"""Abstract base class for site definitions in crystal structures.
This module defines the core Site abstraction, which represents a bounded volume
in a crystal structure that can contain zero or more atoms. The Site class
serves as the abstract base class that all specific site types (polyhedral,
spherical, Voronoi, etc.) in the site_analysis package must inherit from.
Concrete site implementations must override the abstract methods to define:
- How to determine whether a point is contained within the site
- How to calculate the center of the site
- Site-specific properties like coordination number
This class should not be instantiated directly; use one of the concrete
subclasses instead.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections import Counter
from typing import Any
from .atom import Atom
import numpy as np
[docs]
class Site(ABC):
"""Parent class for defining sites.
A Site is a bounded volume that can contain none, one, or more atoms.
This class defines the attributes and methods expected for specific
Site subclasses.
Attributes:
index (int): Numerical ID, intended to be unique to each site.
label (`str`: optional): Optional string given as a label for this site.
Default is `None`.
contains_atoms (list): list of the atoms contained by this site in the
structure last processed.
trajectory (list(list(int))): Nested list of atoms that have visited this
site at each timestep.
points (list): list of fractional coordinates for atoms assigned as
occupying this site.
transitions (collections.Counter): Stores observed transitions from this
site to other sites. Format is {index: count} with ``index`` giving
the index of each destination site, and ``count`` giving the number
of observed transitions to this site.
"""
_newid = 0
# Site._newid provides a counter that is incremented each time a
# new site is initialised. This allows each site to have a
# unique numerical index.
# Site._newid can be reset to 0 by calling Site.reset_index()
# with the default arguments.
def __init__(self,
label: str | None=None) -> None:
"""Initialise a Site object.
Args:
label (`str`: optional): Optional string used to label this site.
Returns:
None
"""
self.index = Site._newid
Site._newid += 1
self.label = label
self.contains_atoms: list[int] = []
self.trajectory: list[list[int]] = []
self.points: list[np.ndarray] = []
self.transitions: Counter = Counter()
[docs]
def reset(self) -> None:
"""Reset the trajectory for this site.
Returns the contains_atoms and trajectory attributes
to empty lists.
Args:
None
Returns:
None
"""
self.contains_atoms = []
self.trajectory = []
self.transitions = Counter()
[docs]
@abstractmethod
def contains_point(self,
x: np.ndarray) -> bool:
"""Test whether the fractional coordinate x is contained by this site.
Args:
x: Fractional coordinate to test.
Returns:
True if the point is contained by this site.
"""
raise NotImplementedError('contains_point should be implemented '
'in the derived class')
[docs]
def contains_atom(self,
atom: Atom) -> bool:
"""Test whether this site contains a specific atom.
Args:
atom: The atom to test.
Returns:
True if the atom is contained by this site.
"""
return self.contains_point(atom.frac_coords)
[docs]
def as_dict(self) -> dict:
"""Json-serializable dict representation of this Site.
Args:
None
Returns:
(dict)
"""
d = {'index': self.index,
'contains_atoms': self.contains_atoms,
'trajectory': self.trajectory,
'points': self.points,
'transitions': self.transitions}
if self.label:
d['label'] = self.label
return d
[docs]
@classmethod
def from_dict(cls,
d: dict) -> Site:
"""Create a Site object from a dict representation.
Args:
d (dict): The dict representation of this Site.
Returns:
(Site)
"""
site = cls()
site.index = d['index']
site.trajectory = d['trajectory']
site.contains_atoms = d['contains_atoms']
site.points = d['points']
site.transitions = d['transitions']
site.label = d.get('label')
return site
@property
@abstractmethod
def centre(self) -> np.ndarray:
"""Returns the centre point of this site.
This method should be implemented in the derived subclass.
Args:
None
Returns:
None
"""
raise NotImplementedError('centre should be implemented '
'in the derived class')
[docs]
@classmethod
def reset_index(cls,
newid: int=0) -> None:
"""Reset the site index counter.
Args:
newid (`int`: optional): New starting index. Default is 0.
Returns:
None
"""
Site._newid = newid
@property
def coordination_number(self) -> int:
"""Returns the coordination number of this site.
This method should be implemented in the derived subclass.
Args:
None
Returns:
int
"""
raise NotImplementedError('coordination_number should be implemented '
'in the derived class')
@abstractmethod
def __repr__(self) -> str:
"""Return a string representation of this site.
This method should be implemented in the derived subclass.
Returns:
str: A string representation of the site including its
class name and important attributes.
"""
raise NotImplementedError('__repr__ should be implemented '
'in the derived class')
[docs]
def most_frequent_transitions(self):
"""Return list of site indices ordered by transition frequency (most common first).
Returns:
list[int]: Site indices sorted by transition count in descending order.
Returns empty list if no transitions have been recorded.
"""
return sorted(self.transitions.keys(), key=self.transitions.get, reverse=True)
@property
def average_occupation(self) -> float | None:
"""Calculate the average site occupation over the trajectory.
Returns the fraction of timesteps where the site was occupied
(contained at least one atom).
Returns:
float | None: Average occupation between 0.0 and 1.0, or None
if trajectory is empty (no data processed).
"""
if not self.trajectory:
return None
occupied_timesteps = sum(1 for timestep in self.trajectory if timestep)
return occupied_timesteps / len(self.trajectory)
[docs]
def residence_times(self,
filter_length: int = 0,
include_edge_runs: bool = False) -> tuple[int, ...]:
"""Compute per-atom residence time run lengths from the site trajectory.
For each atom that visits this site, builds a binary
occupied/not-occupied sequence across all timesteps, then extracts
the lengths of consecutive occupied runs. The result is a flat tuple
of all run lengths from all atoms.
By default, runs that touch the first or last timestep are excluded
because they are truncated by the trajectory boundary and
underestimate the true residence time. Set ``include_edge_runs=True``
to include them.
Args:
filter_length: Maximum interior gap length to fill before
computing run lengths. Gaps of ``filter_length`` or fewer
consecutive unoccupied frames are filled (treated as if the
atom remained in the site) provided the gap is flanked by
occupied frames from the same atom on both sides. Gaps at the
trajectory edges are never filled. Default is 0 (no
filtering).
include_edge_runs: Whether to include runs that touch the first
or last timestep. Default is False (exclude truncated runs).
Returns:
A tuple of consecutive-occupation run lengths for all atoms
that visit this site. Returns an empty tuple if the trajectory
is empty or the site is never occupied.
Raises:
ValueError: If ``filter_length`` is negative.
TypeError: If ``filter_length`` is not an integer.
Examples:
>>> site.trajectory = [[], [1], [1], [1], []]
>>> site.residence_times()
(3,)
Runs touching the trajectory boundary are excluded by default:
>>> site.trajectory = [[1], [1], [1], []]
>>> site.residence_times()
()
>>> site.residence_times(include_edge_runs=True)
(3,)
Short gaps can be filled before computing run lengths:
>>> site.trajectory = [[], [1], [], [1], []]
>>> site.residence_times()
(1, 1)
>>> site.residence_times(filter_length=1)
(3,)
"""
if not isinstance(filter_length, int):
raise TypeError(f"filter_length must be an integer, got {type(filter_length).__name__}")
if filter_length < 0:
raise ValueError(f"filter_length must be non-negative, got {filter_length}")
if not self.trajectory:
return ()
all_atoms: set[int] = set()
for timestep in self.trajectory:
all_atoms.update(timestep)
if not all_atoms:
return ()
n_timesteps = len(self.trajectory)
run_lengths: list[int] = []
for atom in sorted(all_atoms):
occupied = np.array([atom in ts for ts in self.trajectory], dtype=bool)
if filter_length > 0:
occupied = self._filter_gaps(occupied, filter_length)
# Extract run lengths of True values using diff
padded = np.concatenate(([False], occupied, [False]))
diffs = np.diff(padded.astype(int))
starts = np.where(diffs == 1)[0]
ends = np.where(diffs == -1)[0]
for s, e in zip(starts, ends):
if not include_edge_runs:
if s == 0 or e == n_timesteps:
continue
run_lengths.append(e - s)
return tuple(run_lengths)
@staticmethod
def _filter_gaps(occupied: np.ndarray, filter_length: int) -> np.ndarray:
"""Fill short interior gaps in an occupied/not-occupied boolean array.
Only gaps flanked by occupied frames on both sides are filled.
Gaps at the trajectory edges are never filled.
Args:
occupied: Boolean array where True means the atom is in the site.
filter_length: Maximum gap length to fill.
Returns:
A new boolean array with short interior gaps filled.
"""
result = np.array(occupied)
# Find gaps (runs of False)
padded = np.concatenate(([True], occupied, [True]))
diffs = np.diff(padded.astype(int))
gap_starts = np.where(diffs == -1)[0]
gap_ends = np.where(diffs == 1)[0]
for gs, ge in zip(gap_starts, gap_ends):
if gs == 0 or ge == len(occupied):
continue
if ge - gs <= filter_length:
result[gs:ge] = True
return result
[docs]
def summary(self, metrics: list[str] | None = None) -> dict:
"""Generate summary statistics and computed properties.
By default, returns commonly used metrics excluding any with None values.
When specific metrics are requested, they are included even if their values are None.
Args:
metrics: List of metrics to include, or None for defaults.
Available metrics:
- 'index': Site's unique identifier
- 'label': Site label (if set)
- 'site_type': Class name (e.g., 'SphericalSite')
- 'average_occupation': Fraction of timesteps occupied (0.0-1.0)
- 'transitions': Dict of transitions to other sites
Default behaviour (metrics=None):
Returns index, site_type, average_occupation, transitions.
Also includes label if set. Excludes any metrics with None values.
Returns:
dict: Summary statistics. Keys depend on requested metrics.
Raises:
ValueError: If any requested metrics are not available.
Examples:
>>> site.summary() # Default metrics, excluding None values
{'index': 0, 'site_type': 'SphericalSite', 'transitions': {}}
>>> site.summary(metrics=['index', 'average_occupation'])
{'index': 0, 'average_occupation': None} # Includes None when explicitly requested
"""
# Define available metrics
available_metrics = ['index', 'label', 'site_type', 'average_occupation', 'transitions']
# Track if we're using defaults
using_defaults = metrics is None
if metrics is None:
# Default: all metrics except label (only if present)
metrics = ['index', 'site_type', 'average_occupation', 'transitions']
if self.label is not None:
metrics.insert(1, 'label')
# Validate metrics
invalid_metrics = set(metrics) - set(available_metrics)
if invalid_metrics:
raise ValueError(f"Invalid metric(s): {invalid_metrics}. Available metrics: {available_metrics}")
# Build summary dict
summary_dict = {}
for metric in metrics:
value: Any
if metric == 'site_type':
value = self.__class__.__name__
elif metric == 'transitions':
value = dict(self.transitions)
else:
value = getattr(self, metric)
# Include if: value is not None OR user explicitly requested it
if value is not None or not using_defaults:
summary_dict[metric] = value
return summary_dict