Source code for gpaw.matrix

"""BLACS distributed matrix object."""
from typing import Dict, Tuple
import numpy as np
import scipy.linalg as linalg

import _gpaw
from gpaw import debug
from gpaw.mpi import serial_comm, _Communicator
import gpaw.utilities.blas as blas


_global_blacs_context_store: Dict[Tuple[_Communicator, int, int], int] = {}


[docs]def matrix_matrix_multiply(alpha, a, opa, b, opb, beta=0.0, c=None, symmetric=False): """BLAS-style matrix-matrix multiplication. Will use dgemm/zgemm/dsyrk/zherk/dsyr2k/zher2k as apropriate or the equivalent PBLAS functions for distributed matrices. The coefficients alpha and beta are of type float. Matrices a, b and c must have same type (float or complex). The strings opa and opb must be 'N', 'T', or 'C' . For opa='N' and opb='N', the operation performed is equivalent to:: c.array[:] = alpha * np.dot(a.array, b.array) + beta * c.array Replace a.array with a.array.T or a.array.T.conj() for opa='T' and 'C' respectively (similarly for opb). Use symmetric=True if the result matrix is symmetric/hermetian (only lower half of c will be evaluated). """ return _matrix(a).multiply(alpha, opa, _matrix(b), opb, beta, c if c is None else _matrix(c), symmetric)
def suggest_blocking(N, ncpus): """Suggest blocking of NxN matrix. Returns rows, columns, blocksize tuple.""" nprow = ncpus npcol = 1 # Make npcol and nprow as close to each other as possible npcol_try = npcol while npcol_try < nprow: if ncpus % npcol_try == 0: npcol = npcol_try nprow = ncpus // npcol npcol_try += 1 assert npcol * nprow == ncpus # ScaLAPACK creates trouble if there aren't at least a few whole blocks. # Choose block size so that there will always be at least one whole block # and at least two blocks in total. blocksize = max((N - 2) // max(nprow, npcol), 1) # The next commented line would give more whole blocks. # blocksize = max(N // max(nprow, npcol) - 2, 1) # Use block size that is a power of 2 and at most 64 blocksize = 2**int(np.log2(blocksize)) blocksize = max(min(blocksize, 64), 1) return nprow, npcol, blocksize
[docs]class Matrix: def __init__(self, M, N, dtype=None, data=None, dist=None): """Matrix object. M: int Rows. N: int Columns. dtype: type Data type (float or complex). dist: tuple or None BLACS distribution given as (communicator, rows, colums, blocksize) tuple. Default is None meaning no distribution. data: ndarray or None. Numpy ndarray to use for starage. By default, a new ndarray will be allocated. """ self.shape = (M, N) if dtype is None: if data is None: dtype = float else: dtype = data.dtype self.dtype = np.dtype(dtype) dist = dist or () if isinstance(dist, tuple): dist = create_distribution(M, N, *dist) self.dist = dist if data is None: self.array = np.empty(dist.shape, self.dtype) else: self.array = data.reshape(dist.shape) self.comm = serial_comm self.state = 'everything is fine' def __len__(self): return self.shape[0] def __repr__(self): dist = str(self.dist).split('(')[1] return 'Matrix({}: {}'.format(self.dtype.name, dist)
[docs] def new(self, dist='inherit'): """Create new matrix of same shape and dtype. Default is to use same BLACS distribution. Use dist to use another distribution. """ return Matrix(*self.shape, dtype=self.dtype, dist=self.dist if dist == 'inherit' else dist)
def __setitem__(self, i, x): # assert i == slice(None) if isinstance(x, np.ndarray): 1 / 0 # sssssself.array[:] = x else: x.eval(self) def __iadd__(self, x): x.eval(self, 1.0) return self
[docs] def multiply(self, alpha, opa, b, opb, beta=0.0, out=None, symmetric=False): """BLAS-style Matrix-matrix multiplication. See matrix_matrix_multipliction() for details. """ dist = self.dist if out is None: assert beta == 0.0 if opa == 'N': M = self.shape[0] else: M = self.shape[1] if opb == 'N': N = b.shape[1] else: N = b.shape[0] out = Matrix(M, N, self.dtype, dist=(dist.comm, dist.rows, dist.columns)) if dist.comm.size > 1: # Special cases that don't need scalapack - most likely also # faster: if alpha == 1.0 and opa == 'N' and opb == 'N': return fastmmm(self, b, out, beta) if alpha == 1.0 and beta == 1.0 and opa == 'N' and opb == 'C': if symmetric: return fastmmm2(self, b, out) else: return fastmmm2notsym(self, b, out) dist.multiply(alpha, self, opa, b, opb, beta, out, symmetric) return out
[docs] def redist(self, other): """Redistribute to other BLACS layout.""" if self is other: return d1 = self.dist d2 = other.dist n1 = d1.rows * d1.columns n2 = d2.rows * d2.columns if n1 == n2 == 1: other.array[:] = self.array return if n2 == 1 and d1.blocksize is None: assert d2.blocksize is None comm = d1.comm if comm.rank == 0: M = len(self) m = (M + comm.size - 1) // comm.size other.array[:m] = self.array for r in range(1, comm.size): m1 = min(r * m, M) m2 = min(m1 + m, M) comm.receive(other.array[m1:m2], r) else: comm.send(self.array, 0) return if n1 == 1 and d2.blocksize is None: assert d1.blocksize is None comm = d1.comm if comm.rank == 0: M = len(self) m = (M + comm.size - 1) // comm.size other.array[:] = self.array[:m] for r in range(1, comm.size): m1 = min(r * m, M) m2 = min(m1 + m, M) comm.send(self.array[m1:m2], r) else: comm.receive(other.array, 0) return c = d1.comm if d1.comm.size > d2.comm.size else d2.comm n = max(n1, n2) if n < c.size: c = c.new_communicator(np.arange(n)) if c is not None: M, N = self.shape d1 = create_distribution(M, N, c, d1.rows, d1.columns, d1.blocksize) d2 = create_distribution(M, N, c, d2.rows, d2.columns, d2.blocksize) if n1 == n: ctx = d1.desc[1] else: ctx = d2.desc[1] redist(d1, self.array, d2, other.array, ctx)
[docs] def invcholesky(self): """Inverse of Cholesky decomposition. Only the lower part is used. """ if self.state == 'a sum is needed': self.comm.sum(self.array, 0) if self.comm.rank == 0: if self.dist.comm.size > 1: S = self.new(dist=(self.dist.comm, 1, 1)) self.redist(S) else: S = self if self.dist.comm.rank == 0: if debug: S.array[np.triu_indices(S.shape[0], 1)] = 42.0 L_nn = linalg.cholesky(S.array, lower=True, overwrite_a=True, check_finite=debug) S.array[:] = linalg.inv(L_nn, overwrite_a=True, check_finite=debug) if S is not self: S.redist(self) if self.comm.size > 1: self.comm.broadcast(self.array, 0) self.state == 'everything is fine'
[docs] def eigh(self, cc=False, scalapack=(None, 1, 1, None)): """Calculate eigenvectors and eigenvalues. Matrix must be symmetric/hermitian and stored in lower half. cc: bool Complex conjugate matrix before finding eigenvalues. scalapack: tuple BLACS distribution for ScaLapack to use. Default is to do serial diagonalization. """ slcomm, rows, columns, blocksize = scalapack if self.state == 'a sum is needed': self.comm.sum(self.array, 0) slcomm = slcomm or self.dist.comm dist = (slcomm, rows, columns, blocksize) redist = (rows != self.dist.rows or columns != self.dist.columns or blocksize != self.dist.blocksize) if redist: H = self.new(dist=dist) self.redist(H) else: assert self.dist.comm.size == slcomm.size H = self eps = np.empty(H.shape[0]) if rows * columns == 1: if self.comm.rank == 0 and self.dist.comm.rank == 0: if cc and H.dtype == complex: np.negative(H.array.imag, H.array.imag) if debug: H.array[np.triu_indices(H.shape[0], 1)] = 42.0 eps[:], H.array.T[:] = linalg.eigh(H.array, lower=True, # ??? overwrite_a=True, check_finite=debug) self.dist.comm.broadcast(eps, 0) elif slcomm.rank < rows * columns: assert cc array = H.array.copy() info = _gpaw.scalapack_diagonalize_dc(array, H.dist.desc, 'U', H.array, eps) assert info == 0, info if redist: H.redist(self) assert (self.state == 'a sum is needed') == ( self.comm.size > 1) if self.comm.size > 1: self.comm.broadcast(self.array, 0) self.comm.broadcast(eps, 0) self.state == 'everything is fine' return eps
[docs] def complex_conjugate(self): """Inplace complex conjugation.""" if self.dtype == complex: np.negative(self.array.imag, self.array.imag)
def _matrix(M): """Dig out Matrix object from wrapper(s).""" if isinstance(M, Matrix): return M return _matrix(M.matrix) class NoDistribution: comm = serial_comm rows = 1 columns = 1 blocksize = None def __init__(self, M, N): self.shape = (M, N) def __str__(self): return 'NoDistribution({}x{})'.format(*self.shape) def global_index(self, n): return n def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric): if symmetric: assert opa == 'N' assert opb == 'C' or opb == 'T' and a.dtype == float if a is b: blas.rk(alpha, a.array, beta, c.array) else: if beta == 1.0 and a.shape[1] == 0: return blas.r2k(0.5 * alpha, a.array, b.array, beta, c.array) else: blas.mmm(alpha, a.array, opa, b.array, opb, beta, c.array) class BLACSDistribution: serial = False def __init__(self, M, N, comm, r, c, b): self.comm = comm self.rows = r self.columns = c self.blocksize = b key = (comm, r, c) context = _global_blacs_context_store.get(key) if context is None: try: context = _gpaw.new_blacs_context(comm.get_c_object(), c, r, 'R') except AttributeError: pass else: _global_blacs_context_store[key] = context if b is None: if c == 1: br = (M + r - 1) // r bc = max(1, N) elif r == 1: br = M bc = (N + c - 1) // c else: raise ValueError('Please specify block size!') else: br = bc = b if context is None: assert b is None assert c == 1 n = N m = min((comm.rank + 1) * br, M) - min(comm.rank * br, M) else: n, m = _gpaw.get_blacs_local_shape(context, N, M, bc, br, 0, 0) if n < 0 or m < 0: n = m = 0 self.shape = (m, n) lld = max(1, n) if context is not None: self.desc = np.array([1, context, N, M, bc, br, 0, 0, lld], np.intc) def __str__(self): return ('BLACSDistribution(global={}, local={}, blocksize={})' .format(*('{}x{}'.format(*shape) for shape in [self.desc[3:1:-1], self.shape, self.desc[5:3:-1]]))) def global_index(self, myi): return self.comm.rank * int(self.desc[5]) + myi def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric): if symmetric: assert opa == 'N' assert opb == 'C' or opb == 'T' and a.dtype == float N, K = a.shape if a is b: _gpaw.pblas_rk(N, K, alpha, a.array, beta, c.array, a.dist.desc, c.dist.desc, 'U') else: _gpaw.pblas_r2k(N, K, 0.5 * alpha, b.array, a.array, beta, c.array, b.dist.desc, a.dist.desc, c.dist.desc, 'U') else: Ka, M = a.shape N, Kb = b.shape if opa == 'N': Ka, M = M, Ka if opb == 'N': N, Kb = Kb, N _gpaw.pblas_gemm(N, M, Ka, alpha, b.array, a.array, beta, c.array, b.dist.desc, a.dist.desc, c.dist.desc, opb, opa) def redist(dist1, M1, dist2, M2, context): _gpaw.scalapack_redist(dist1.desc, dist2.desc, M1, M2, dist1.desc[2], dist1.desc[3], 1, 1, 1, 1, # 1-indexing context, 'G') def create_distribution(M, N, comm=None, r=1, c=1, b=None): if comm is None or comm.size == 1: assert r == 1 and abs(c) == 1 or c == 1 and abs(r) == 1 return NoDistribution(M, N) return BLACSDistribution(M, N, comm, r if r != -1 else comm.size, c if c != -1 else comm.size, b) def fastmmm(m1, m2, m3, beta): comm = m1.dist.comm buf1 = m2.array N = len(m1) n = (N + comm.size - 1) // comm.size for r in range(comm.size): if r == 0: buf2 = np.empty((n, buf1.shape[1]), dtype=buf1.dtype) rrequest = None srequest = None if r < comm.size - 1: rrank = (comm.rank + r + 1) % comm.size rn1 = min(rrank * n, N) rn2 = min(rn1 + n, N) if rn2 > rn1: rrequest = comm.receive(buf2[:rn2 - rn1], rrank, 21, False) srank = (comm.rank - r - 1) % comm.size if len(m2.array) > 0: srequest = comm.send(m2.array, srank, 21, False) r0 = (comm.rank + r) % comm.size n1 = min(r0 * n, N) n2 = min(n1 + n, N) blas.mmm(1.0, m1.array[:, n1:n2], 'N', buf1[:n2 - n1], 'N', beta, m3.array) beta = 1.0 if r == 0: buf1 = np.empty_like(buf2) buf1, buf2 = buf2, buf1 if rrequest: comm.wait(rrequest) if srequest: comm.wait(srequest) return m3 def fastmmm2(a, b, out): if a.comm: assert b.comm is a.comm if a.comm.size > 1: assert out.comm == a.comm assert out.state == 'a sum is needed' comm = a.dist.comm M, N = a.shape m = (M + comm.size - 1) // comm.size mym = len(a.array) buf1 = np.empty((m, N), dtype=a.dtype) buf2 = np.empty((m, N), dtype=a.dtype) half = comm.size // 2 aa = a.array bb = b.array for r in range(half + 1): rrequest = None srequest = None if r < half: srank = (comm.rank + r + 1) % comm.size rrank = (comm.rank - r - 1) % comm.size skip = (comm.size % 2 == 0 and r == half - 1) m1 = min(rrank * m, M) m2 = min(m1 + m, M) if not (skip and comm.rank < half) and m2 > m1: rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False) if not (skip and comm.rank >= half) and mym > 0: srequest = comm.send(b.array, srank, 11, False) if not (comm.size % 2 == 0 and r == half and comm.rank < half): m1 = min(((comm.rank - r) % comm.size) * m, M) m2 = min(m1 + m, M) if r == 0: # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc blas.mmm(1.0, aa, 'N', bb, 'C', 1.0, out.array[:, m1:m2]) else: beta = 1.0 if r <= comm.rank else 0.0 blas.mmm(1.0, aa, 'N', buf2[:m2 - m1], 'C', beta, out.array[:, m1:m2]) # out.array[:, m1:m2] = m12.array[:, :m2 - m1] if rrequest: comm.wait(rrequest) if srequest: comm.wait(srequest) bb = buf1 buf1, buf2 = buf2, buf1 requests = [] blocks = [] nrows = (comm.size - 1) // 2 for row in range(nrows): for column in range(comm.size - nrows + row, comm.size): if comm.rank == row: m1 = min(column * m, M) m2 = min(m1 + m, M) if mym > 0 and m2 > m1: requests.append( comm.send(out.array[:, m1:m2].T.conj().copy(), column, 12, False)) elif comm.rank == column: m1 = min(row * m, M) m2 = min(m1 + m, M) if mym > 0 and m2 > m1: block = np.empty((mym, m2 - m1), out.dtype) blocks.append((m1, m2, block)) requests.append(comm.receive(block, row, 12, False)) comm.waitall(requests) for m1, m2, block in blocks: out.array[:, m1:m2] += block return out def fastmmm2notsym(a, b, out): if a.comm: assert b.comm is a.comm if a.comm.size > 1: assert out.comm == a.comm assert out.state == 'a sum is needed' comm = a.dist.comm M, N = a.shape m = (M + comm.size - 1) // comm.size mym = len(a.array) buf1 = np.empty((m, N), dtype=a.dtype) buf2 = np.empty((m, N), dtype=a.dtype) aa = a.array bb = b.array for r in range(comm.size): rrequest = None srequest = None if r < comm.size - 1: srank = (comm.rank + r + 1) % comm.size rrank = (comm.rank - r - 1) % comm.size m1 = min(rrank * m, M) m2 = min(m1 + m, M) if m2 > m1: rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False) if mym > 0: srequest = comm.send(b.array, srank, 11, False) m1 = min(((comm.rank - r) % comm.size) * m, M) m2 = min(m1 + m, M) # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc ?? blas.mmm(1.0, aa, 'N', bb[:m2 - m1], 'C', 1.0, out.array[:, m1:m2]) if rrequest: comm.wait(rrequest) if srequest: comm.wait(srequest) bb = buf1 buf1, buf2 = buf2, buf1 return out