# Source code for gpaw.point_groups.check

"""Symmetry checking code."""
import sys
from typing import Sequence, Any, Dict, List, Union

from ase import Atoms
import numpy as np
from numpy.linalg import inv, det, solve
from scipy.ndimage import map_coordinates

from . import PointGroup
from gpaw.hints import Array2D, Array3D

Axis = Union[str, Sequence[float], None]

[docs]class SymmetryChecker:
def __init__(self,
group: Union[str, PointGroup],
center: Sequence[float],
radius: float = 2.0,
x: Axis = None,
y: Axis = None,
z: Axis = None,
grid_spacing: float = 0.2):
"""Check point-group symmetries.

If a non-standard orientation is desired then two of
*x*, *y*, *z* can be specified.
"""
if isinstance(group, str):
group = PointGroup(group)
self.group = group
self.normalized_table = group.get_normalized_table()
self.points = sphere(radius, grid_spacing)
self.center = center
self.grid_spacing = grid_spacing
self.rotation = rotation_matrix([x, y, z])

[docs]    def check_atoms(self, atoms: Atoms, tol: float = 1e-5) -> bool:
"""Check if atoms have all the symmetries.

Unit of *tol* is Angstrom.
"""
numbers = atoms.numbers
positions = (atoms.positions - self.center).dot(self.rotation.T)
icell = np.linalg.inv(atoms.cell.dot(self.rotation.T))
for opname, op in self.group.operations.items():
P = positions.dot(op.T)
for i, pos in enumerate(P):
sdiff = (pos - positions).dot(icell)
sdiff -= sdiff.round() * atoms.pbc
dist2 = (sdiff.dot(atoms.cell)**2).sum(1)
j = dist2.argmin()
if dist2[j] > tol**2 or numbers[j] != numbers[i]:
return False
return True

[docs]    def check_function(self,
function: Array3D,
grid_vectors: Array2D = None) -> Dict[str, Any]:
"""Check function on uniform grid."""
if grid_vectors is None:
grid_vectors = np.eye(3)
dv = abs(det(grid_vectors))
norm1 = (function**2).sum() * dv
M = inv(grid_vectors).T
overlaps: List[float] = []
for op in self.group.operations.values():
op = self.rotation.T.dot(op).dot(self.rotation)
pts = (self.points.dot(op.T) + self.center).dot(M.T)
pts %= function.shape
values = map_coordinates(function, pts.T, mode='wrap')
if not overlaps:
values1 = values
overlaps.append(values.dot(values1) * self.grid_spacing**3)

reduced_overlaps = []
i1 = 0
for n in self.group.nops:
i2 = i1 + n
reduced_overlaps.append(sum(overlaps[i1:i2]) / n / overlaps)
i1 = i2

characters = solve(self.normalized_table.T, reduced_overlaps)
best = self.group.symmetries[characters.argmax()]

return {'symmetry': best,
'norm': norm1,
'overlaps': overlaps,
'characters': {symmetry: value
for symmetry, value
in zip(self.group.symmetries, characters)}}

[docs]    def check_band(self,
calc,
band: int,
spin: int = 0) -> Dict[str, Any]:
"""Check wave function from GPAW calculation."""
wfs = calc.get_pseudo_wave_function(band, spin=spin, pad=True)
grid_vectors = (calc.atoms.cell.T / wfs.shape).T
return self.check_function(wfs, grid_vectors)

[docs]    def check_calculation(self,
calc,
n1: int,
n2: int,
spin: int = 0,
output: str = '-') -> None:
"""Check several wave functions from GPAW calculation."""
lines = ['band    energy     norm  normcut     best     ' +
''.join(f'{sym:8}' for sym in self.group.symmetries)]
n2 = n2 or calc.get_number_of_bands()
for n in range(n1, n2):
dct = self.check_band(calc, n, spin)
best = dct['symmetry']
norm = dct['norm']
normcut = dct['overlaps']
eig = calc.get_eigenvalues(spin=spin)[n]
lines.append(
f'{n:4} {eig:9.3f} {norm:8.3f} {normcut:8.3f} {best:>8}' +
''.join(f'{x:8.3f}'
for x in dct['characters'].values()))

fd = sys.stdout if output == '-' else open(output, 'w')
fd.write('\n'.join(lines) + '\n')
if output != '-':
fd.close()

def sphere(radius: float, grid_spacing: float) -> Array2D:
"""Return sphere of grid-points.

>>> points = sphere(1.1, 1.0)
>>> points.shape
(7, 3)
"""
npts = int(radius / grid_spacing) + 1
x = np.linspace(-npts, npts, 2 * npts + 1) * grid_spacing
points = np.array(np.meshgrid(x, x, x, indexing='ij')).reshape((3, -1)).T
points = points[(points**2).sum(1) <= radius**2]
return points

def rotation_matrix(axes: Sequence[Axis]) -> Array3D:
"""Calculate rotation matrix.

>>> rotation_matrix(['-y', 'x', None])
array([[ 0, -1,  0],
[ 1,  0,  0],
[ 0,  0,  1]])
"""
if all(axis is None for axis in axes):
return np.eye(3)

j = -1
for i, axis in enumerate(axes):
if axis is None:
assert j == -1
j = i
assert j != -1

axes = [normalize(axis) if axis is not None else None
for axis in axes]
axes[j] = np.cross(axes[j - 2], axes[j - 1])

return np.array(axes)

def normalize(vector: Union[str, Sequence[float]]) -> Sequence[float]:
"""Normalize a vector.

The *vector* must be a sequence of three numbers or one of the following
strings: x, y, z, -z, -y, -z.
"""
if isinstance(vector, str):
if vector == '-':
return -np.array(normalize(vector[1:]))
return {'x': [1, 0, 0], 'y': [0, 1, 0], 'z': [0, 0, 1]}[vector]
return np.array(vector) / np.linalg.norm(vector)