Source code for gpaw.utilities.blas

```# Copyright (C) 2003  CAMP

"""
Python wrapper functions for the ``C`` package:
Basic Linear Algebra Subroutines (BLAS)

http://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms
and
http://www.netlib.org/lapack/lug/node145.html
"""
from typing import TypeVar

import gpaw.cgpaw as cgpaw
import numpy as np
import scipy.linalg.blas as blas
from gpaw import debug
from gpaw.new import prod
from gpaw.typing import Array2D, ArrayND
from gpaw.utilities import is_contiguous

def is_finite(array, tril=False):
if isinstance(array, np.ndarray):
xp = np
else:
from gpaw.gpu import cupy as xp
if tril:
array = xp.tril(array)
return xp.isfinite(array).all()

__all__ = ['mmm']

T = TypeVar('T', float, complex)

def mmm(alpha: T,
a: Array2D,
opa: str,
b: Array2D,
opb: str,
beta: T,
c: Array2D) -> None:
"""Matrix-matrix multiplication using dgemm or zgemm.

For opa='N' and opb='N', we have:::

c <- αab + βc.

Use 'T' to transpose matrices and 'C' to transpose and complex conjugate
matrices.
"""

assert opa in 'NTC'
assert opb in 'NTC'

if opa == 'N':
a1, a2 = a.shape
else:
a2, a1 = a.shape
if opb == 'N':
b1, b2 = b.shape
else:
b2, b1 = b.shape
assert a2 == b1
assert c.shape == (a1, b2)

assert a.dtype == b.dtype == c.dtype
assert a.strides[1] == c.itemsize or a.size == 0
assert b.strides[1] == c.itemsize or b.size == 0
assert c.strides[1] == c.itemsize or c.size == 0
if a.dtype == float:
assert not isinstance(alpha, complex)
assert not isinstance(beta, complex)
else:
assert a.dtype == complex

cgpaw.mmm(alpha, a, opa, b, opb, beta, c)

def gpu_mmm(alpha, a, opa, b, opb, beta, c):
"""Launch CPU or GPU version of mmm()."""
m = b.shape[1] if opb == 'N' else b.shape[0]
n = a.shape[0] if opa == 'N' else a.shape[1]
k = b.shape[0] if opb == 'N' else b.shape[1]
lda = a.strides[0] // a.itemsize
ldb = b.strides[0] // b.itemsize
ldc = c.strides[0] // c.itemsize
cgpaw.mmm_gpu(alpha, a.data.ptr, lda, opa,
b.data.ptr, ldb, opb, beta,
c.data.ptr, ldc, c.itemsize,
m, n, k)

def gpu_scal(alpha, x):
"""alpha x

Performs the operation::

x <- alpha * x

"""
if debug:
if isinstance(alpha, complex):
assert is_contiguous(x, complex)
else:
assert isinstance(alpha, float)
assert x.dtype in [float, complex]
assert x.flags.c_contiguous
cgpaw.scal_gpu(alpha, x.data.ptr, x.shape, x.dtype)

def to2d(array: ArrayND) -> Array2D:
"""2D view af ndarray.

>>> to2d(np.zeros((2, 3, 4))).shape
(2, 12)
"""
shape = array.shape
return array.reshape((shape[0], prod(shape[1:])))

def mmmx(alpha: T,
a: ArrayND,
opa: str,
b: ArrayND,
opb: str,
beta: T,
c: ArrayND) -> None:
"""Matrix-matrix multiplication using dgemm or zgemm.

Arrays a, b and c are converted to 2D arrays before calling mmm().
"""
mmm(alpha, to2d(a), opa, to2d(b), opb, beta, to2d(c))

def gpu_gemm(alpha, a, b, beta, c, transa='n'):
"""General Matrix Multiply.

Performs the operation::

c <- alpha * b.a + beta * c

If transa is "n", ``b.a`` denotes the matrix multiplication defined by::

_
\
(b.a)        =  ) b  * a
ijkl...   /_  ip   pjkl...
p

If transa is "t" or "c", ``b.a`` denotes the matrix multiplication
defined by::

_
\
(b.a)        =  ) b    *    a
ij        /_  iklm...   jklm...
klm...

where in case of "c" also complex conjugate of a is taken.
"""
if debug:
assert beta == 0.0 or is_finite(c)

assert (a.dtype == float and b.dtype == float and c.dtype == float and
isinstance(alpha, float) and isinstance(beta, float) or
a.dtype == complex and b.dtype == complex and
c.dtype == complex)
assert a.flags.c_contiguous
if transa == 'n':
assert c.flags.c_contiguous or (c.ndim == 2
and c.strides[1] == c.itemsize)
assert b.ndim == 2
assert b.strides[1] == b.itemsize
assert a.shape[0] == b.shape[1]
assert c.shape == b.shape[0:1] + a.shape[1:]
else:
assert b.size == 0 or b[0].flags.c_contiguous
assert c.strides[1] == c.itemsize
assert a.shape[1:] == b.shape[1:]
assert c.shape == (b.shape[0], a.shape[0])

cgpaw.gemm_gpu(alpha, a.data.ptr, a.shape,
b.data.ptr, b.shape, beta,
c.data.ptr, c.shape,
a.dtype, transa)

def gpu_gemv(alpha, a, x, beta, y, trans='t'):
"""General Matrix Vector product.

Performs the operation::

y <- alpha * a.x + beta * y

``a.x`` denotes matrix multiplication, where the product-sum is
over the entire length of the vector x and
the first dimension of a (for trans='n'), or
the last dimension of a (for trans='t' or 'c').

If trans='c', the complex conjugate of a is used. The default is
trans='t', i.e. behaviour like np.dot with a 2D matrix and a vector.
"""
if debug:
assert (a.dtype == float and x.dtype == float and y.dtype == float and
isinstance(alpha, float) and isinstance(beta, float) or
a.dtype == complex and x.dtype == complex and
y.dtype == complex)
assert a.flags.c_contiguous
assert y.flags.c_contiguous
assert x.ndim == 1
assert y.ndim == a.ndim - 1
if trans == 'n':
assert a.shape[0] == x.shape[0]
assert a.shape[1:] == y.shape
else:
assert a.shape[-1] == x.shape[0]
assert a.shape[:-1] == y.shape

cgpaw.gemv_gpu(alpha, a.data.ptr, a.shape,
x.data.ptr, x.shape, beta,
y.data.ptr, a.dtype,
trans)

def axpy(alpha, x, y):
"""alpha x plus y.

Performs the operation::

y <- alpha * x + y

"""
if x.size == 0:
return
assert x.flags.contiguous
assert y.flags.contiguous
x = x.ravel()
y = y.ravel()
if x.dtype == float:
z = blas.daxpy(x, y, a=alpha)
else:
z = blas.zaxpy(x, y, a=alpha)
assert z is y, (x, y, x.shape, y.shape)

def gpu_axpy(alpha, x, y):
"""alpha x plus y.

Performs the operation::

y <- alpha * x + y

"""
if debug:
if isinstance(alpha, complex):
assert is_contiguous(x, complex) and is_contiguous(y, complex)
else:
assert isinstance(alpha, float)
assert x.dtype in [float, complex]
assert x.dtype == y.dtype
assert x.flags.c_contiguous and y.flags.c_contiguous
assert x.shape == y.shape

cgpaw.axpy_gpu(alpha, x.data.ptr, x.shape,
y.data.ptr, y.shape,
x.dtype)

def rk(alpha, a, beta, c, trans='c'):
"""Rank-k update of a matrix.

For ``trans='c'`` the following operation is performed:::

†
c <- αaa + βc,

and for ``trans='t'`` we get:::

†
c <- αa a + βc

If the ``a`` array has more than 2 dimensions then the 2., 3., ...
axes are combined.

Only the lower triangle of ``c`` will contain sensible numbers.
"""
if debug:
assert beta == 0.0 or is_finite(c, tril=True)

assert (a.dtype == float and c.dtype == float or
a.dtype == complex and c.dtype == complex)
assert a.flags.c_contiguous
assert a.ndim > 1
if trans == 'n':
assert c.shape == (a.shape[1], a.shape[1])
else:
assert c.shape == (a.shape[0], a.shape[0])
assert c.strides[1] == c.itemsize or c.size == 0

cgpaw.rk(alpha, a, beta, c, trans)

def gpu_rk(alpha, a, beta, c, trans='c'):
"""Launch CPU or GPU version of rk()."""
cgpaw.rk_gpu(alpha, a.data.ptr, a.shape,
beta, c.data.ptr, c.shape,
a.dtype)

def r2k(alpha, a, b, beta, c, trans='c'):
"""Rank-2k update of a matrix.

Performs the operation::

dag        cc       dag
c <- alpha * a . b    + alpha  * b . a    + beta * c

or if trans='n'::
dag           cc   dag
c <- alpha * a   . b + alpha  * b   . a + beta * c

where ``a.b`` denotes the matrix multiplication defined by::

_
\
(a.b)   =  ) a         * b
ij   /_  ipklm...     pjklm...
pklm...

``cc`` denotes complex conjugation.

``dag`` denotes the hermitian conjugate (complex conjugation plus a
swap of axis 0 and 1).

Only the lower triangle of ``c`` will contain sensible numbers.
"""
if debug:
assert beta == 0.0 or is_finite(c, tril=True)
assert (a.dtype == float and b.dtype == float and c.dtype == float or
a.dtype == complex and b.dtype == complex and
c.dtype == complex)
assert a.flags.c_contiguous and b.flags.c_contiguous
assert a.ndim > 1
assert a.shape == b.shape
if trans == 'c':
assert c.shape == (a.shape[0], a.shape[0])
else:
assert c.shape == (a.shape[1], a.shape[1])
assert c.strides[1] == c.itemsize or c.size == 0

cgpaw.r2k(alpha, a, b, beta, c, trans)

def gpu_r2k(alpha, a, b, beta, c, trans='c'):
"""Launch CPU or GPU version of r2k()."""
cgpaw.r2k_gpu(alpha, a.data.ptr, a.shape,
b.data.ptr, b.shape, beta,
c.data.ptr, c.shape,
a.dtype)

def gpu_dotc(a, b):
r"""Dot product, conjugating the first vector with complex arguments.

Returns the value of the operation::

_
\   cc
) a       * b
/_  ijk...    ijk...
ijk...

``cc`` denotes complex conjugation.
"""
if debug:
assert ((is_contiguous(a, float) and is_contiguous(b, float)) or
(is_contiguous(a, complex) and is_contiguous(b, complex)))
assert a.shape == b.shape

return cgpaw.dotc_gpu(a.data.ptr, a.shape,
b.data.ptr, a.dtype)

def gpu_dotu(a, b):
"""Dot product, NOT conjugating the first vector with complex arguments.

Returns the value of the operation::

_
\
) a       * b
/_  ijk...    ijk...
ijk...

"""
if debug:
assert ((is_contiguous(a, float) and is_contiguous(b, float)) or
(is_contiguous(a, complex) and is_contiguous(b, complex)))
assert a.shape == b.shape

return cgpaw.dotu_gpu(a.data.ptr, a.shape,
b.data.ptr, a.dtype)

def _gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'):
"""Matrix multiplication using gemm.

return reference to out, where::

out <- alpha * a . b + beta * out

If out is None, a suitably sized zero array will be created.

``a.b`` denotes matrix multiplication, where the product-sum is
over the last dimension of a, and either
the first dimension of b (for trans='n'), or
the last dimension of b (for trans='t' or 'c').

If trans='c', the complex conjugate of b is used.
"""
# Store original shapes
ashape = a.shape
bshape = b.shape

# Vector-vector multiplication is handled by dotu
if a.ndim == 1 and b.ndim == 1:
assert out is None
if trans == 'c':
return alpha * np.vdot(b, a)  # dotc conjugates *first* argument
else:
return alpha * a.dot(b)

# Map all arrays to 2D arrays
a = a.reshape(-1, a.shape[-1])
if trans == 'n':
b = b.reshape(b.shape[0], -1)
outshape = a.shape[0], b.shape[1]
else:  # 't' or 'c'
b = b.reshape(-1, b.shape[-1])

# Apply BLAS gemm routine
outshape = a.shape[0], b.shape[trans == 'n']
if out is None:
# (ATLAS can't handle uninitialized output array)
out = np.zeros(outshape, a.dtype)
else:
out = out.reshape(outshape)
mmmx(alpha, a, 'N', b, trans.upper(), beta, out)

# Determine actual shape of result array
if trans == 'n':
outshape = ashape[:-1] + bshape[1:]
else:  # 't' or 'c'
outshape = ashape[:-1] + bshape[:-1]
return out.reshape(outshape)

if not hasattr(cgpaw, 'mmm'):
# These are the functions used with noblas=True
# TODO: move these functions elsewhere so that
# they can be used for unit tests

def op(o, m):
if o.upper() == 'N':
return m
if o.upper() == 'T':
return m.T
if o.upper() == 'C':
return m.conj().T
raise ValueError(f'unknown op: {o}')

[docs]    def rk(alpha, a, beta, c, trans='c'):  # noqa
if c.size == 0:
return
if beta == 0:
c[:] = 0.0
else:
c *= beta
if trans == 'n':
c += alpha * a.conj().T.dot(a)
else:
a = a.reshape((len(a), -1))
c += alpha * a.dot(a.conj().T)

def r2k(alpha, a, b, beta, c, trans='c'):  # noqa
if c.size == 0:
return
if beta == 0.0:
c[:] = 0.0
else:
c *= beta
if trans == 'c':
c += (alpha * a.reshape((len(a), -1))
.dot(b.reshape((len(b), -1)).conj().T) +
alpha * b.reshape((len(b), -1))
.dot(a.reshape((len(a), -1)).conj().T))
else:
c += alpha * (a.conj().T @ b + b.conj().T @ a)

[docs]    def mmm(alpha: T, a: np.ndarray, opa: str,  # noqa
b: np.ndarray, opb: str,
beta: T, c: np.ndarray) -> None:
if beta == 0.0:
c[:] = 0.0
else:
c *= beta
c += alpha * op(opa, a).dot(op(opb, b))

gemmdot = _gemmdot

elif not debug:
mmm = cgpaw.mmm  # noqa
rk = cgpaw.rk  # noqa
r2k = cgpaw.r2k  # noqa
gemmdot = _gemmdot

else:
def gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'):
assert a.flags.c_contiguous
assert b.flags.c_contiguous
assert a.dtype == b.dtype
if trans == 'n':
assert a.shape[-1] == b.shape[0]
else:
assert a.shape[-1] == b.shape[-1]
if out is not None:
assert out.flags.c_contiguous
assert a.dtype == out.dtype
assert a.ndim > 1 or b.ndim > 1
if trans == 'n':
assert out.shape == a.shape[:-1] + b.shape[1:]
else:
assert out.shape == a.shape[:-1] + b.shape[:-1]
return _gemmdot(a, b, alpha, beta, out, trans)
```