Source code for site_analysis.polyhedral_site

"""Polyhedral site representation for crystal structure analysis.

This module provides the PolyhedralSite class, which represents a site defined
by a polyhedron formed by a set of vertex atoms. These sites are commonly used
to represent coordination environments in crystal structures, such as tetrahedral
or octahedral sites.
"""

from __future__ import annotations

import itertools
import warnings

import numpy as np
from scipy.spatial import ConvexHull, Delaunay, QhullError
from site_analysis.site import Site
from site_analysis.tools import x_pbc
from site_analysis.atom import Atom
from site_analysis._compat import HAS_NUMBA
from site_analysis.pbc_utils import correct_pbc, update_pbc_shifts


if HAS_NUMBA:
    import numba  # type: ignore

    @numba.njit(cache=True)  # type: ignore[misc]
    def _numba_update_faces(
        vertex_coords: np.ndarray,
        face_simplices: np.ndarray,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """JIT-compiled face normal and centre sign computation.

        Args:
            vertex_coords: (N_vertices, 3) array of vertex positions.
            face_simplices: (N_faces, 3) array of vertex indices per face.

        Returns:
            Tuple of (face_normals, face_ref_points, centre_signs).
        """
        n_faces = face_simplices.shape[0]
        n_vertices = vertex_coords.shape[0]
        face_normals = np.empty((n_faces, 3))
        face_ref_points = np.empty((n_faces, 3))
        centre_signs = np.empty(n_faces)

        centre = np.zeros(3)
        for i in range(n_vertices):
            for k in range(3):
                centre[k] += vertex_coords[i, k]
        for k in range(3):
            centre[k] /= n_vertices

        for j in range(n_faces):
            i0 = face_simplices[j, 0]
            i1 = face_simplices[j, 1]
            i2 = face_simplices[j, 2]

            # Edge vectors from vertex 2
            e0x = vertex_coords[i0, 0] - vertex_coords[i2, 0]
            e0y = vertex_coords[i0, 1] - vertex_coords[i2, 1]
            e0z = vertex_coords[i0, 2] - vertex_coords[i2, 2]
            e1x = vertex_coords[i1, 0] - vertex_coords[i2, 0]
            e1y = vertex_coords[i1, 1] - vertex_coords[i2, 1]
            e1z = vertex_coords[i1, 2] - vertex_coords[i2, 2]

            # Cross product
            face_normals[j, 0] = e0y * e1z - e0z * e1y
            face_normals[j, 1] = e0z * e1x - e0x * e1z
            face_normals[j, 2] = e0x * e1y - e0y * e1x

            # Reference point (first vertex of face)
            for k in range(3):
                face_ref_points[j, k] = vertex_coords[i0, k]

            # Centre sign
            dot = 0.0
            for k in range(3):
                dot += face_normals[j, k] * (centre[k] - face_ref_points[j, k])
            centre_signs[j] = np.sign(dot)

        return face_normals, face_ref_points, centre_signs

    @numba.njit(cache=True)  # type: ignore[misc]
    def _numba_sn_query(
        x_pbc_points: np.ndarray,
        face_normals: np.ndarray,
        face_ref_points: np.ndarray,
        centre_signs: np.ndarray,
    ) -> bool:
        """JIT-compiled containment check using surface normals.

        For each PBC image point, checks whether the point lies on the
        same side as the polyhedron centre for every face. Uses early
        exit on first failing face for efficiency.

        Args:
            x_pbc_points: (N, 3) array of periodic boundary images.
            face_normals: (N_faces, 3) precomputed outward face normals.
            face_ref_points: (N_faces, 3) reference vertex per face.
            centre_signs: (N_faces,) sign of centre dot product per face.

        Returns:
            True if any PBC image is inside the polyhedron.
        """
        n_points = x_pbc_points.shape[0]
        n_faces = face_normals.shape[0]
        for i in range(n_points):
            all_match = True
            for j in range(n_faces):
                dot = 0.0
                for k in range(3):
                    dot += ((x_pbc_points[i, k] - face_ref_points[j, k])
                            * face_normals[j, k])
                # dot == 0.0 means the point lies exactly on the face
                # plane; np.sign(0.0) returns 0.0 which would never
                # match centre_signs (always +/-1), so we skip the
                # check to avoid falsely rejecting boundary points.
                if dot != 0.0 and np.sign(dot) != centre_signs[j]:
                    all_match = False
                    break
            if all_match:
                return True
        return False


[docs] class FaceTopologyCache: """Cached face topology and surface normal data for polyhedral containment. The face topology (which vertex triples form each face) is computed once from an initial ``ConvexHull`` and cached permanently -- it depends only on vertex connectivity, not on coordinates. Per-timestep, the face normals, reference points, and centre signs are recomputed from new vertex coordinates using the cached topology. Attributes: face_simplices: (N_faces, 3) array of vertex indices per face. """ def __init__(self, vertex_coords: np.ndarray) -> None: """Compute face topology from initial vertex coordinates. Args: vertex_coords: (N_vertices, 3) array of vertex positions. Used to build a ConvexHull and extract face connectivity. Raises: RuntimeError: If numba is not installed. """ if not HAS_NUMBA: raise RuntimeError( "FaceTopologyCache requires numba. " "Install it with: pip install site-analysis[fast]" ) hull = ConvexHull(vertex_coords) self.face_simplices: np.ndarray = hull.simplices self._face_normals: np.ndarray self._face_ref_points: np.ndarray self._centre_signs: np.ndarray self.update(vertex_coords)
[docs] def update(self, vertex_coords: np.ndarray) -> None: """Recompute face normals from cached topology and new coordinates. Called once per timestep after vertex coordinates are assigned. Args: vertex_coords: (N_vertices, 3) array of current vertex positions. """ self._face_normals, self._face_ref_points, self._centre_signs = ( _numba_update_faces(vertex_coords, self.face_simplices) )
[docs] def contains_point(self, x_pbc_points: np.ndarray) -> bool: """Test whether any PBC image point is inside the polyhedron. Args: x_pbc_points: (N, 3) array of periodic boundary images. Returns: True if any point is inside the polyhedron. """ return bool(_numba_sn_query( x_pbc_points, self._face_normals, self._face_ref_points, self._centre_signs, ))
[docs] class PolyhedralSite(Site): """Describes a site defined by the polyhedral volume enclosed by a set of vertex atoms. A PolyhedralSite determines whether atoms are inside the site volume by constructing a convex polyhedron from the vertex atoms and checking whether points lie within this polyhedron. The containment algorithm is selected automatically: when numba is available, a JIT-compiled surface normal method with cached face topology is used; otherwise, falls back to Delaunay tessellation via scipy. The polyhedron vertices are defined using atom indices in a structure, and their coordinates are assigned from the structure when needed. This allows the polyhedron shape to adapt to changes in the crystal structure. Attributes: vertex_indices (list[int]): List of integer indices for the vertex atoms (counting from 0). vertex_coords (np.ndarray or None): Fractional coordinates of the vertices. Set using assign_vertex_coords() or notify_structure_changed(). reference_center (np.ndarray or None): Optional reference centre for PBC handling. See Also: :class:`~site_analysis.site.Site`: Parent class documenting inherited attributes (index, label, contains_atoms, trajectory, points, transitions). """ def __init__(self, vertex_indices: list[int], label: str | None=None, reference_center: np.ndarray | None=None): """Create a PolyhedralSite instance. Args: vertex_indices: List of integer indices for the vertex atoms (counting from 0). label: Optional label for this site. reference_center: Optional reference centre for PBC handling. Returns: None Raises: ValueError: If vertex_indices is empty. TypeError: If any element in vertex_indices is not an integer. """ if isinstance(vertex_indices, np.ndarray): vertex_indices = vertex_indices.tolist() if not vertex_indices: raise ValueError("vertex_indices cannot be empty") if not all(isinstance(idx, int) for idx in vertex_indices): raise TypeError("All vertex indices must be integers") super(PolyhedralSite, self).__init__(label=label) self.vertex_indices = vertex_indices self.vertex_coords: np.ndarray | None = None self._delaunay: Delaunay | None = None self._face_topology_cache: FaceTopologyCache | None = None self._cache_stale: bool = True self._pending_frac_coords: np.ndarray | None = None self._pending_lattice_matrix: np.ndarray | None = None self._pbc_image_shifts: np.ndarray | None = None self._pbc_cached_raw_frac: np.ndarray | None = None self.reference_center = reference_center def __repr__(self) -> str: string = ('site_analysis.PolyhedralSite(' f'index={self.index}, ' f'label={self.label}, ' f'vertex_indices={self.vertex_indices}, ' f'contains_atoms={self.contains_atoms})') return string
[docs] def reset(self) -> None: """Reset the trajectory for this site. Resets the contains_atoms and trajectory attributes to empty lists. Vertex coordinates, Delaunay tessellation, and PBC shift caches are unset. The face topology cache is preserved as it depends only on vertex indices, which are immutable. """ super(PolyhedralSite, self).reset() self.vertex_coords = None self._delaunay = None self._cache_stale = True self._pending_frac_coords = None self._pending_lattice_matrix = None self._pbc_image_shifts = None self._pbc_cached_raw_frac = None
@property def delaunay(self) -> Delaunay: """Delaunay tessellation of the vertex coordinates for this site. This is calculated the first time the attribute is requested, and then stored for reuse, unless the site is reset. Returns: scipy.spatial.Delaunay """ if not self._delaunay: if self.vertex_coords is None: raise RuntimeError("Vertex coordinates have not been assigned.") self._delaunay = Delaunay(self.vertex_coords) return self._delaunay @property def coordination_number(self) -> int: """Coordination number for this site, defined as the number of vertices Returns: int """ return len(self.vertex_indices) @property def cn(self) -> int: """Coordination number for this site, defined as the number of vertices Convenience property for coordination_number() Returns: int """ return self.coordination_number
[docs] def notify_structure_changed(self, all_frac_coords: np.ndarray, lattice_matrix: np.ndarray) -> None: """Mark vertex coordinates as stale for lazy reassignment. Stores a reference to the full coordinate array so that PBC-corrected vertex coordinates can be computed on demand when ``contains_point`` is next called. Args: all_frac_coords: Full fractional coordinate array from the structure, shape ``(n_atoms, 3)``. lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors. """ self._pending_frac_coords = all_frac_coords self._pending_lattice_matrix = lattice_matrix
[docs] def assign_vertex_coords(self, all_frac_coords: np.ndarray, lattice_matrix: np.ndarray) -> None: """Assign fractional coordinates to the polyhedra vertices. Resets the cached Delaunay tessellation, so the next ``contains_point`` call will recompute it. Args: all_frac_coords: Full fractional coordinate array, shape ``(n_atoms, 3)``. lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors. Note: For bulk analysis prefer ``notify_structure_changed``, which pre-extracts coordinates once and defers PBC correction until the site is actually queried. """ frac_coords = all_frac_coords[self.vertex_indices] self._store_vertex_coords(frac_coords, lattice_matrix)
def _assign_from_pending(self, all_frac_coords: np.ndarray, lattice_matrix: np.ndarray) -> None: """Compute PBC-corrected vertex coords from pending data. Args: all_frac_coords: Full fractional coordinate array. lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors. """ self._pending_frac_coords = None self._pending_lattice_matrix = None frac_coords = all_frac_coords[self.vertex_indices] self._store_vertex_coords(frac_coords, lattice_matrix) def _store_vertex_coords(self, frac_coords: np.ndarray, lattice_matrix: np.ndarray) -> None: """Apply PBC correction and store vertex coordinates. On the first call (or after an anomalous displacement invalidates the cache), delegates to ``correct_pbc()`` for full PBC unwrapping. On subsequent calls, updates the cached integer image shifts incrementally by detecting coordinate wraps (jumps of ~1.0), avoiding the expensive 27-image distance search. Sets ``vertex_coords``, clears the Delaunay tessellation, and marks the face topology cache as stale. Args: frac_coords: Raw fractional coordinates of the vertices, shape ``(n_vertices, 3)``. lattice_matrix: (3, 3) lattice matrix where rows are lattice vectors. Passed to ``correct_pbc()``; unused by the incremental update path. """ if self._pbc_image_shifts is not None and self._pbc_cached_raw_frac is not None: valid, vertex_coords, new_shifts = update_pbc_shifts( frac_coords, self._pbc_cached_raw_frac, self._pbc_image_shifts) if valid: self._pbc_image_shifts = new_shifts self._pbc_cached_raw_frac = frac_coords.copy() self.vertex_coords = vertex_coords self._delaunay = None self._cache_stale = True return # Full computation — first call only (or after anomalous displacement) corrected, image_shifts = correct_pbc( frac_coords, self.reference_center, lattice_matrix) self._pbc_image_shifts = image_shifts self._pbc_cached_raw_frac = frac_coords.copy() self.vertex_coords = corrected self._delaunay = None self._cache_stale = True
[docs] def get_vertex_species(self, species: list[str]) -> list[str]: """Return species strings for this site's vertex atoms. Args: species: List of species strings for all atoms in the structure, indexed by atom index. Returns: Species strings for this site's vertex atoms. """ return [species[i] for i in self.vertex_indices]
[docs] def contains_point(self, x: np.ndarray, *, algo: str | None = None, pbc_images: np.ndarray | None = None) -> bool: """Test whether a specific point is enclosed by this polyhedral site. The containment algorithm is selected automatically based on available dependencies. When numba is installed, uses a JIT-compiled surface normal method. Otherwise, falls back to Delaunay tessellation. Args: x: Fractional coordinates of the point to test (length 3 array). algo: Deprecated. Previously selected the algorithm. Now ignored; the best available method is used automatically. pbc_images: Optional pre-computed PBC images of x, shape (N, 3). If provided, skips the internal ``x_pbc`` call. Returns: True if the point is inside the polyhedron. """ if algo is not None: warnings.warn( "The 'algo' parameter is deprecated and will be removed in a " "future version. The best available containment algorithm is " "now selected automatically.", DeprecationWarning, stacklevel=2, ) if self._pending_frac_coords is not None and self._pending_lattice_matrix is not None: self._assign_from_pending(self._pending_frac_coords, self._pending_lattice_matrix) if self.vertex_coords is None: raise RuntimeError( f'no vertex coordinates set for polyhedral_site {self.index}' ) x_images = pbc_images if pbc_images is not None else x_pbc(x) try: if HAS_NUMBA: if self._face_topology_cache is None: self._face_topology_cache = FaceTopologyCache(self.vertex_coords) self._cache_stale = False elif self._cache_stale: self._face_topology_cache.update(self.vertex_coords) self._cache_stale = False return self._face_topology_cache.contains_point(x_images) return self._contains_point_delaunay(x_images) except QhullError as e: raise RuntimeError( f"Degenerate vertex geometry for polyhedral_site {self.index} " f"(vertices {self.vertex_indices})" ) from e
def _contains_point_delaunay(self, x: np.ndarray) -> bool: """Test containment using Delaunay tessellation. Args: x: Fractional coordinates as (3,) or (N, 3) array. Returns: True if any point is inside a simplex of the tessellation. """ return bool(np.any(self.delaunay.find_simplex(x) >= 0))
[docs] def contains_atom(self, atom: Atom, *, algo: str | None = None, pbc_images: np.ndarray | None = None) -> bool: """Test whether an atom is inside this polyhedron. Args: atom: The atom to test. algo: Deprecated. Previously selected the algorithm. Now ignored; the best available method is used automatically. pbc_images: Optional pre-computed PBC images of the atom's fractional coordinates. If provided, passed through to ``contains_point`` to avoid redundant computation. Returns: True if the atom is inside the polyhedron. """ if algo is not None: warnings.warn( "The 'algo' parameter is deprecated and will be removed in a " "future version. The best available containment algorithm is " "now selected automatically.", DeprecationWarning, stacklevel=2, ) return self.contains_point(atom.frac_coords, pbc_images=pbc_images)
[docs] def as_dict(self) -> dict: d = super(PolyhedralSite, self).as_dict() d['vertex_indices'] = self.vertex_indices d['vertex_coords'] = self.vertex_coords return d
[docs] @classmethod def from_dict(cls, d): polyhedral_site = cls(vertex_indices=d['vertex_indices']) polyhedral_site.vertex_coords = d['vertex_coords'] polyhedral_site.contains_atoms = d['contains_atoms'] polyhedral_site.label = d.get('label') return polyhedral_site
@property def centre(self) -> np.ndarray: """Returns the fractional coordinates of the centre point of this polyhedral site. Args: None Returns: (np.array): (3,) numpy array. """ if self.vertex_coords is None: raise RuntimeError("Vertex coordinates have not been assigned.") centre_coords = np.mean(self.vertex_coords, axis=0) return np.array(centre_coords)
[docs] @classmethod def sites_from_vertex_indices(cls, vertex_indices: list[list[int]], label: str | None=None) -> list[PolyhedralSite]: sites = [cls(vertex_indices=vi, label=label) for vi in vertex_indices] return sites