Source code for gpaw.band_descriptor

# Copyright (C) 2003  CAMP
# Please see the accompanying LICENSE file for further information.

"""Band-descriptors for blocked/strided groups.

This module contains classes defining two kinds of band groups:

* Blocked groups with contiguous band indices.
* Strided groups with evenly-spaced band indices.
"""

import numpy as np

from gpaw.mpi import serial_comm

NONBLOCKING = False


[docs]class BandDescriptor: r"""Descriptor-class for ordered lists of bands A ``BandDescriptor`` object holds information on how functions, such as wave functions and corresponding occupation numbers, are divided into groups according to band indices. The main information here is how many bands are stored on each processor and who gets what. This is how a 12 band array is laid out in memory on 3 cpu's:: a) Blocked groups b) Strided groups 3 7 11 9 10 11 myn 2 \ 6 \ 10 myn 6 7 8 | 1 \ 5 \ 9 | 3 4 5 | 0 4 8 | 0 1 2 | | +----- band_rank +----- band_rank Example: >>> a = np.zeros((3, 4)) >>> a.ravel()[:] = range(12) >>> a array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]]) >>> b = np.zeros((4, 3)) >>> b.ravel()[:] = range(12) >>> b.T array([[ 0., 3., 6., 9.], [ 1., 4., 7., 10.], [ 2., 5., 8., 11.]]) """ def __init__(self, nbands: int, comm=None, strided=False): """Construct band-descriptor object. Parameters: nbands: int Global number of bands. comm: MPI-communicator Communicator for band-groups. strided: bool Enable strided band distribution for better load balancing with many unoccupied bands. Note that if comm.size is 1, then all bands are contained on a single CPU and blocked/strided grouping loses its meaning. Attributes: ============ ====================================================== ``nbands`` Number of bands in total. ``mynbands`` Number of bands on this CPU. ``beg`` Beginning of band indices in group (inclusive). ``end`` End of band indices in group (exclusive). ``step`` Stride for band indices between ``beg`` and ``end``. ``comm`` MPI-communicator for band distribution. ============ ====================================================== """ if comm is None: comm = serial_comm self.comm = comm self.nbands = nbands self.strided = strided self.maxmynbands = (nbands + comm.size - 1) // comm.size if strided: assert nbands % comm.size == 0 self.beg = comm.rank self.end = nbands self.step = comm.size self.mynbands = self.maxmynbands else: self.beg = min(nbands, comm.rank * self.maxmynbands) self.end = min(nbands, self.beg + self.maxmynbands) self.mynbands = self.end - self.beg self.step = 1 def __len__(self): return self.mynbands
[docs] def get_slice(self, band_rank=None): """Return the slice of global bands which belong to a given rank.""" if band_rank is None: band_rank = self.comm.rank assert band_rank in range(self.comm.size) if self.strided: nstride = self.comm.size nslice = slice(band_rank, None, nstride) else: nslice = slice(self.beg, self.end) return nslice
[docs] def get_band_indices(self, band_rank=None): """Return the global band indices which belong to a given rank.""" nslice = self.get_slice(band_rank) return np.arange(*nslice.indices(self.nbands))
[docs] def get_band_ranks(self): """Return array of ranks as a function of global band indices.""" rank_n = np.empty(self.nbands, dtype=int) for band_rank in range(self.comm.size): nslice = self.get_slice(band_rank) rank_n[nslice] = band_rank assert (rank_n >= 0).all() and (rank_n < self.comm.size).all() return rank_n
[docs] def who_has(self, n): """Convert global band index to rank information and local index.""" if self.strided: myn, band_rank = divmod(n, self.comm.size) else: band_rank, myn = divmod(n, self.maxmynbands) return band_rank, myn
[docs] def global_index(self, myn, band_rank=None): """Convert rank information and local index to global index.""" if band_rank is None: band_rank = self.comm.rank if self.strided: return band_rank + myn * self.comm.size return band_rank * self.maxmynbands + myn
def get_size_of_global_array(self): return (self.nbands,)
[docs] def zeros(self, n=(), dtype=float, global_array=False): """Return new zeroed 3D array for this domain. The type can be set with the ``dtype`` keyword (default: ``float``). Extra dimensions can be added with ``n=dim``. A global array spanning all domains can be allocated with ``global_array=True``.""" # TODO XXX doc return self._new_array(n, dtype, True, global_array)
[docs] def empty(self, n=(), dtype=float, global_array=False): """Return new uninitialized 3D array for this domain. The type can be set with the ``dtype`` keyword (default: ``float``). Extra dimensions can be added with ``n=dim``. A global array spanning all domains can be allocated with ``global_array=True``.""" # TODO XXX doc return self._new_array(n, dtype, False, global_array)
def _new_array(self, n=(), dtype=float, zero=True, global_array=False): if global_array: shape = self.get_size_of_global_array() else: shape = (self.mynbands,) if isinstance(n, int): n = (n,) shape = tuple(shape) + n if zero: return np.zeros(shape, dtype) else: return np.empty(shape, dtype)
[docs] def collect(self, a_nx, broadcast=False): """Collect distributed array to master-CPU or all CPU's.""" if self.comm.size == 1: return a_nx xshape = a_nx.shape[1:] # Optimization for blocked groups if not self.strided: if self.comm.size * self.maxmynbands > self.nbands: return self.nasty_non_strided_collect(a_nx, broadcast) if broadcast: A_nx = self.empty(xshape, a_nx.dtype, global_array=True) self.comm.all_gather(a_nx, A_nx) return A_nx if self.comm.rank == 0: A_nx = self.empty(xshape, a_nx.dtype, global_array=True) else: A_nx = None self.comm.gather(a_nx, 0, A_nx) return A_nx # Collect all arrays on the master: if self.comm.rank != 0: # There can be several sends before the corresponding receives # are posted, so use syncronous send here self.comm.ssend(a_nx, 0, 3011) if broadcast: A_nx = self.empty(xshape, a_nx.dtype, global_array=True) self.comm.broadcast(A_nx, 0) return A_nx else: return None # Put the band groups from the slaves into the big array # for the whole collection of bands: A_nx = self.empty(xshape, a_nx.dtype, global_array=True) for band_rank in range(self.comm.size): if band_rank != 0: a_nx = self.empty(xshape, a_nx.dtype, global_array=False) self.comm.receive(a_nx, band_rank, 3011) A_nx[self.get_slice(band_rank), ...] = a_nx if broadcast: self.comm.broadcast(A_nx, 0) return A_nx
[docs] def distribute(self, B_nx, b_nx): """ distribute full array B_nx to band groups, result in b_nx. b_nx must be allocated.""" S = self.comm.size if S == 1: b_nx[:] = B_nx return # Optimization for blocked groups if not self.strided: M2 = self.maxmynbands if M2 * S == self.nbands: self.comm.scatter(B_nx, b_nx, 0) return if self.comm.rank == 0: C_nx = np.empty((S * M2,) + B_nx.shape[1:], B_nx.dtype) C_nx[:self.nbands] = B_nx else: C_nx = None if self.mynbands < M2: c_nx = np.empty((M2,) + b_nx.shape[1:], b_nx.dtype) else: c_nx = b_nx self.comm.scatter(C_nx, c_nx, 0) if self.mynbands < M2: b_nx[:] = c_nx[:self.mynbands] return if self.comm.rank != 0: self.comm.receive(b_nx, 0, 421) return else: requests = [] for band_rank in range(self.comm.size): if band_rank != 0: a_nx = B_nx[self.get_slice(band_rank), ...].copy() request = self.comm.send(a_nx, band_rank, 421, NONBLOCKING) # Remember to store a reference to the # send buffer (a_nx) so that is isn't # deallocated: requests.append((request, a_nx)) else: b_nx[:] = B_nx[self.get_slice(), ...] for request, a_nx in requests: self.comm.wait(request)
def nasty_non_strided_collect(self, a_nx, broadcast): xshape = a_nx.shape[1:] if broadcast: A_nx = self.nasty_non_strided_collect(a_nx, False) if A_nx is None: A_nx = self.empty(xshape, a_nx.dtype, global_array=True) self.comm.broadcast(A_nx, 0) return A_nx S = self.comm.size M2 = self.maxmynbands if self.comm.rank == 0: A_nx = np.empty((S * M2,) + xshape, a_nx.dtype) else: A_nx = None if self.mynbands < M2: b_nx = np.empty((M2,) + xshape, a_nx.dtype) b_nx[:self.mynbands] = a_nx else: b_nx = a_nx self.comm.gather(b_nx, 0, A_nx) if self.comm.rank > 0: return return A_nx[:self.nbands]