Source code for site_analysis.dynamic_voronoi_site_collection

"""Collection manager for dynamic Voronoi sites in crystal structures.

This module provides the DynamicVoronoiSiteCollection class, which manages a
collection of DynamicVoronoiSite objects and implements methods for assigning
atoms to these sites based on their positions in a crystal structure.

The DynamicVoronoiSiteCollection extends the base SiteCollection class with
specific functionality for dynamic Voronoi sites, including:

1. Calculating the dynamic centres of sites based on reference atom positions
2. Assigning atoms to sites using Voronoi tessellation principles

For atom assignment, the collection:

1. First updates each site's centre by calculating the mean position of its
   reference atoms, with special handling for periodic boundary conditions
2. Calculates distances from each (dynamically determined) site centre to each atom
3. Assigns each atom to the site with the nearest centre
4. Uses minimum-image convention distances to correctly handle periodic
   boundaries

This collection is particularly useful for tracking sites in frameworks
that deform during simulation, as the site centres adapt to the changing
positions of the reference atoms.
"""

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field

import numpy as np
from pymatgen.core import Structure
from site_analysis.site_collection import SiteCollection
from site_analysis.site import Site
from site_analysis.dynamic_voronoi_site import DynamicVoronoiSite
from site_analysis.pbc_utils import correct_pbc
from site_analysis.atom import Atom
from site_analysis.distances import all_mic_distances


@dataclass
class _CentreGroup:
    """Batch arrays for a group of sites sharing the same n_reference.

    Owns the cached PBC shift state and the vectorised fast-path
    computation.  The collection orchestrates fallback (per-site)
    computation and distributes computed centres back to individual
    sites.

    Attributes:
        site_positions: Indices into the parent
            ``DynamicVoronoiSiteCollection.sites`` list for this group.
        ref_indices: ``(n_sites, n_ref)`` int array of reference atom
            indices.
        pbc_shifts: ``(n_sites, n_ref, 3)`` int, cached image shifts.
        cached_raw_frac: ``(n_sites, n_ref, 3)`` float, previous raw
            coords.
        initialised: Whether the PBC caches have been populated.
    """
    site_positions: list[int]
    ref_indices: np.ndarray
    pbc_shifts: np.ndarray = field(init=False, repr=False)
    cached_raw_frac: np.ndarray = field(init=False, repr=False)
    initialised: bool = field(init=False, default=False)

    def __post_init__(self) -> None:
        n_sites, n_ref = self.ref_indices.shape
        self.pbc_shifts = np.zeros((n_sites, n_ref, 3), dtype=np.int64)
        self.cached_raw_frac = np.zeros((n_sites, n_ref, 3))

    def try_fast_update(self, batch_ref: np.ndarray) -> np.ndarray | None:
        """Try the vectorised incremental shift update.

        If the group is initialised and all coordinate displacements
        since the last frame are below 0.3 fractional units, updates
        the cached shifts and returns the computed centres.  Otherwise
        returns ``None`` to signal that the caller should fall back to
        per-site full PBC computation.

        Args:
            batch_ref: Raw fractional coordinates for this group,
                shape ``(n_sites, n_ref, 3)``.

        Returns:
            Site centres as ``(n_sites, 3)`` array, or ``None`` if the
            fast path cannot be used.
        """
        if not self.initialised:
            return None
        diff = batch_ref - self.cached_raw_frac
        wrapping = np.round(diff).astype(np.int64)
        physical_diff = diff - wrapping
        if not np.all(np.abs(physical_diff) < 0.3):
            return None
        new_shifts = self.pbc_shifts - wrapping
        corrected = batch_ref + new_shifts
        # Shift each site's coords so all values are >= 0
        min_coords = np.min(corrected, axis=1)  # (n_sites, 3)
        non_neg = np.maximum(0, np.ceil(-min_coords))  # (n_sites, 3)
        corrected = corrected + non_neg[:, np.newaxis, :]
        self.pbc_shifts = new_shifts
        self.cached_raw_frac = batch_ref.copy()
        centres: np.ndarray = np.mean(corrected, axis=1) % 1.0
        return centres

    def initialise(self, batch_ref: np.ndarray) -> None:
        """Mark the group as initialised after fallback computation.

        Called once all ``pbc_shifts`` entries have been populated by
        per-site full PBC computation.

        Args:
            batch_ref: Raw fractional coordinates for this group,
                shape ``(n_sites, n_ref, 3)``.
        """
        self.cached_raw_frac = batch_ref.copy()
        self.initialised = True


[docs] class DynamicVoronoiSiteCollection(SiteCollection): """A collection of DynamicVoronoiSite objects. This collection manages a set of dynamic Voronoi sites and handles the assignment of atoms to sites based on their dynamically calculated centres. Attributes: sites (list[DynamicVoronoiSite]): list of DynamicVoronoiSite objects. """ def __init__(self, sites: list[Site]) -> None: """Create a DynamicVoronoiSiteCollection instance. Args: sites (list[DynamicVoronoiSite]): list of DynamicVoronoiSite objects. Returns: None Raises: TypeError: If any of the sites is not a DynamicVoronoiSite. """ for s in sites: if not isinstance(s, DynamicVoronoiSite): raise TypeError("All sites must be DynamicVoronoiSite instances") super(DynamicVoronoiSiteCollection, self).__init__(sites) self.sites: list[DynamicVoronoiSite] self._centre_groups: list[_CentreGroup] = self._build_centre_groups() def _build_centre_groups(self) -> list[_CentreGroup]: """Group sites by reference count for batch centre calculation.""" by_nref: dict[int, list[int]] = defaultdict(list) for i, site in enumerate(self.sites): by_nref[len(site.reference_indices)].append(i) groups: list[_CentreGroup] = [] for positions in by_nref.values(): ref_indices = np.array( [self.sites[i].reference_indices for i in positions]) groups.append(_CentreGroup( site_positions=positions, ref_indices=ref_indices, )) return groups def _batch_calculate_centres(self, all_frac_coords: np.ndarray, lattice_matrix: np.ndarray) -> None: """Compute all site centres in batch, grouped by reference count. For each group, tries the vectorised fast path first. If that fails (first frame, after reset, or large displacement), falls back to per-site full PBC computation. Args: all_frac_coords: Full fractional coordinate array from the structure, shape ``(n_atoms, 3)``. lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors. """ for group in self._centre_groups: batch_ref = all_frac_coords[group.ref_indices] # (n_sites, n_ref, 3) centres = group.try_fast_update(batch_ref) if centres is not None: for idx, pos in enumerate(group.site_positions): self.sites[pos]._centre_coords = centres[idx] continue # Fallback — per-site full PBC computation. for idx, pos in enumerate(group.site_positions): site = self.sites[pos] corrected, image_shifts = correct_pbc( batch_ref[idx], site.reference_center, lattice_matrix) site._centre_coords = np.mean(corrected, axis=0) % 1.0 group.pbc_shifts[idx] = image_shifts group.initialise(batch_ref)
[docs] def reset(self) -> None: """Reset all sites and batch PBC caches for a fresh analysis run.""" super().reset() for group in self._centre_groups: group.initialised = False
[docs] def analyse_structure(self, atoms: list[Atom], structure: Structure) -> None: """Analyse a structure to assign atoms to sites. Assigns coordinates to atoms, calculates site centres, and assigns atoms to the nearest site. Args: atoms: List of atoms to be assigned to sites. structure: Pymatgen Structure containing atom positions. """ all_frac_coords = structure.frac_coords for atom in atoms: atom.assign_coords(all_frac_coords) lattice_matrix = structure.lattice.matrix self._batch_calculate_centres(all_frac_coords, lattice_matrix) self.assign_site_occupations(atoms, lattice_matrix)
[docs] def assign_site_occupations(self, atoms: list[Atom], lattice_matrix: np.ndarray) -> None: """Assign atoms to sites based on Voronoi tessellation. Uses minimum-image convention distances to assign each atom to the nearest site centre. Args: atoms: List of atoms to be assigned to sites. lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors. """ self.reset_site_occupations() if not atoms: return site_coords = np.array([site.centre for site in self.sites]) atom_coords = np.array([atom.frac_coords for atom in atoms]) dist_matrix = all_mic_distances(site_coords, atom_coords, lattice_matrix) site_list_indices = np.argmin(dist_matrix, axis=0) for atom, site_list_index in zip(atoms, site_list_indices): site = self.sites[site_list_index] self.update_occupation(site, atom)