Source code for gpaw.lrtddft2.lr_communicators

import gpaw.mpi


[docs]class LrCommunicators: def __init__(self, world, dd_size: int, eh_size: int = None): """Create communicators for LrTDDFT calculation. Parameters ---------- world MPI parent communicator (usually ``gpaw.mpi.world``) dd_size Number of domains for domain decomposition eh_size Number of groups for parallelization over electron-hole pairs Note ---- Sizes must match, i.e., world.size must be equal to dd_size x eh_size, e.g., 1024 = 64*16 Tip --- Use enough processes for domain decomposition (dd_size) to fit everything (easily) into memory, and use the remaining processes for electron-hole pairs as K-matrix build is trivially parallel over them. Pass ``lr_comms.dd_comm`` to ground state calc when reading for LrTDDFT. Examples -------- For 8 MPI processes:: lr_comms = LrCommunicators(gpaw.mpi.world, 4, 2) txt = 'lr_%06d_%06d.txt' % (lr_comms.dd_comm.rank, lr_comms.eh_comm.rank) calc = GPAW('unocc.gpw', communicator=lr_comms.dd_comm) lr = LrTDDFT2(calc, lr_communicators=lr_comms, txt=txt) """ self.parent_comm = None self.dd_comm = None self.eh_comm = None self.world = world self.dd_size = dd_size self.eh_size = eh_size if self.world is None: return if self.dd_size is None: return if self.eh_size is None: self.eh_size = self.world.size // self.dd_size self.parent_comm = self.world if self.world.size != self.dd_size * self.eh_size: raise RuntimeError('Domain decomposition processes (dd_size) ' 'times electron-hole (eh_size) processes ' 'does not match with total processes ' '(world size != dd_size * eh_size)') dd_ranks = [] eh_ranks = [] for k in range(self.world.size): if k // self.dd_size == self.world.rank // self.dd_size: dd_ranks.append(k) if k % self.dd_size == self.world.rank % self.dd_size: eh_ranks.append(k) self.dd_comm = self.world.new_communicator(dd_ranks) self.eh_comm = self.world.new_communicator(eh_ranks) def initialize(self, calc): if self.parent_comm is None: if calc is not None: self.dd_comm = calc.density.gd.comm self.parent_comm = self.dd_comm.parent if self.parent_comm.size != self.dd_comm.size: raise RuntimeError( 'Invalid communicators in LrTDDFT2. Ground state ' 'calculator domain decomposition communicator and ' 'its parent (or actually its parent parent) has ' 'different size. Please set up LrCommunicators ' 'explicitly to avoid this. Or contact developers ' 'if this is intentional.' ) self.eh_comm = gpaw.mpi.serial_comm else: self.parent_comm = gpaw.mpi.serial_comm self.dd_comm = gpaw.mpi.serial_comm self.eh_comm = gpaw.mpi.serial_comm else: # Check that parent_comm is valid if self.parent_comm != self.eh_comm.parent: raise RuntimeError( 'Invalid communicators in LrTDDFT2. LrTDDFT2 parent ' 'communicator does is not parent of electron-hole ' 'communicator. Please set up LrCommunicators explicitly ' 'to avoid this.') if self.parent_comm != self.dd_comm.parent: raise RuntimeError( 'Invalid communicators in LrTDDFT2. LrTDDFT2 parent ' 'communicator does is not parent of domain decomposition ' 'communicator. Please set up LrCommunicators explicitly ' 'to avoid this.') # Do not use so slow... unless absolutely necessary # def index_of_kss(self,i,p): # for (ind,kss) in enumerate(self.kss_list): # if kss.occ_ind == i and kss.unocc_ind == p: # return ind # return None def get_local_eh_index(self, ip): if ip % self.eh_comm.size != self.eh_comm.rank: return None return ip // self.eh_comm.size def get_local_dd_index(self, jq): if jq % self.dd_comm.size != self.dd_comm.rank: return None return jq // self.dd_comm.size def get_global_eh_index(self, lip): return lip * self.eh_comm.size + self.eh_comm.rank def get_global_dd_index(self, ljq): return ljq * self.dd_comm.size + self.dd_comm.rank def get_matrix_elem_proc_and_index(self, ip, jq): ehproc = ip % self.eh_comm.size ddproc = jq % self.dd_comm.size proc = ehproc * self.dd_comm.size + ddproc lip = ip // self.eh_comm.size ljq = jq // self.eh_comm.size return (proc, ehproc, ddproc, lip, ljq)