Source code for gpaw.utilities.ps2ae

from math import pi, sqrt
from warnings import warn
from typing import Optional, List, Dict

import numpy as np
from ase.units import Bohr, Ha

from gpaw.calculator import GPAW
from gpaw.atom.shapefunc import shape_functions
from gpaw.fftw import get_efficient_fft_size
from gpaw.grid_descriptor import GridDescriptor
from gpaw.lfc import LocalizedFunctionsCollection as LFC
from gpaw.utilities import h2gpts
from import PWDescriptor
from gpaw.mpi import serial_comm
from gpaw.setup import Setup
from gpaw.spline import Spline
from gpaw.typing import Array3D

class Interpolator:
    def __init__(self, gd1, gd2, dtype=float):
        self.pd1 = PWDescriptor(0.0, gd1, dtype)
        self.pd2 = PWDescriptor(0.0, gd2, dtype)

    def interpolate(self, a_r):
        return self.pd1.interpolate(a_r, self.pd2)[0]

POINTS = 200

[docs]class PS2AE: """Transform PS to AE wave functions. Interpolates PS wave functions to a fine grid and adds PAW corrections in order to obtain true AE wave functions. """ def __init__(self, calc: GPAW, grid_spacing: float = 0.05, n: int = 2, h=None # deprecated ): """Create transformation object. calc: GPAW calculator object The calcalator that has the wave functions. grid_spacing: float Desired grid-spacing in Angstrom. n: int Force number of points to be a mulitiple of n. """ if h is not None: warn('Please use grid_spacing=... instead of h=...') grid_spacing = h self.calc = calc gd = gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) # Descriptor for the final grid: N_c = h2gpts(grid_spacing / Bohr, gd.cell_cv) N_c = np.array([get_efficient_fft_size(N, n) for N in N_c]) gd2 = = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm) self.interpolator = Interpolator(gd1, gd2, self.calc.wfs.dtype) self._dphi: Optional[LFC] = None # PAW correction self.dv = * Bohr**3 @property def dphi(self) -> LFC: if self._dphi is not None: return self._dphi splines: Dict[Setup, List[Spline]] = {} dphi_aj = [] for setup in self.calc.wfs.setups: dphi_j = splines.get(setup) if dphi_j is None: rcut = max(setup.rcut_j) * 1.1 gcut = setup.rgd.ceil(rcut) dphi_j = [] for l, phi_g, phit_g in zip(setup.l_j,, dphi_g = (phi_g - phit_g)[:gcut] dphi_j.append(setup.rgd.spline(dphi_g, rcut, l, points=200)) splines[setup] = dphi_j dphi_aj.append(dphi_j) self._dphi = LFC(, dphi_aj, kd=self.calc.wfs.kd.copy(), dtype=self.calc.wfs.dtype) self._dphi.set_positions(self.calc.spos_ac) return self._dphi
[docs] def get_wave_function(self, n: int, k: int = 0, s: int = 0, ae: bool = True, periodic: bool = False) -> Array3D: """Interpolate wave function. Returns 3-d array in units of Ang**-1.5. n: int Band index. k: int K-point index. s: int Spin index. ae: bool Add PAW correction to get an all-electron wave function. periodic: Return periodic part of wave-function, u(r), instead of psi(r)=exp(ikr)u(r). """ u_r = self.calc.get_pseudo_wave_function(n, k, s, periodic=True) u_R = self.interpolator.interpolate(u_r * Bohr**1.5) k_c = self.calc.wfs.kd.ibzk_kc[k] gamma = np.isclose(k_c, 0.0).all() if gamma: eikr_R = 1.0 else: eikr_R = if ae: dphi = self.dphi wfs = self.calc.wfs P_nI = wfs.collect_projections(k, s) if == 0: psi_R = u_R * eikr_R P_ai = {} I1 = 0 for a, setup in enumerate(wfs.setups): I2 = I1 + P_ai[a] = P_nI[n, I1:I2] I1 = I2 dphi.add(psi_R, P_ai, k) u_R = psi_R / eikr_R, 0) if periodic: return u_R * Bohr**-1.5 else: return u_R * eikr_R * Bohr**-1.5
[docs] def get_pseudo_density(self, add_compensation_charges: bool = True) -> Array3D: """Interpolate pseudo density.""" dens = self.calc.density gd1 = assert gd1.comm.size == 1 interpolator = Interpolator(gd1, dens_r = dens.nt_sG[:dens.nspins].sum(axis=0) dens_R = interpolator.interpolate(dens_r) if add_compensation_charges: dens.calculate_multipole_moments() ghat = LFC(, [setup.ghat_l for setup in dens.setups], integral=sqrt(4 * pi)) ghat.set_positions(self.calc.spos_ac) Q_aL = {} for a, Q_L in dens.Q_aL.items(): Q_aL[a] = Q_L.copy() Q_aL[a][0] += dens.setups[a].Nv / (4 * pi)**0.5 ghat.add(dens_R, Q_aL) return dens_R / Bohr**3
[docs] def get_electrostatic_potential(self, ae: bool = True, rcgauss: float = 0.02) -> Array3D: """Interpolate electrostatic potential. Return value in eV. ae: bool Add PAW correction to get the all-electron potential. rcgauss: float Width of gaussian (in Angstrom) used to represent the nuclear charge. """ gd = self.calc.hamiltonian.finegd v_r = self.calc.get_electrostatic_potential() / Ha gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) interpolator = Interpolator(gd1, v_R = interpolator.interpolate(v_r) if ae: self.add_potential_correction(v_R, rcgauss / Bohr) return v_R * Ha
def add_potential_correction(self, v_R: Array3D, rcgauss: float) -> None: dens = self.calc.density dens.D_asp.redistribute(dens.atom_partition.as_serial()) dens.Q_aL.redistribute(dens.atom_partition.as_serial()) dv_a1 = [] for a, D_sp in dens.D_asp.items(): setup = dens.setups[a] c = setup.xc_correction rgd = c.rgd params = params['lmax'] = 0 ghat_g = shape_functions(rgd, **params)[0] Z_g = shape_functions(rgd, 'gauss', rcgauss, lmax=0)[0] * setup.Z D_q =, c.B_pqL[:, :, 0]) dn_g =, (c.n_qg - c.nt_qg)) * sqrt(4 * pi) dn_g += 4 * pi * (c.nc_g - c.nct_g) dn_g -= Z_g dn_g -= dens.Q_aL[a][0] * ghat_g * sqrt(4 * pi) dv_g = rgd.poisson(dn_g) / sqrt(4 * pi) dv_g[1:] /= rgd.r_g[1:] dv_g[0] = dv_g[1] dv_g[-1] = 0.0 dv_a1.append([rgd.spline(dv_g, points=POINTS)]) dens.D_asp.redistribute(dens.atom_partition) dens.Q_aL.redistribute(dens.atom_partition) if dv_a1: dv = LFC(, dv_a1) dv.set_positions(self.calc.spos_ac) dv.add(v_R), 0)
def interpolate_weight(calc, weight, h=0.05, n=2): """interpolates cdft weight function, gd is the fine grid.""" gd = calc.density.finegd weight = gd.collect(weight, broadcast=True) weight = gd.zero_pad(weight) w = np.zeros_like(weight) gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) gd1.distribute(weight, w) N_c = h2gpts(h / Bohr, gd.cell_cv) N_c = np.array([get_efficient_fft_size(N, n) for N in N_c]) gd2 = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm) interpolator = Interpolator(gd1, gd2) W = interpolator.interpolate(w) return W