
Planewave module. dict | Array | Tuple, freq_mask: Bool[Array, 'x y z']) Complex[Array, 'spin kpt band x y z'][source]#

Create the linear coefficients to combine the frequency components.

This function takes a raw parameter of shape (num_spin, num_kpts, num_gpts, num_bands), orthogonalizes for the last two dimensions, so that the resulting tensor satisfies the unitary constraint einsum('kabc,labc->kl', ret[i, j], ret[i, j]) == eye(num_bands).

The pw_param should be created from param_init(), and the same freq_mask used in param_init() should be used here. As mentioned in param_init(), we use linear combination over 3D Fourier components for creating wave functions. Some extra requirements are:

  1. The wave functions that have the same spin and same k component need to be orthogonal to each other.

  2. We only activate some of the frequency components with the freq_mask.

As the raw parameter returned from param_init() has the shape (num_spin, num_kpts, num_g, num_bands), where num_g is the number of activated frequencies flattened from the activated entries in the freq_mask, this function first orthogonalizes over the last two dimensions and reorganizes the orthogonalized parameter into a 3D grid the same shape as the frequency mask.

Extension reads: 1. Why and how to mask the frequency components. 2. Bloch theorem.

  • pw_param (Union[Array, Tuple]) – The raw parameter, maybe created from param_init().

  • freq_mask (Bool[Array, 'x y z']) – A 3D mask to select the frequency components.


Complex array of shape (num_spin, num_kpts, num_band, x y z). It satisfies the unitary constraint that for any i,j, einsum('kabc,labc->kl', ret[i, j], ret[i, j]) is an identity matrix.

Return type:

Complex[Array, ‘spin kpt band x y z’] Complex[Array, 'spin kpt band x y z'], vol: Float, occupation: Float[Array, 'spin kpt band'] | None = None) Float[Array, 'spin kpt band x y z'] | Float[Array, 'spin kpt band'][source]#

Compute the density at the spatial grid.

In a system with electrons, the density of the electron is the result of multiple wave functions overlapping in space. As mentioned in wave_grid(), the wave function is a linear combination of 3D Fourier components. To compute the density, usually we only need to take the absolute square of each wave function and sum them up.

\[\rho(r) = \sum_i |\psi_i(r)|^2\]

This function evaluates the density \(\rho(r)\) at the spatial grid generated from jrystal.grid.r_vectors().

In crystals, this is a little bit more complicated. The form of the wave function is:

\[\psi(r) = \frac{1}{\sqrt{\Omega_\text{cell}}} e^{ikr} \sum_G c_{kG} e^{iGr}\]

The \(c_{kG}\) can be computed from param_init() and coeff(). For calculation of density, we only need the \(c_{kG}\) and the occupation \(o_k\) over \(k\).

\[\rho(r) = \frac{1}{\Omega_\text{cell}} e^{ikr} \sum_G c_{kG} e^{iGr}\]
  • coeff\(c_{kG}\) part of the parameter. It can have a leading batch dimension which will be summed to get the overall density. Therefore the shape is (spin, kpt, band, x, y, z).

  • vol – Volume of the unit cell, a real scalar.

  • occupation – The occupation over different k frequencies. The shape is (spin, kpt, band), it should have the same leading dimension as coeff. This is an optional argument. When occupation=None, we compute the density contribution from each \(k\) without summing them. If occupation is provided, we sum up all the density from each \(k\) weighted by the occupation.


A real-valued tensor that represents the density at the spatial grid computed from jrystal.grid.r_vectors(). The shape is (x, y, z) if occupation is provided, else the shape is (spin, kpt, band, x, y, z).

Return type:

Union[Float[Array, “spin kpt band x y z”], Float[Array, “spin kpt band”]] Complex[Array, 'spin kpt band x y z'], vol: float | Array, occupation: Float[Array, 'spin kpt band'] | None = None) Complex[Array, 'spin kpt band x y z'] | Complex[Array, 'spin kpt band'][source]#

Fourier transform of the density grid.

In a system with electrons, the density of the electron is the result of multiple wave functions overlapping in the space. As we mention in wave_grid(), the wave function is a linear combination of 3D fourier components. To compute the density, usually we only need to take the absolute square of each wave function \(\psi_i\), multiply by the occupation \(f_i\) and sum them up:

\[\rho(\vb{r}) = \sum_i f_i |\psi_i(\vb{r})|^2\]

density_grid() computes the density \(\rho(\vb{r})\) at the spatial grid generated from jrystal.grid.r_vectors().

The discrete fourier transformation of \(\rho(\vb{r})\) is

\[\tilde{\rho}(\vb{G}) = \frac{1}{\Omega} \int_\Omega \rho(\vb{r}) e^{-\text{i} \vb{G}^\top \vb{r}} \dd{\vb{r}}\]

We can also evaluate the \(\tilde{\rho}(G)\) at any \(G\), but this function computes the \(\tilde{\rho}(G)\) evaluated at the grid generated from jrystal.grid.g_vectors(). It is equivalent to computing the \(\rho(\vb{r})\) at jrystal.grid.r_vectors() and then do the discrete fourier transform.

Since this function is just composing FFT with density_grid(), we refer you to density_grid() for more details.

  • coeff – coefficients of the wave functions.

  • vol – volume of the unit cell.

  • occupation – occupation over the different \(k\) components, see more in density_grid().


A complex valued tensor representing \(\tilde{\rho}(G)\) evaluated at \(G\) generated from jrystal.grid.g_vectors(). Array, num_bands: int, num_kpts: int, freq_mask: Bool[Array, 'x y z'], spin_restricted: bool = True) dict[source]#

Initialize the raw parameters.

This function generates a random tensor of shape (num_spin, num_kpts, num_g, num_bands), where num_g is the number of True items in the freq_mask.

In planewave-based calculation, a wave function is represented as a linear combination of the Fourier series in 3D. Therefore, to create one wave function we need a 3D shaped tensor to represent the mixing coefficients on each frequency component (denoted as G). freq_mask provides a 3D mask to decide which frequency components are selected, the number of selected components is denoted as num_g.

The num_bands & num_kpts are a bit hard to explain. Intuitively, the wave functions consist of high frequency components that have a period smaller than the unit cell (denoted \(G\)) and components that have a period larger than the unit cell (denoted \(k\)).

The form of wave function under solid state is:

\[\psi(r) = e^{ikr}\sum_G c_{kG} e^{iGr}\]

This function generates a raw parameter, which after processing by coeff() can be used as the \(c_{kG}\) part of the above equation.

Extension reads: 1. Why and how to mask the frequency components. 2. Bloch theorem.

As far as this function is concerned, it simply returns a randomly initialized parameter of shape (num_spin, num_kpts, num_g, num_bands). The input arguments to this function are only used to determine the shape.

Note that this function returns the raw parameter that cannot be used directly to weight the frequency components, as in quantum chemistry we require the wave functions to be orthogonal to each other. Check coeff() for converting the raw parameter into a unitary tensor.

  • key (Array) – Random key for initializing the parameters.

  • num_bands (int) – The number of bands.

  • num_kpts (int) – The number of k points.

  • freq_mask (Bool[Array, 'x y z']) – A 3D mask that denotes which frequency components are selected.

  • spin_restricted (bool) – If True, num_spin=2 else num_spin=1.


A complex type raw parameter of shape (num_spin, num_kpts, num_g, num_bands).

Return type:

Complex[Array, ‘spin kpt gpt band’] Complex[Array, 'spin kpt band x y z'], vol: Float) Complex[Array, 'spin kpt band x y z'][source]#

Wave function evaluated at a grid of spatial locations.

This function implements the \(u(r)\) part of the Bloch wave function:

\[u(r)=\frac{1}{\sqrt{\Omega_\text{cell}}} \sum_G c_{G} e^{iG^\top r}\]

where \(G\) is the 3D frequency components, \(\Omega_\text{cell}\) is the volume of the crystal unit cell, which is to make sure the wave function is normalized within the cell.

where \(c\) is the linear coefficient. It combines over different \(G\) components that is generated with jrystal.grid.g_vectors(). We can evaluate the wave function at any spatial location \(r\) which takes \(O(|G|)\) computation. However, if we evaluate this function on a specific spatial grid of size \(|G|\), we can be faster than \(O(|G|^2)\) by using fourier transform. IFFT gives us an \(O(|G|\log(|G|))\) implementation of the above equation. The \(G\) and \(R\) grid can be obtained from jrystal.grid.g_vectors() and jrystal.grid.r_vectors() correspondingly.

G = jrystal.grid.g_vectors(*args)  # (x y z, 3)
R = jrystal.grid.r_vectors(*args)  # (x y z, 3)
coefficients = ...  # (x y z)
vol = ...

def wave_function(r):
  return (coefficients * jnp.exp(1j * G @ r)).sum() / jnp.sqrt(vol)

# The following is O(|G|^2)
wave_at_R_naive = jax.vmap(jax.vmap(jax.vmap(wave_function)))(R)
# The following is O(|G|log|G|)
wave_at_R_fft = wave_grid(coefficients, vol)

As IFFT implements

\[x_n = \frac{1}{N} \sum_{k=0}^{N-1} X_k e^{i\frac{2\pi}{N}kn}\]

It is a bit different from the definition of the wave function, if you check the code, we do two things to align them,

1. we multiply back the \(N\) to cancel the \(\frac{1}{N}\) factor in the IFFT (in 3D the 2. we divide by the \(\sqrt{\Omega_\text{cell}}\).

The coeff passed to this function has shape (..., x y z), it can have any leading dimension. It can be created using param_init() and coeff(). param_init() creates a raw parameter and coeff() converts that parameter into coefficients that are used to linearly weight the 3D fourier components.

  • coeff (Complex[Array, 'spin kpt band x y z']) – Wave function coefficients, which has a shape of (spin, kpt, band, x, y, z).

  • vol (float) – Volume of the unit cell.


Wave function evaluated at the spatial grid.

Return type:

Complex[Array, ‘spin kpt band x y z’] Float[Array, '3'], coeff: Complex[Array, 'spin kpt band x y z'], cell_vectors: Float[Array, '3 3'], g_vector_grid: Float[Array, 'x y z 3'] | None = None) Complex[Array, 'spin kpt band'][source]#

Evaluate plane wave functions at location r.

This function computes the \(\psi(r)\) following the equation:

\[\psi(r) = \frac{1}{\sqrt{\Omega_\text{cell}}} \sum_G c_{G} e^{iGr}\]

The coeff provided is the \(c_{G}\), the cell_vectors is used to generate the grid of frequency components \(G\) by calling the function jrystal.grid.g_vectors(). The r is the location where we evaluate the wave function.

  • r – Spatial location to evaluate the wave function, shape: (3,).

  • coeff – Wave function coefficients, which has a shape of (spin, kpt, band, x, y, z). It can be created from param_init() followed by coeff().

  • cell_vectors – The cell vectors of the crystal unit cell.

  • g_vector_grid – The G vectors computed from jrystal.grid.g_vectors(). If None, will be computed from cell_vectors.


Complex tensor that represents wave functions evaluated at location r.

Return type:

Complex[Array, ‘spin kpt band’]