Source code for jrystal._src.ewald

# Copyright 2025 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ewald summation for periodic systems. """

import jax
import jax.numpy as jnp
from jaxtyping import Float, Array


[docs] def ewald_coulomb_repulsion( positions: Float[Array, 'atom d'], charges: Float[Array, 'atom'], g_vector_grid: Float[Array, 'x y z 3'], vol: Float, ewald_eta: Float, ewald_grid: Float[Array, 'x y z 3'], ) -> Float: """ Calculate the nuclei repulsion energy using Ewald summation method. This function computes the nuclei repulsion energy for a periodic system using the Ewald summation technique, which splits the calculation into real and reciprocal space contributions. The method provides an efficient way to handle long-range Coulomb interactions in periodic systems. .. note:: Further reading: - Textbook: Martin, R. M. (2020). Electronic structure: basic theory and practical methods. Cambridge university press. (Appendix F.2) - Our tutorial: :doc:`Ewald Summation </tutorial/ewald>` Args: positions (Float[Array, 'atom d']): Array of shape (atom, d) containing atomic positions in d-dimensional space. charges (Float[Array, 'atom']): Array of shape (atom,) containing the charges of each atom. g_vector_grid (Float[Array, 'x y z 3']): Array of shape (x, y, z, 3) containing the reciprocal lattice vectors. vol (Float): Float representing the volume of the unit cell. ewald_eta (Float): Float controlling the split between real and reciprocal space contributions. Also known as the Ewald convergence parameter. ewald_grid (Float[Array, 'x y z 3']): Array of shape (x, y, z, 3) containing real-space translation vectors for the Ewald sum. Can be generated using :func:`jrystal.grid.translation_vectors`. Returns: Float: The total Coulomb repulsion energy computed using Ewald summation. Includes both real-space and reciprocal-space contributions, as well as self-interaction corrections. """ dim = positions.shape[-1] tau = jnp.expand_dims(positions, 0) - jnp.expand_dims(positions, 1) tau_t = jnp.expand_dims(tau, 2) - jnp.expand_dims(ewald_grid, axis=(0, 1)) # [na, na, nt, 3] tau_t_norm = jnp.sqrt(jnp.sum(tau_t**2, axis=-1) + 1e-20) # [na, na, nt] tau_t_norm = jnp.where(tau_t_norm <= 1e-9, 1e20, tau_t_norm) # atom-atom part: ew_ovlp = jnp.sum( jax.scipy.special.erfc(ewald_eta * tau_t_norm) / tau_t_norm, axis=2 ) # the reciprocal space part: gvec_norm_sq = jnp.sum(g_vector_grid**2, axis=3) # [x y z] gvec_norm_sq = gvec_norm_sq.at[(0,) * dim].set(1e16) ew_rprcl = jnp.exp(-gvec_norm_sq / 4 / ewald_eta**2) / gvec_norm_sq ew_rprcl1 = jnp.expand_dims(ew_rprcl, range(dim, dim + 2)) ew_rprcl2 = jnp.cos( jnp.sum( jnp.expand_dims(g_vector_grid, axis=(-2, -3)) * jnp.expand_dims(tau, range(dim)), axis=-1 ) ) # [x y z, na, na, nt] ew_rprcl2 = ew_rprcl2.at[(0,) * dim].set(0) # this is to exclude G = 0 ew_rprcl = jnp.sum(ew_rprcl1 * ew_rprcl2, axis=range(dim)) # [na, na] ew_rprcl = ew_rprcl * 4 * jnp.pi / vol ew_aa = jnp.einsum('i,ij->j', charges, ew_ovlp + ew_rprcl) ew_aa = jnp.dot(ew_aa, charges) / 2 # single atom part ew_a = -jnp.sum(charges**2) * 2 * ewald_eta / jnp.sqrt(jnp.pi) / 2 ew_a -= jnp.sum(charges)**2 * jnp.pi / ewald_eta**2 / vol / 2 return ew_aa + ew_a