Source code for site_analysis.pbc_utils

"""Utilities for handling periodic boundary conditions.

Includes PBC correction, reference-centre unwrapping, and incremental
PBC shift updates (with optional numba acceleration).
"""

from __future__ import annotations

from typing import Literal, overload

import numpy as np
from site_analysis.distances import frac_to_cart

from site_analysis._compat import HAS_NUMBA

[docs] def apply_legacy_pbc_correction(frac_coords: np.ndarray) -> np.ndarray: """Apply the legacy spread-based periodic boundary condition handling. If the range of fractional coordinates along x, y, or z exceeds 0.5, assume that the site wraps around the periodic boundary in that dimension. Fractional coordinates for that dimension that are less than 0.5 will be incremented by 1.0. Args: frac_coords: Array of fractional coordinates with shape (n, 3). Returns: Adjusted fractional coordinates with the same shape. Warning: This algorithm can produce incorrect results for sites spanning periodic boundaries in small unit cells. Consider using reference centre-based approaches for robust PBC handling. """ corrected_coords: np.ndarray = frac_coords.copy() for dim in range(3): spread = np.max(corrected_coords[:, dim]) - np.min(corrected_coords[:, dim]) if spread > 0.5: corrected_coords[corrected_coords[:, dim] < 0.5, dim] += 1.0 return corrected_coords
# Generate all 27 possible shifts: [-1,0,1] for each dimension _PERIODIC_SHIFTS = np.array([[dx, dy, dz] for dx in [-1, 0, 1] for dy in [-1, 0, 1] for dz in [-1, 0, 1]], dtype=np.int64) @overload def unwrap_vertices_to_reference_center( frac_coords: np.ndarray, reference_center: np.ndarray, lattice_matrix: np.ndarray, return_image_shifts: Literal[False] = ..., ) -> np.ndarray: ... @overload def unwrap_vertices_to_reference_center( frac_coords: np.ndarray, reference_center: np.ndarray, lattice_matrix: np.ndarray, return_image_shifts: Literal[True] = ..., ) -> tuple[np.ndarray, np.ndarray]: ...
[docs] def unwrap_vertices_to_reference_center( frac_coords: np.ndarray, reference_center: np.ndarray, lattice_matrix: np.ndarray, return_image_shifts: bool = False, ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Vectorised unwrapping of vertices to their closest periodic images relative to a reference centre. Args: frac_coords: Array of fractional coordinates with shape (n, 3). reference_center: Reference centre position for unwrapping. lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors. return_image_shifts: If True, also return the per-vertex integer image shifts (from ``_PERIODIC_SHIFTS``), separate from the uniform non-negative shift. Returns: Unwrapped fractional coordinates with the same shape, shifted to ensure all coordinates >= 0. If ``return_image_shifts`` is True, returns a tuple of (unwrapped_coords, image_shifts). """ if frac_coords.size == 0: if return_image_shifts: return frac_coords, np.zeros((0, 3), dtype=np.int64) return frac_coords n_vertices = len(frac_coords) vertex_images = frac_coords[:, np.newaxis, :] + _PERIODIC_SHIFTS[np.newaxis, :, :] ref_cart = frac_to_cart(reference_center, lattice_matrix) vertex_images_cart = frac_to_cart(vertex_images.reshape(n_vertices * 27, 3), lattice_matrix) distances = np.linalg.norm( vertex_images_cart - ref_cart, axis=1).reshape(n_vertices, 27) best_indices = np.argmin(distances, axis=1) image_shifts: np.ndarray = _PERIODIC_SHIFTS[best_indices] result = frac_coords + image_shifts min_coords = np.min(result, axis=0) uniform_shift = np.maximum(0, np.ceil(-min_coords)) result = result + uniform_shift if return_image_shifts: return result, image_shifts return np.asarray(result) # no-op; satisfies mypy no-any-return
[docs] def correct_pbc( frac_coords: np.ndarray, reference_center: np.ndarray | None, lattice_matrix: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Apply PBC correction to fractional coordinates. Selects the appropriate unwrapping strategy based on whether a reference centre is provided. When a reference centre is given, vertices are unwrapped to their closest periodic images relative to that centre. Otherwise, the legacy spread-based correction is applied. Args: frac_coords: Fractional coordinates, shape ``(n, 3)``. reference_center: Reference centre for unwrapping, or ``None`` for legacy spread-based correction. lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors. Passed to the reference-centre unwrapping path; unused by the legacy spread-based path. Returns: Tuple of ``(corrected_coords, image_shifts)`` where both have shape ``(n, 3)`` and ``image_shifts`` has ``int64`` dtype. """ if frac_coords.size == 0: return frac_coords, np.zeros((0, 3), dtype=np.int64) if reference_center is not None: return unwrap_vertices_to_reference_center( frac_coords, reference_center, lattice_matrix, return_image_shifts=True) corrected = apply_legacy_pbc_correction(frac_coords) image_shifts = np.round(corrected - frac_coords).astype(np.int64) return corrected, image_shifts
def _numpy_update_pbc_shifts( frac_coords: np.ndarray, cached_raw_frac: np.ndarray, image_shifts: np.ndarray, ) -> tuple[bool, np.ndarray, np.ndarray]: """Incrementally update cached PBC image shifts between frames. Full PBC unwrapping is expensive because it evaluates all 27 periodic images per coordinate. Between successive MD frames, atoms typically move by small amounts, so the integer image shifts rarely change. This function detects coordinate wraps (jumps of ~1.0 in fractional coordinates) and adjusts shifts accordingly, avoiding the full 27-image search. If any coordinate has moved by more than 0.3 fractional units (after accounting for wrapping), the cache is considered invalid and the caller should fall back to full PBC recomputation. Args: frac_coords: New raw fractional coordinates, shape ``(n, 3)``. cached_raw_frac: Previous raw fractional coordinates from the last frame, shape ``(n, 3)``. image_shifts: Current cached integer image shifts, shape ``(n, 3)``. Returns: Tuple of ``(cache_valid, new_vertex_coords, new_image_shifts)``. If ``cache_valid`` is False, the other values are undefined and the caller should perform a full recomputation. """ diff = frac_coords - cached_raw_frac wrapping = np.round(diff).astype(np.int64) physical_diff = diff - wrapping if not np.all(np.abs(physical_diff) < 0.3): return False, frac_coords, image_shifts new_shifts = image_shifts - wrapping shifted = frac_coords + new_shifts min_coords = np.min(shifted, axis=0) uniform = np.maximum(0, np.ceil(-min_coords)) return True, shifted + uniform, new_shifts if HAS_NUMBA: import numba # type: ignore @numba.njit(cache=True) # type: ignore[misc] def _numba_update_pbc_shifts( frac_coords: np.ndarray, cached_raw_frac: np.ndarray, image_shifts: np.ndarray, ) -> tuple[bool, np.ndarray, np.ndarray]: """JIT-compiled variant of ``_numpy_update_pbc_shifts``. Same algorithm with early exit on first invalid vertex, avoiding full array traversal on cache misses. See ``_numpy_update_pbc_shifts`` for full documentation. """ n = frac_coords.shape[0] new_shifts = np.empty((n, 3), dtype=np.int64) for i in range(n): for k in range(3): diff = frac_coords[i, k] - cached_raw_frac[i, k] w = int(np.round(diff)) physical = diff - w if physical >= 0.3 or physical <= -0.3: return False, frac_coords, image_shifts new_shifts[i, k] = image_shifts[i, k] - w shifted = np.empty((n, 3)) for i in range(n): for k in range(3): shifted[i, k] = frac_coords[i, k] + new_shifts[i, k] for k in range(3): min_val = shifted[0, k] for i in range(1, n): if shifted[i, k] < min_val: min_val = shifted[i, k] u = 0.0 if min_val < 0.0: u = np.ceil(-min_val) for i in range(n): shifted[i, k] += u return True, shifted, new_shifts update_pbc_shifts = _numba_update_pbc_shifts else: update_pbc_shifts = _numpy_update_pbc_shifts