Spherical Harmonics#

In E(3)-equivariant NN, mathematical objects are further decomposed into irreps, approximated by spherical harmonics which are equivariant. The spherical harmonics \(Y^l\) are a family of \(2l+1\) functions from \(R^3\) to \(D^l\), the set of l-ireps, such that:

\[ Y^l\left(\vec{x}\right) = \left[ Y_{-l}^{l}\left(\vec{x}\right), Y_{-l+1}^{l}\left(\vec{x}\right), ... , Y_{l}^{l}\left(\vec{x}\right) \right], \]

and \(Y^l_m\) are real (normalized) spherical harmonics defined by:

\[\begin{split} Y^l_m = \begin{cases} \sqrt{2}(-1)^m\mathcal{I}[\hat{Y}^l_{|m|}] & \text{if $m<0$} \\ \hat{Y}^l_{|m|} & \text{if $m=0$} \\ \sqrt{2}(-1)^m\mathcal{R}[\hat{Y}^l_{m}] & \text{if $m>0$} \\ \end{cases} \end{split}\]

with:

\[ \hat{Y}^l_{m} = (-1)^m \sqrt{\frac{(2l+1)}{4\pi} \frac{(l-m)!}{(l+m)!} }P_m^l(cos(\theta))e^{im\phi}, \]

where \(P_m^l\) are Legendre polynomials without the Condon–Shortley phase, and \(\theta, \phi\) spherical coordinates and \(m \in \{-l, -l+1, ..., l+1\}\).

It is to be noted that each \(Y^l\) is equivariant to \(SO(3)\) with respect to the irrep of the same order. For example, consider any rotation matrix \(R\), then:

\[ Y^l(R\vec{x}) = D^l(R)Y^l(g) \]

i.e,

\[ Y^l_m(R\vec{x}) = \sum_{n=-l}^{n=l} D^l_{mn}(R)Y^l_n(\vec{x}) = D^l_m(R) \cdot Y^l, \]

where \(D^l\) are the irreducible representations of \(SO(3)\).

e3nn provides utilities functions to compute spherical harmonics:

import torch
import math
from e3nn import o3
import plotly.graph_objects as go

axis = dict(
    showbackground=False,
    showticklabels=False,
    showgrid=False,
    zeroline=False,
    title='',
    nticks=3,
)

layout = dict(
    width=690,
    height=160,
    scene=dict(
        xaxis=dict(
            **axis,
            range=[-8, 8]
        ),
        yaxis=dict(
            **axis,
            range=[-2, 2]
        ),
        zaxis=dict(
            **axis,
            range=[-2, 2]
        ),
        aspectmode='manual',
        aspectratio=dict(x=8, y=2, z=2),
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=0, y=-5, z=5),
            projection=dict(type='orthographic'),
        ),
    ),
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    margin=dict(l=0, r=0, t=0, b=0)
)


def s2_grid():
    betas = torch.linspace(0, math.pi, 40)
    alphas = torch.linspace(0, 2 * math.pi, 80)
    beta, alpha = torch.meshgrid(betas, alphas, indexing='ij')
    return o3.angles_to_xyz(alpha, beta)

def trace(r, f, c, radial_abs=True):
    if radial_abs:
        a = f.abs()
    else:
        a = 1
    return dict(
        x=a * r[..., 0] + c[0],
        y=a * r[..., 1] + c[1],
        z=a * r[..., 2] + c[2],
        surfacecolor=f
    )

def plot(data, radial_abs=True):
    r = s2_grid()
    n = data.shape[-1]
    traces = [
        trace(r, data[..., i], torch.tensor([2.0 * i - (n - 1.0), 0.0, 0.0]), radial_abs=radial_abs)
        for i in range(n)
    ]
    cmax = max(d['surfacecolor'].abs().max().item() for d in traces)
    traces = [go.Surface(**d, colorscale='RdYlBu', cmin=-cmax, cmax=cmax) for d in traces]
    fig = go.Figure(data=traces, layout=layout)
    fig.show()
    return fig
r = s2_grid()
lmax = 2

for l in range(lmax+1):
    yl = o3.spherical_harmonics(l,r, 'norm')
    fig = plot(yl)

Disclaimer#

Conventions between field varies on how to define spherical harmonics and thus their transformation to real space formula varies depending on which field you are in. e3nn might adopt a different conventions then the formula proposed here, but the ideas are similar.

References#

  1. Geiger, Mario, and Tess Smidt. “e3nn: Euclidean neural networks.” arXiv preprint arXiv:2207.09453 (2022).

  2. https://docs.e3nn.org/en/stable/api/o3/o3_sh.html