utils#
Utility functions.
- jrystal._src.utils.absolute_square(array: Complex[Array, '...']) Float[Array, '...'] [source]#
Computes the squared magnitude of complex numbers in an array.
Calculates \(|z|^2\) for each complex number \(z\) in the input array by multiplying each element with its complex conjugate. This operation preserves the array shape while converting complex values to their real squared magnitudes.
Note
This is equivalent to \((Re(z))^2 + (Im(z))^2\) for each complex number \(z\), but is computed using complex conjugate multiplication for better numerical stability.
Example:
x = 3 + 4j absolute_square(x) # Returns 25.0 (|3 + 4j|² = 3² + 4² = 25)
- Parameters:
array (Complex[Array, '...']) – Complex-valued array of any shape. The ‘…’ notation indicates arbitrary dimensions are supported.
- Returns:
Real-valued array of the same shape as input, containing the squared magnitudes of the complex values.
- jrystal._src.utils.check_spin_number(num_electrons: int, spin: int) None [source]#
Validates that the spin number is compatible with electron count.
Checks if the specified spin number (number of unpaired electrons) is physically possible given the total number of electrons. The spin number and total electron count must have the same parity (both odd or both even).
- Parameters:
- Raises:
ValueError – If the spin number is not valid for the given number of electrons (i.e., if they have different parity).
- jrystal._src.utils.expand_coefficient(coeff_compact: Complex[Array, 'spin kpt band gpt'], mask: Bool[Array, 'x y z']) Complex[Array, 'spin kpt band x y z'] [source]#
Expands compact coefficients into a full grid using a boolean mask.
Transforms coefficients from a compact representation (where only significant points are stored) to a full grid representation by placing the coefficients at positions specified by a boolean mask. This is useful for converting between storage-efficient and computation-friendly representations.
- Parameters:
coeff_compact (Complex[Array, "spin kpt band gpt"]) – Compact coefficient array with dimensions for spin, k-points, bands, and grid points.
mask (Bool[Array, 'x y z']) – Boolean mask indicating valid grid points in the expanded representation. The number of True values must match the last dimension of coeff_compact.
- Returns:
The expanded coefficient array with dimensions matching the batch dimensions of coeff_compact (spin, kpt, band) followed by the spatial dimensions of the mask (x, y, z).
- Return type:
Complex[Array, “spin kpt band x y z”]
Note
The function first swaps the last two axes of the input coefficients to align with the expected output format, then creates a zero-filled array of the target shape and places the coefficients at the masked positions.
- jrystal._src.utils.fft_factor(n: int) int [source]#
Finds the smallest valid FFT size that is >= n.
Determines the smallest number greater than or equal to n that can be factored as
\[\text{FFT size} = 2^a \times 3^b \times 5^c \times 7^d \times 11^e \times 13^f\]where \(e\) and \(f\) are either 0 or 1.
- Parameters:
n (int) – The minimum size needed for the FFT grid.
- Returns:
The smallest valid FFT size >= n that satisfies the prime factorization requirements.
- Return type:
- Raises:
ValueError – If n > 2048, as the implementation is limited to sizes below this threshold.
- jrystal._src.utils.safe_real(array: Array, tol: float = 1e-08) Array [source]#
Safely converts a complex array to real by checking imaginary components.
Attempts to convert a complex array to real by verifying that all imaginary components are effectively zero (within specified tolerance). This is useful for numerical computations where results should be real but may have tiny imaginary components due to floating point errors.
- Parameters:
array (Array) – Input array that may be real or complex.
tol (float) – Tolerance threshold for considering imaginary components as zero. Defaults to 1e-8.
- Returns:
- The real component of the input if imaginary parts are within
tolerance, otherwise the original array.
- Return type:
Array
- Raises:
ValueError – If the array has imaginary components larger than the specified tolerance.
Example:
x = 1.0 + 1e-10j safe_real(x) # Returns 1.0 y = 1.0 + 1.0j safe_real(y) # Raises ValueError
- jrystal._src.utils.squeeze_coefficient(coeff: Complex[Array, 'spin kpt band x y z'], mask: Bool[Array, 'spin kpt band x y z']) Complex[Array, 'spin kpt gpt band'] [source]#
Compresses coefficients by extracting values at masked positions.
Performs the inverse operation of expand_coefficient by extracting values from positions specified by a boolean mask and arranging them in a compact format. The output is transposed to have grid points before bands for efficient computation.
Note
The function extracts values at masked positions and then swaps the last two axes to arrange the output as (spin, kpt, gpt, band) rather than (spin, kpt, band, gpt).
- Parameters:
coeff (Complex[Array, "spin kpt band x y z"]) – Full coefficient array with dimensions for spin, k-points, bands, and spatial coordinates (x, y, z).
mask (Bool[Array, "spin kpt band x y z"]) – Boolean mask of the same shape as coeff indicating which positions should be included in the compact representation.
- Returns:
Compact coefficient array with dimensions (spin, kpt, gpt, band), where gpt represents the number of True values in the mask.
- Return type:
Complex[Array, “spin kpt gpt band”]
- jrystal._src.utils.vmapstack(times: int, args: List[Dict] = None) Callable [source]#
Recursively applies JAX’s vmap function to vectorize operations over multiple dimensions.
Creates a decorator that applies JAX’s vmap transformation multiple times to a function, enabling vectorized operations over multiple batch dimensions. This is particularly useful for handling multi-dimensional batch processing in neural network operations.
- Parameters:
times (int) – Number of vmap applications. Should match the number of batch dimensions to be processed.
args (List[Dict]) – Optional list of dictionaries containing vmap configuration for each application. Each dictionary can contain standard vmap arguments like in_axes, out_axes, axis_size, etc. Defaults to None.
- Returns:
A decorator function that transforms the input function by applying vmap the specified number of times.
- Return type:
Callable
- Raises:
ValueError – If the length of args does not match the specified number of vmap applications (times).
Example:
@vmapstack(2) def f(x): return x * 2 # f can now handle 2 batch dimensions x = jnp.ones((3, 4, 5)) # 2 batch dims (3,4) with input dim 5 result = f(x) # Shape will be (3, 4, 5)
- jrystal._src.utils.volume(cell_vectors: Float[Array, '3 3']) Float [source]#
Calculates the volume of a parallelepiped defined by three cell vectors.
Computes the volume of a unit cell in a crystal structure by calculating the determinant of the matrix formed by the three cell vectors. The absolute value of the determinant gives the volume of the parallelepiped.
Note
The volume is calculated as \(|det(A)|\) where \(A\) is the matrix of cell vectors. This gives the volume of the parallelepiped formed by the three vectors regardless of their orientation.
Example:
# For a cubic cell of side length 2 vectors = jnp.array([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]]) volume(vectors) # Returns 8.0
- Parameters:
cell_vectors (Float[Array, '3 3']) – A 3x3 matrix where each row represents a cell vector of the crystal structure. The vectors should be given in consistent units (e.g., Bohr radii or Angstroms).
- Returns:
The volume of the unit cell (in cubic units of the input vectors).
- Return type:
Float
- jrystal._src.utils.wave_to_density(wave_grid: Complex[Array, 'spin kpt band x y z'], occupation: Float[Array, 'spin kpt band'] | None = None, axis: int | Tuple | List | None = None) Float[Array, 'spin kpt band x y z'] | Float[Array, 'spin kpt band'] [source]#
Computes electron density from wave functions in real space.
Calculates the electron density by taking the absolute square of wave functions and optionally applying occupation numbers. The density can be computed for the full grid or reduced along specified dimensions.
- Parameters:
wave_grid (Complex[Array, 'spin kpt band x y z']) – Complex wave function values on a real-space grid. The array has dimensions for spin, k-points, bands, and spatial coordinates (x,y,z).
occupation (Optional[Float[Array, 'spin kpt band']]) – Optional occupation numbers for each state (spin, k-point, band). If provided, the density will be weighted by these values. Defaults to None.
axis (Optional[Union[int, Tuple, List]]) – Optional specification of axes along which to sum the density when applying occupation numbers. If None, sums over spin, k-points, and bands (first three dimensions). Defaults to None.
- Returns:
The electron density grid. If occupation is None, the density grid has the same shape as the input wave_grid. If occupation is provided, the density grid is reduced along the specified axes after applying occupation weights.
- Return type:
Union[Float[Array, ‘spin kpt band x y z’], Float[Array, ‘spin kpt band’]]
- Raises:
ValueError – If the shapes of wave_grid and occupation are incompatible for broadcasting.
- jrystal._src.utils.wave_to_density_reciprocal(wave_grid: Complex[Array, 'spin kpt band x y z'], occupation: Float[Array, 'spin kpt band'] | None = None, axis: int | Tuple | List | None = None) Float[Array, 'spin kpt band x y z'] | Float[Array, 'spin kpt band'] [source]#
Computes electron density from wave functions in reciprocal space.
Calculates the electron density by first computing the real-space density and then performing a Fourier transform to obtain the reciprocal space representation. This is useful for operations that are more efficient in reciprocal space, such as computing the Hartree potential.
- Parameters:
wave_grid (Complex[Array, 'spin kpt band x y z']) – Complex wave function values on a real-space grid. The array has dimensions for spin, k-points, bands, and spatial coordinates (x,y,z).
occupation (Optional[Float[Array, 'spin kpt band']]) – Optional occupation numbers for each state (spin, k-point, band). If provided, the density will be weighted by these values. Defaults to None.
axis (Optional[Union[int, Tuple, List]]) – Optional specification of axes along which to sum the density when applying occupation numbers. If None, sums over spin, k-points, and bands (first three dimensions). Defaults to None.
- Returns:
The electron density grid in reciprocal space. If occupation is None, the density grid has the same shape as the input wave_grid. If occupation is provided, the density grid is reduced along the specified axes after applying occupation weights.
- Return type:
Union[Float[Array, ‘spin kpt band x y z’], Float[Array, ‘spin kpt band’]]