"""Utilities for handling periodic boundary conditions.
Includes PBC correction, reference-centre unwrapping, and incremental
PBC shift updates (with optional numba acceleration).
"""
from __future__ import annotations
from typing import Literal, overload
import numpy as np
from site_analysis.distances import frac_to_cart
from site_analysis._compat import HAS_NUMBA
[docs]
def apply_legacy_pbc_correction(frac_coords: np.ndarray) -> np.ndarray:
"""Apply the legacy spread-based periodic boundary condition handling.
If the range of fractional coordinates along x, y, or z exceeds 0.5,
assume that the site wraps around the periodic boundary in that
dimension. Fractional coordinates for that dimension that are less
than 0.5 will be incremented by 1.0.
Args:
frac_coords: Array of fractional coordinates with shape (n, 3).
Returns:
Adjusted fractional coordinates with the same shape.
Warning:
This algorithm can produce incorrect results for sites spanning
periodic boundaries in small unit cells. Consider using reference
centre-based approaches for robust PBC handling.
"""
corrected_coords: np.ndarray = frac_coords.copy()
for dim in range(3):
spread = np.max(corrected_coords[:, dim]) - np.min(corrected_coords[:, dim])
if spread > 0.5:
corrected_coords[corrected_coords[:, dim] < 0.5, dim] += 1.0
return corrected_coords
# Generate all 27 possible shifts: [-1,0,1] for each dimension
_PERIODIC_SHIFTS = np.array([[dx, dy, dz] for dx in [-1, 0, 1]
for dy in [-1, 0, 1]
for dz in [-1, 0, 1]],
dtype=np.int64)
@overload
def unwrap_vertices_to_reference_center(
frac_coords: np.ndarray,
reference_center: np.ndarray,
lattice_matrix: np.ndarray,
return_image_shifts: Literal[False] = ...,
) -> np.ndarray: ...
@overload
def unwrap_vertices_to_reference_center(
frac_coords: np.ndarray,
reference_center: np.ndarray,
lattice_matrix: np.ndarray,
return_image_shifts: Literal[True] = ...,
) -> tuple[np.ndarray, np.ndarray]: ...
[docs]
def unwrap_vertices_to_reference_center(
frac_coords: np.ndarray,
reference_center: np.ndarray,
lattice_matrix: np.ndarray,
return_image_shifts: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Vectorised unwrapping of vertices to their closest periodic images relative to a reference centre.
Args:
frac_coords: Array of fractional coordinates with shape (n, 3).
reference_center: Reference centre position for unwrapping.
lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors.
return_image_shifts: If True, also return the per-vertex integer
image shifts (from ``_PERIODIC_SHIFTS``), separate from the
uniform non-negative shift.
Returns:
Unwrapped fractional coordinates with the same shape, shifted to
ensure all coordinates >= 0. If ``return_image_shifts`` is True,
returns a tuple of (unwrapped_coords, image_shifts).
"""
if frac_coords.size == 0:
if return_image_shifts:
return frac_coords, np.zeros((0, 3), dtype=np.int64)
return frac_coords
n_vertices = len(frac_coords)
vertex_images = frac_coords[:, np.newaxis, :] + _PERIODIC_SHIFTS[np.newaxis, :, :]
ref_cart = frac_to_cart(reference_center, lattice_matrix)
vertex_images_cart = frac_to_cart(vertex_images.reshape(n_vertices * 27, 3), lattice_matrix)
distances = np.linalg.norm(
vertex_images_cart - ref_cart, axis=1).reshape(n_vertices, 27)
best_indices = np.argmin(distances, axis=1)
image_shifts: np.ndarray = _PERIODIC_SHIFTS[best_indices]
result = frac_coords + image_shifts
min_coords = np.min(result, axis=0)
uniform_shift = np.maximum(0, np.ceil(-min_coords))
result = result + uniform_shift
if return_image_shifts:
return result, image_shifts
return np.asarray(result) # no-op; satisfies mypy no-any-return
[docs]
def correct_pbc(
frac_coords: np.ndarray,
reference_center: np.ndarray | None,
lattice_matrix: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Apply PBC correction to fractional coordinates.
Selects the appropriate unwrapping strategy based on whether a
reference centre is provided. When a reference centre is given,
vertices are unwrapped to their closest periodic images relative
to that centre. Otherwise, the legacy spread-based correction is
applied.
Args:
frac_coords: Fractional coordinates, shape ``(n, 3)``.
reference_center: Reference centre for unwrapping, or ``None``
for legacy spread-based correction.
lattice_matrix: (3, 3) lattice matrix where rows are lattice
vectors. Passed to the reference-centre unwrapping path;
unused by the legacy spread-based path.
Returns:
Tuple of ``(corrected_coords, image_shifts)`` where both have
shape ``(n, 3)`` and ``image_shifts`` has ``int64`` dtype.
"""
if frac_coords.size == 0:
return frac_coords, np.zeros((0, 3), dtype=np.int64)
if reference_center is not None:
return unwrap_vertices_to_reference_center(
frac_coords, reference_center, lattice_matrix,
return_image_shifts=True)
corrected = apply_legacy_pbc_correction(frac_coords)
image_shifts = np.round(corrected - frac_coords).astype(np.int64)
return corrected, image_shifts
def _numpy_update_pbc_shifts(
frac_coords: np.ndarray,
cached_raw_frac: np.ndarray,
image_shifts: np.ndarray,
) -> tuple[bool, np.ndarray, np.ndarray]:
"""Incrementally update cached PBC image shifts between frames.
Full PBC unwrapping is expensive because it evaluates all 27
periodic images per coordinate. Between successive MD frames,
atoms typically move by small amounts, so the integer image
shifts rarely change. This function detects coordinate wraps
(jumps of ~1.0 in fractional coordinates) and adjusts shifts
accordingly, avoiding the full 27-image search.
If any coordinate has moved by more than 0.3 fractional units
(after accounting for wrapping), the cache is considered invalid
and the caller should fall back to full PBC recomputation.
Args:
frac_coords: New raw fractional coordinates, shape ``(n, 3)``.
cached_raw_frac: Previous raw fractional coordinates from the
last frame, shape ``(n, 3)``.
image_shifts: Current cached integer image shifts, shape
``(n, 3)``.
Returns:
Tuple of ``(cache_valid, new_vertex_coords, new_image_shifts)``.
If ``cache_valid`` is False, the other values are undefined and
the caller should perform a full recomputation.
"""
diff = frac_coords - cached_raw_frac
wrapping = np.round(diff).astype(np.int64)
physical_diff = diff - wrapping
if not np.all(np.abs(physical_diff) < 0.3):
return False, frac_coords, image_shifts
new_shifts = image_shifts - wrapping
shifted = frac_coords + new_shifts
min_coords = np.min(shifted, axis=0)
uniform = np.maximum(0, np.ceil(-min_coords))
return True, shifted + uniform, new_shifts
if HAS_NUMBA:
import numba # type: ignore
@numba.njit(cache=True) # type: ignore[misc]
def _numba_update_pbc_shifts(
frac_coords: np.ndarray,
cached_raw_frac: np.ndarray,
image_shifts: np.ndarray,
) -> tuple[bool, np.ndarray, np.ndarray]:
"""JIT-compiled variant of ``_numpy_update_pbc_shifts``.
Same algorithm with early exit on first invalid vertex,
avoiding full array traversal on cache misses.
See ``_numpy_update_pbc_shifts`` for full documentation.
"""
n = frac_coords.shape[0]
new_shifts = np.empty((n, 3), dtype=np.int64)
for i in range(n):
for k in range(3):
diff = frac_coords[i, k] - cached_raw_frac[i, k]
w = int(np.round(diff))
physical = diff - w
if physical >= 0.3 or physical <= -0.3:
return False, frac_coords, image_shifts
new_shifts[i, k] = image_shifts[i, k] - w
shifted = np.empty((n, 3))
for i in range(n):
for k in range(3):
shifted[i, k] = frac_coords[i, k] + new_shifts[i, k]
for k in range(3):
min_val = shifted[0, k]
for i in range(1, n):
if shifted[i, k] < min_val:
min_val = shifted[i, k]
u = 0.0
if min_val < 0.0:
u = np.ceil(-min_val)
for i in range(n):
shifted[i, k] += u
return True, shifted, new_shifts
update_pbc_shifts = _numba_update_pbc_shifts
else:
update_pbc_shifts = _numpy_update_pbc_shifts