Spherical Harmonics
Contents
Spherical Harmonics#
In E(3)-equivariant NN, mathematical objects are further decomposed into irreps
, approximated by spherical harmonics
which are equivariant. The spherical harmonics \(Y^l\) are a family of \(2l+1\) functions from \(R^3\) to \(D^l\), the set of l-ireps, such that:
and \(Y^l_m\) are real (normalized) spherical harmonics defined by:
with:
where \(P_m^l\) are Legendre polynomials without the Condon–Shortley phase, and \(\theta, \phi\) spherical coordinates and \(m \in \{-l, -l+1, ..., l+1\}\).
It is to be noted that each \(Y^l\) is equivariant to \(SO(3)\) with respect to the irrep of the same order. For example, consider any rotation matrix \(R\), then:
i.e,
where \(D^l\) are the irreducible representations of \(SO(3)\).
e3nn
provides utilities functions to compute spherical harmonics:
import torch
import math
from e3nn import o3
import plotly.graph_objects as go
axis = dict(
showbackground=False,
showticklabels=False,
showgrid=False,
zeroline=False,
title='',
nticks=3,
)
layout = dict(
width=690,
height=160,
scene=dict(
xaxis=dict(
**axis,
range=[-8, 8]
),
yaxis=dict(
**axis,
range=[-2, 2]
),
zaxis=dict(
**axis,
range=[-2, 2]
),
aspectmode='manual',
aspectratio=dict(x=8, y=2, z=2),
camera=dict(
up=dict(x=0, y=0, z=1),
center=dict(x=0, y=0, z=0),
eye=dict(x=0, y=-5, z=5),
projection=dict(type='orthographic'),
),
),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
margin=dict(l=0, r=0, t=0, b=0)
)
def s2_grid():
betas = torch.linspace(0, math.pi, 40)
alphas = torch.linspace(0, 2 * math.pi, 80)
beta, alpha = torch.meshgrid(betas, alphas, indexing='ij')
return o3.angles_to_xyz(alpha, beta)
def trace(r, f, c, radial_abs=True):
if radial_abs:
a = f.abs()
else:
a = 1
return dict(
x=a * r[..., 0] + c[0],
y=a * r[..., 1] + c[1],
z=a * r[..., 2] + c[2],
surfacecolor=f
)
def plot(data, radial_abs=True):
r = s2_grid()
n = data.shape[-1]
traces = [
trace(r, data[..., i], torch.tensor([2.0 * i - (n - 1.0), 0.0, 0.0]), radial_abs=radial_abs)
for i in range(n)
]
cmax = max(d['surfacecolor'].abs().max().item() for d in traces)
traces = [go.Surface(**d, colorscale='RdYlBu', cmin=-cmax, cmax=cmax) for d in traces]
fig = go.Figure(data=traces, layout=layout)
fig.show()
return fig
r = s2_grid()
lmax = 2
for l in range(lmax+1):
yl = o3.spherical_harmonics(l,r, 'norm')
fig = plot(yl)
Disclaimer#
Conventions between field varies on how to define spherical harmonics and thus their transformation to real space formula varies depending on which field you are in. e3nn
might adopt a different conventions then the formula proposed here, but the ideas are similar.
References#
Geiger, Mario, and Tess Smidt. “e3nn: Euclidean neural networks.” arXiv preprint arXiv:2207.09453 (2022).