"""Base classes for collections of sites in crystal structures.
This module defines:
- ``SiteCollection``: abstract base class that all site collection types
must inherit from. Provides the interface for site-atom assignment and
common functionality for managing site occupations.
- ``PriorityAssignmentMixin``: mixin providing priority-based site
assignment ordering. Used by collection types that check sites one at
a time (polyhedral, spherical) but not by those that use global
distance-matrix assignment (Voronoi, dynamic Voronoi).
- ``_NearestSiteLookup``: precomputed lookup for finding the nearest
site to a given position.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Generic, NamedTuple, Sequence, TypeVar, TYPE_CHECKING
import numpy as np
from .atom import Atom
from .site import Site
class _NearestSiteLookup(NamedTuple):
"""Precomputed lookup for finding the nearest site to a given position.
Uses minimum-image convention in fractional space, which is only
geometrically exact for orthogonal cells.
"""
centres: np.ndarray
site_indices: list[int]
def nearest_site_index(self, frac_coords: np.ndarray) -> int:
"""Return the site index nearest to the given fractional coordinates.
Uses minimum-image convention in fractional space.
Args:
frac_coords: Fractional coordinates to find the nearest site for.
Returns:
The site index of the nearest site.
"""
diffs = self.centres - frac_coords
diffs -= np.round(diffs)
dists = np.linalg.norm(diffs, axis=1)
return self.site_indices[int(np.argmin(dists))]
SiteT = TypeVar('SiteT', bound=Site)
[docs]
class PriorityAssignmentMixin(Generic[SiteT]):
"""Mixin providing priority-based site assignment ordering.
Provides ``_get_priority_sites(atom)``, a generator that yields sites
in an optimised order based on recent site history, learned transitions,
and precomputed distance ranking.
Subclasses call ``_init_priority_ranking(centres, site_indices)`` from
their ``__init__`` to enable distance-ranked ordering. If not called,
the generator falls back to ``neighbouring_sites`` then arbitrary
order (used by ``PolyhedralSiteCollection`` when reference centres
are unavailable).
Note: distance ranking uses minimum-image convention in fractional
space, which is only geometrically exact for orthogonal cells. For
non-orthogonal cells the ranking is approximate, but correctness is
unaffected since all sites are eventually checked.
Expects to be mixed with ``SiteCollection`` which provides
``site_by_index``, ``neighbouring_sites``, and ``sites``.
"""
# Type stubs for the SiteCollection interface this mixin requires.
# These are provided by SiteCollection at runtime via MRO.
if TYPE_CHECKING:
sites: Sequence[SiteT]
def site_by_index(self, index: int) -> SiteT: ...
def neighbouring_sites(self, site_index: int) -> Sequence[SiteT]: ...
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._distance_ranked_sites: dict[int, list[int]] | None = None
self._nearest_site_lookup: _NearestSiteLookup | None = None
def _init_priority_ranking(self, centres: np.ndarray, site_indices: list[int]) -> None:
"""Precompute distance-ranked site ordering from the given centres.
Does nothing if ``centres`` is empty (zero sites).
Args:
centres: (N, 3) array of fractional coordinates for each site.
site_indices: Corresponding site indices.
"""
if len(centres) == 0:
return
ranked: dict[int, list[int]] = {}
for i, idx in enumerate(site_indices):
diffs = centres - centres[i]
diffs -= np.round(diffs)
dists = np.linalg.norm(diffs, axis=1)
order = np.argsort(dists)
ranked[idx] = [site_indices[j] for j in order if j != i]
self._distance_ranked_sites = ranked
self._nearest_site_lookup = _NearestSiteLookup(
centres=centres, site_indices=site_indices
)
def _get_priority_sites(self, atom: Atom) -> Generator[SiteT, None, None]:
"""Generator that yields sites in priority order for optimised atom assignment.
The generator picks an *anchor site* — the most recent site from
the atom's history, or the nearest site centre if no history
exists — and uses it to order the remaining sites.
The checking sequence depends on available information:
When trajectory history exists:
1. Most recently visited site, then previously visited site
2. Learned transition destinations from the anchor in
frequency order
3. Remaining sites by distance from anchor (if distance ranking
available), otherwise neighbours then arbitrary order
When no trajectory history exists:
- If distance ranking is available: nearest site centre
(anchor) first, then learned transitions, then
distance-ranked outward
- Otherwise: all sites in arbitrary order
Each site is yielded at most once.
Args:
atom: Atom object with recent site history used to determine
site priorities.
Yields:
Site: Sites in optimal checking order.
"""
checked_indices: set[int] = set()
anchor_index = None
recent = [s for s in atom._recent_sites if s is not None]
if recent:
anchor_index = recent[0]
for index in recent:
yield self.site_by_index(index)
checked_indices.add(index)
elif self._nearest_site_lookup is not None:
anchor_index = self._nearest_site_lookup.nearest_site_index(atom.frac_coords)
yield self.site_by_index(anchor_index)
checked_indices.add(anchor_index)
if anchor_index is not None:
# Learned transitions in frequency order
anchor_site = self.site_by_index(anchor_index)
for dest_index in anchor_site.most_frequent_transitions():
if dest_index not in checked_indices:
yield self.site_by_index(dest_index)
checked_indices.add(dest_index)
# Remaining sites
if self._distance_ranked_sites is not None:
for index in self._distance_ranked_sites[anchor_index]:
if index not in checked_indices:
yield self.site_by_index(index)
checked_indices.add(index)
else:
for neighbour_site in self.neighbouring_sites(anchor_index):
if neighbour_site.index not in checked_indices:
yield neighbour_site
checked_indices.add(neighbour_site.index)
for site in self.sites:
if site.index not in checked_indices:
yield site
else:
for site in self.sites:
yield site
[docs]
class SiteCollection(ABC):
"""Parent class for collections of sites.
Collections of specific site types should inherit from this class.
Attributes:
sites (list): List of ``Site``-like objects.
"""
def __init__(self, sites: Sequence[Site]) -> None:
"""Create a SiteCollection object.
Args:
sites (list): List of ``Site`` objects.
Raises:
ValueError: If there are duplicate site indices.
"""
self.sites = sites
# Create lookup dictionary for efficient site access by index
self._site_lookup: dict[int, Site] = {}
for site in sites:
if site.index in self._site_lookup:
raise ValueError(f"Duplicate site index detected: {site.index}. Site indices must be unique.")
self._site_lookup[site.index] = site
[docs]
@abstractmethod
def assign_site_occupations(self, atoms, lattice_matrix):
"""Assign atoms to sites.
Args:
atoms: List of Atom objects to be assigned to sites.
lattice_matrix: (3, 3) lattice matrix where rows are lattice
vectors.
Note:
The atom coordinates should already be consistent with the
structure. Recommended usage is via ``analyse_structure()``.
"""
raise NotImplementedError('assign_site_occupations should be implemented in'
' the derived class')
[docs]
@abstractmethod
def analyse_structure(self, atoms, structure):
"""Perform a site analysis for a set of atoms on a specific structure.
This method should be implemented in the derived subclass.
Args:
atoms (list(Atom)): List of Atom objects to be assigned to sites.
struture (pymatgen.Structure): Pymatgen Structure object used to specificy
the atomic coordinates.
Returns:
None
"""
raise NotImplementedError('analyse_structure should be implemented in the derived class')
[docs]
def neighbouring_sites(self, site_index):
"""If implemented, returns a list of sites that neighbour
a given site.
This method should be implemented in the derived subclass.
Args:
site_index (int): Index of the site to return a list of neighbours for.
"""
raise NotImplementedError('neighbouring_sites should be implemented'
'in the derived class')
[docs]
def site_by_index(self, index):
"""Returns the site with a specific index.
Args:
index (int): index for the site to be returned.
Returns:
(Site)
Raises:
ValueError: If a site with the specified index is not contained
in this SiteCollection.
"""
site = self._site_lookup.get(index)
if site is None:
raise ValueError(f'No site with index {index} found')
return site
[docs]
def update_occupation(self, site, atom):
"""Updates site and atom attributes for this atom occupying this site.
Args:
site (Site): The site to be updated.
atom (Atom): The atom to be updated.
Returns:
None
Notes:
This method does the following:
1. If the atom has moved to a new site, record a old_site --> new_site transition.
2. Add this atom's index to the list of atoms occupying this site.
3. Add this atom's fractional coordinates to the list of
coordinates observed occupying this site.
4. Assign this atom this site index.
"""
previous_site_index = None
if atom.trajectory:
previous_site_index = atom.trajectory[-1]
if previous_site_index is not None:
if previous_site_index != site.index: # this atom has moved
previous_site = self.site_by_index(previous_site_index)
previous_site.transitions[site.index] += 1
site.contains_atoms.append(atom.index)
site.points.append(atom.frac_coords)
atom.in_site = site.index
atom.update_recent_site(site.index)
[docs]
def reset(self) -> None:
"""Reset the collection and all its sites for a fresh analysis run.
Resets per-site state (occupations, trajectories, caches) via
``Site.reset()``. Subclasses may override to also clear
collection-level caches, but should call ``super().reset()``.
"""
for site in self.sites:
site.reset()
[docs]
def reset_site_occupations(self):
"""Occupations of all sites in this site collection are set as empty.
Args:
None
Returns:
None
"""
for s in self.sites:
s.contains_atoms = []
[docs]
def sites_contain_points(self,
points: np.ndarray,
all_frac_coords: np.ndarray,
lattice_matrix: np.ndarray) -> bool:
"""Check whether the set of sites contain corresponding points.
Args:
points: (N, 3) array of fractional coordinates.
One coordinate per site being checked.
all_frac_coords: Full fractional coordinate array, shape
``(n_atoms, 3)``.
lattice_matrix: (3, 3) lattice matrix where rows are lattice
vectors.
Returns:
True if every point is contained by its corresponding site.
"""
raise NotImplementedError('sites_contain_points() should be'
' implemented in the derived class')
[docs]
def summaries(self, metrics: list[str] | None = None) -> list[dict]:
"""Generate summary statistics for all sites in the collection.
Args:
metrics: List of metrics to include for each site. None returns
default metrics for each site.
Returns:
List of summary dicts, one per site, in site order.
"""
return [site.summary(metrics=metrics) for site in self.sites]