Source code for gpaw.fftw

"""
Python wrapper for FFTW3 library
================================

.. autoclass:: FFTPlans

"""
from __future__ import annotations

import weakref
from types import ModuleType

import numpy as np
from scipy.fft import fftn, ifftn, irfftn, rfftn
import warnings

import gpaw.cgpaw as cgpaw
from gpaw.typing import Array1D, Array3D, DTypeLike, IntVector

ESTIMATE = 64
MEASURE = 0
PATIENT = 32
EXHAUSTIVE = 8

_plan_cache: dict[tuple, weakref.ReferenceType] = {}


[docs]def have_fftw() -> bool: """Did we compile with FFTW?""" return hasattr(cgpaw, 'FFTWPlan')
[docs]def check_fft_size(n: int, factors=[2, 3, 5, 7]) -> bool: """Check if n is an efficient fft size. Efficient means that n can be factored into small primes (2, 3, 5, 7). >>> check_fft_size(17) False >>> check_fft_size(18) True """ if n == 1: return True for x in factors: if n % x == 0: return check_fft_size(n // x, factors) return False
[docs]def get_efficient_fft_size(N: int, n=1, factors=[2, 3, 5, 7]) -> int: """Return smallest efficient fft size. Must be greater than or equal to N and divisible by n. >>> get_efficient_fft_size(17) 18 """ N = -(-N // n) * n while not check_fft_size(N, factors): N += n return N
[docs]def empty(shape, dtype=float): """numpy.empty() equivalent with 16 byte alignment.""" assert dtype == complex N = np.prod(shape) a = np.empty(2 * N + 1) offset = (a.ctypes.data % 16) // 8 a = a[offset:2 * N + offset].view(complex) a.shape = shape return a
[docs]def create_plans(size_c: IntVector, dtype: DTypeLike, flags: int = MEASURE, xp: ModuleType = np) -> FFTPlans: """Create plan-objects for FFT and inverse FFT.""" key = (tuple(size_c), dtype, flags, xp) # Look up weakref to plan: if key in _plan_cache: plan = _plan_cache[key]() # Check if plan is still "alive": if plan is not None: return plan # Create new plan: if xp is not np: plan = CuPyFFTPlans(size_c, dtype) elif have_fftw(): plan = FFTWPlans(size_c, dtype, flags) else: plan = NumpyFFTPlans(size_c, dtype) _plan_cache[key] = weakref.ref(plan) return plan
[docs]class FFTPlans: def __init__(self, size_c: IntVector, dtype: DTypeLike, empty=empty): self.shape: tuple[int, ...] if dtype == float: self.shape = (size_c[0], size_c[1], size_c[2] // 2 + 1) self.tmp_Q = empty(self.shape, complex) self.tmp_R = self.tmp_Q.view(float)[:, :, :size_c[2]] else: self.shape = tuple(size_c) self.tmp_Q = empty(size_c, complex) self.tmp_R = self.tmp_Q def fft(self) -> None: """Do FFT from ``tmp_R`` to ``tmp_Q``. >>> plans = create_plans([4, 1, 1], float) >>> plans.tmp_R[:, 0, 0] = [1, 0, 1, 0] >>> plans.fft() >>> plans.tmp_Q[:, 0, 0] array([2.+0.j, 0.+0.j, 2.+0.j, 0.+0.j]) """ raise NotImplementedError def ifft(self) -> None: """Do inverse FFT from ``tmp_Q`` to ``tmp_R``. >>> plans = create_plans([4, 1, 1], complex) >>> plans.tmp_Q[:, 0, 0] = [0, 1j, 0, 0] >>> plans.ifft() >>> plans.tmp_R[:, 0, 0] array([ 0.+1.j, -1.+0.j, 0.-1.j, 1.+0.j]) """ raise NotImplementedError def ifft_sphere(self, coef_G, pw, out_R): if coef_G is None: out_R.scatter_from(None) return pw.paste(coef_G, self.tmp_Q) if pw.dtype == float: t = self.tmp_Q[:, :, 0] n, m = (s // 2 - 1 for s in out_R.desc.size_c[:2]) t[0, -m:] = t[0, m:0:-1].conj() t[n:0:-1, -m:] = t[-n:, m:0:-1].conj() t[-n:, -m:] = t[n:0:-1, m:0:-1].conj() t[-n:, 0] = t[n:0:-1, 0].conj() self.ifft() out_R.scatter_from(self.tmp_R) def fft_sphere(self, in_R, pw): self.tmp_R[:] = in_R.data self.fft() coefs = pw.cut(self.tmp_Q) * (1 / self.tmp_R.size) return coefs
[docs]class FFTWPlans(FFTPlans): """FFTW3 3d transforms.""" def __init__(self, size_c, dtype, flags=MEASURE): if not have_fftw(): raise ImportError('Not compiled with FFTW.') super().__init__(size_c, dtype) self._fftplan = cgpaw.FFTWPlan(self.tmp_R, self.tmp_Q, -1, flags) self._ifftplan = cgpaw.FFTWPlan(self.tmp_Q, self.tmp_R, 1, flags)
[docs] def fft(self): cgpaw.FFTWExecute(self._fftplan)
[docs] def ifft(self): cgpaw.FFTWExecute(self._ifftplan)
def __del__(self): # Attributes will not exist if execution stops during FFTW planning if hasattr(self, '_fftplan'): cgpaw.FFTWDestroy(self._fftplan) if hasattr(self, '_ifftplan'): cgpaw.FFTWDestroy(self._ifftplan)
[docs]class NumpyFFTPlans(FFTPlans): """Numpy fallback."""
[docs] def fft(self): if self.tmp_R.dtype == float: self.tmp_Q[:] = rfftn(self.tmp_R, overwrite_x=True) else: self.tmp_Q[:] = fftn(self.tmp_R, overwrite_x=True)
[docs] def ifft(self): if self.tmp_R.dtype == float: self.tmp_R[:] = irfftn(self.tmp_Q, self.tmp_R.shape, norm='forward', overwrite_x=True) else: self.tmp_R[:] = ifftn(self.tmp_Q, self.tmp_R.shape, norm='forward', overwrite_x=True)
def rfftn_patch(tmp_R): from gpaw.gpu import cupyx warnings.warn(f'CuFFTError for cupyx.scipy.fft.rfftn {tmp_R.shape}.' f'reverting to using just fftn. This is a bug in ROCM cupy.') return cupyx.scipy.fft.fftn(tmp_R)[:, :, :tmp_R.shape[-1] // 2 + 1] class CuPyFFTPlans(FFTPlans): def __init__(self, size_c: IntVector, dtype: DTypeLike): from gpaw.core import PWDesc from gpaw.gpu import cupy as cp self.dtype = dtype super().__init__(size_c, dtype, empty=cp.empty) self.Q_G_cache: dict[PWDesc, Array1D] = {} def fft(self): from gpaw.gpu import cupyx from gpaw.gpu import cupy as cp if self.tmp_R.dtype == float: try: self.tmp_Q[:] = cupyx.scipy.fft.rfftn(self.tmp_R) except cp.cuda.cufft.CuFFTError: self.tmp_Q[:] = rfftn_patch(self.tmp_R) else: self.tmp_Q[:] = cupyx.scipy.fft.fftn(self.tmp_R) def ifft(self): from gpaw.gpu import cupyx if self.tmp_R.dtype == float: self.tmp_R[:] = cupyx.scipy.fft.irfftn( self.tmp_Q, self.tmp_R.shape, norm='forward', overwrite_x=True) else: self.tmp_R[:] = cupyx.scipy.fft.ifftn( self.tmp_Q, self.tmp_R.shape, norm='forward', overwrite_x=True) def indices(self, pw): from gpaw.gpu import cupy as cp Q_G = self.Q_G_cache.get(pw) if Q_G is None: Q_G = cp.asarray(pw.indices(self.shape)) self.Q_G_cache[pw] = Q_G return Q_G def ifft_sphere(self, coef_G, pw, out_R): from gpaw.gpu import cupyx if coef_G is None: out_R.scatter_from(None) return if out_R.desc.comm.size == 1: array_R = out_R.data else: array_R = self.tmp_R array_Q = self.tmp_Q array_Q[:] = 0.0 Q_G = self.indices(pw) array_Q.ravel()[Q_G] = coef_G if self.dtype == complex: array_R[:] = cupyx.scipy.fft.ifftn( array_Q, array_Q.shape, norm='forward', overwrite_x=True) else: # We need a GPU kernel for this stuff: t = array_Q[:, :, 0] n, m = (s // 2 - 1 for s in out_R.desc.size_c[:2]) t[0, -m:] = t[0, m:0:-1].conj() t[n:0:-1, -m:] = t[-n:, m:0:-1].conj() t[-n:, -m:] = t[n:0:-1, m:0:-1].conj() t[-n:, 0] = t[n:0:-1, 0].conj() array_R[:] = cupyx.scipy.fft.irfftn( array_Q, out_R.desc.global_shape(), norm='forward', overwrite_x=True) if out_R.desc.comm.size > 1: out_R.scatter_from(array_R) def fft_sphere(self, in_R, pw): from gpaw.gpu import cupyx from gpaw.gpu import cupy as cp if self.dtype == complex: out_Q = cupyx.scipy.fft.fftn(in_R) else: try: out_Q = cupyx.scipy.fft.rfftn(in_R) except cp.cuda.cufft.CuFFTError: out_Q = rfftn_patch(in_R) Q_G = self.indices(pw) coef_G = out_Q.ravel()[Q_G] * (1 / in_R.size) return coef_G # The rest of this file will be removed in the future ... def check_fftw_inputs(in_R, out_R): for arr in in_R, out_R: # Note: Arrays not necessarily contiguous due to 16-byte alignment assert arr.ndim == 3 # We can perhaps relax this requirement assert arr.dtype == float or arr.dtype == complex if in_R.dtype == out_R.dtype == complex: assert in_R.shape == out_R.shape else: # One real and one complex: R, C = (in_R, out_R) if in_R.dtype == float else (out_R, in_R) assert C.dtype == complex assert R.shape[:2] == C.shape[:2] assert C.shape[2] == 1 + R.shape[2] // 2
[docs]class FFTPlan: """FFT 3d transform.""" def __init__(self, in_R: Array3D, out_R: Array3D, sign: int, flags: int = MEASURE): check_fftw_inputs(in_R, out_R) self.in_R = in_R self.out_R = out_R self.sign = sign self.flags = flags def execute(self) -> None: raise NotImplementedError
[docs]class FFTWPlan(FFTPlan): """FFTW3 3d transform.""" def __init__(self, in_R, out_R, sign, flags=MEASURE): if not have_fftw(): raise ImportError('Not compiled with FFTW.') self._ptr = cgpaw.FFTWPlan(in_R, out_R, sign, flags) FFTPlan.__init__(self, in_R, out_R, sign, flags) def execute(self): cgpaw.FFTWExecute(self._ptr) def __del__(self): if getattr(self, '_ptr', None): cgpaw.FFTWDestroy(self._ptr) self._ptr = None
[docs]class NumpyFFTPlan(FFTPlan): """Numpy fallback.""" def execute(self): if self.in_R.dtype == float: self.out_R[:] = np.fft.rfftn(self.in_R) elif self.out_R.dtype == float: self.out_R[:] = np.fft.irfftn(self.in_R, self.out_R.shape) self.out_R *= self.out_R.size elif self.sign == 1: self.out_R[:] = np.fft.ifftn(self.in_R, self.out_R.shape) self.out_R *= self.out_R.size else: self.out_R[:] = np.fft.fftn(self.in_R)
def create_plan(in_R: Array3D, out_R: Array3D, sign: int, flags: int = MEASURE) -> FFTPlan: if have_fftw(): return FFTWPlan(in_R, out_R, sign, flags) return NumpyFFTPlan(in_R, out_R, sign, flags)