"""Trajectory analysis for tracking site occupations over time.
This module provides the Trajectory class, which is responsible for analyzing
and tracking atom movements through crystallographic sites in a simulation
trajectory.
The Trajectory class manages the relationship between atoms and sites, analyzes
structures to assign atoms to sites, and records the movement history of atoms
between sites over time.
Key functionality includes:
- Assigning atoms to sites based on their positions in a structure
- Tracking atom migrations between sites over a sequence of structures
- Recording site occupation and transition data
- Supporting different site definitions via appropriate SiteCollection types
Note:
Trajectory objects should typically be created using the TrajectoryBuilder class
rather than directly instantiated. The builder provides an interface for
configuring all aspects of the trajectory:
>>> from site_analysis.builders import TrajectoryBuilder
>>> trajectory = (TrajectoryBuilder()
... .with_structure(structure)
... .with_mobile_species("Li")
... .with_spherical_sites(centres=[[0.5, 0.5, 0.5]], radii=[2.0])
... .build())
"""
import warnings
from collections import Counter
from collections.abc import Iterable
from typing import Sequence
import numpy as np
from tqdm.auto import tqdm
from pymatgen.core import Structure
from .transition_table import TableKey, TransitionTable
from .atom import Atom
from .dynamic_voronoi_site import DynamicVoronoiSite
from .dynamic_voronoi_site_collection import DynamicVoronoiSiteCollection
from .polyhedral_site import PolyhedralSite
from .polyhedral_site_collection import PolyhedralSiteCollection
from .site import Site
from .site_collection import SiteCollection
from .spherical_site import SphericalSite
from .spherical_site_collection import SphericalSiteCollection
from .voronoi_site import VoronoiSite
from .voronoi_site_collection import VoronoiSiteCollection
[docs]
class Trajectory:
"""Class for performing sites analysis on simulation trajectories."""
def __init__(self,
sites: Sequence[Site],
atoms: list[Atom]) -> None:
"""Initialize a Trajectory object for site analysis of simulation trajectories.
This constructor ensures all sites are of the same type and initializes the
appropriate site collection based on the type of sites provided.
Args:
sites: list of Site objects (must all be of the same type).
atoms: list of Atom objects to track during the trajectory analysis.
Raises:
ValueError: If sites or atoms list is empty.
TypeError: If sites contains mixed site types or an unrecognised site type.
"""
# Validate sites is not empty
if not sites:
raise ValueError("Cannot initialize Trajectory with empty sites list")
# Validate atoms is not empty
if not atoms:
raise ValueError("Cannot initialize Trajectory with empty atoms list")
# ensure that all sites are of the same type
if len(set([type(s) for s in sites])) > 1:
raise TypeError("A Trajectory cannot be initialised with mixed Site types")
# Map site types to their corresponding collection classes
site_collection_map: dict[type[Site], type[SiteCollection]] = {
PolyhedralSite: PolyhedralSiteCollection,
VoronoiSite: VoronoiSiteCollection,
SphericalSite: SphericalSiteCollection,
DynamicVoronoiSite: DynamicVoronoiSiteCollection
}
# Find the appropriate site collection class
site_type = type(sites[0])
try:
collection_class = site_collection_map[site_type]
except KeyError:
raise TypeError(f"Site type {site_type} not recognised for Trajectory initialisation")
self.site_collection = collection_class(sites)
self.sites = sites
self.atoms = atoms
self.timesteps: list[int] = []
self.atom_lookup = {a.index: i for i, a in enumerate(atoms)}
self.site_lookup = {s.index: i for i, s in enumerate(sites)}
[docs]
def atom_by_index(self,
i: int) -> Atom:
"""Return the atom with the specified index.
Args:
i: Index of the atom to return.
Returns:
The Atom object with the specified index.
"""
return self.atoms[self.atom_lookup[i]]
[docs]
def site_by_index(self,
i: int) -> Site:
"""Return the site with the specified index.
Args:
i: Index of the site to return.
Returns:
The Site object with the specified index.
"""
return self.sites[self.site_lookup[i]]
[docs]
def analyse_structure(self,
structure: Structure) -> None:
"""Analyse a structure to assign atoms to sites.
This delegates the analysis to the site collection's analyse_structure method.
Args:
structure: A pymatgen Structure object to be analysed.
"""
self.site_collection.analyse_structure(self.atoms, structure)
[docs]
def assign_site_occupations(self,
structure: Structure) -> None:
"""Assign atoms to sites for a specific structure.
This delegates the assignment to the site collection's assign_site_occupations method.
Args:
structure: A pymatgen Structure object to be analysed.
"""
self.site_collection.assign_site_occupations(self.atoms, structure.lattice.matrix)
[docs]
def site_coordination_numbers(self) -> Counter:
"""Return the coordination numbers of all sites.
Returns:
A Counter object mapping coordination numbers to their frequencies.
"""
return Counter([s.coordination_number for s in self.sites])
[docs]
def site_labels(self) -> list[str | None]:
"""Return the labels of all sites.
Returns:
A list of site labels (or None for sites without labels).
"""
return [s.label for s in self.sites]
@staticmethod
def _normalise_counts(counts: TransitionTable[TableKey]) -> TransitionTable[TableKey]:
"""Row-normalise a counts table into probabilities."""
count_data = counts.matrix.astype(float)
row_sums = count_data.sum(axis=1)
probs = np.zeros_like(count_data)
nonzero = row_sums > 0
probs[nonzero] = count_data[nonzero] / row_sums[nonzero, np.newaxis]
return TransitionTable(keys=counts.keys, matrix=probs)
@staticmethod
def _validate_destination(site_index: int, dest: int, valid_indices: set[int]) -> None:
"""Raise ValueError if dest is not in valid_indices."""
if dest not in valid_indices:
raise ValueError(
f"Site {site_index} has a transition to unknown "
f"site index {dest}."
)
[docs]
def transition_counts_by_site(
self,
*,
keys: Sequence[int] | None = None,
) -> TransitionTable[int]:
"""Return per-site transition counts as a :class:`TransitionTable`.
Args:
keys: Optional key ordering for rows and columns. If ``None``,
keys are sorted. Must be a permutation of the default keys.
Returns:
A :class:`TransitionTable` of integer counts keyed by site index.
Raises:
ValueError: If *keys* does not match the default key set, or
if a site has a transition to an unknown site index.
"""
site_keys = tuple(sorted(s.index for s in self.sites))
index_set = set(site_keys)
idx_lookup = {k: i for i, k in enumerate(site_keys)}
n = len(site_keys)
matrix = np.zeros((n, n), dtype=int)
for site in self.sites:
row = idx_lookup[site.index]
for dest, count in site.transitions.items():
self._validate_destination(site.index, dest, index_set)
matrix[row, idx_lookup[dest]] = count
table = TransitionTable(keys=site_keys, matrix=matrix)
if keys is not None:
return table.reorder(keys)
return table
[docs]
def transition_counts_by_label(
self,
*,
keys: Sequence[str] | None = None,
) -> TransitionTable[str]:
"""Return label-aggregated transition counts as a :class:`TransitionTable`.
Sites without labels are skipped. A warning is emitted if any
transitions are dropped as a result.
Args:
keys: Optional key ordering for rows and columns. If ``None``,
keys are sorted. Must be a permutation of the default keys.
Returns:
A :class:`TransitionTable` of integer counts keyed by site label.
Raises:
ValueError: If *keys* does not match the default key set, or
if a site has a transition to an unknown site index.
"""
all_site_indices = {s.index for s in self.sites}
index_to_label = {
s.index: s.label for s in self.sites if s.label is not None
}
label_keys = tuple(sorted(set(index_to_label.values())))
label_lookup = {k: i for i, k in enumerate(label_keys)}
n = len(label_keys)
matrix = np.zeros((n, n), dtype=int)
dropped = 0
for site in self.sites:
src_label = index_to_label.get(site.index)
if src_label is None:
for dest in site.transitions:
self._validate_destination(site.index, dest, all_site_indices)
dropped += sum(site.transitions.values())
continue
for dest, count in site.transitions.items():
self._validate_destination(site.index, dest, all_site_indices)
dest_label = index_to_label.get(dest)
if dest_label is None:
dropped += count
continue
matrix[label_lookup[src_label], label_lookup[dest_label]] += count
if dropped > 0:
warnings.warn(
f"{dropped} transition(s) involving unlabelled sites were "
f"excluded from the label-aggregated counts.",
stacklevel=2,
)
table = TransitionTable(keys=label_keys, matrix=matrix)
if keys is not None:
return table.reorder(keys)
return table
[docs]
def transition_probabilities_by_site(
self,
*,
keys: Sequence[int] | None = None,
) -> TransitionTable[int]:
"""Return per-site row-normalised transition probabilities.
Each row is normalised so that its values sum to 1.0. Rows with no
outgoing transitions remain as all zeros.
Args:
keys: Optional key ordering for rows and columns. If ``None``,
keys are sorted. Must be a permutation of the default keys.
Returns:
A :class:`TransitionTable` of float probabilities keyed by site index.
Raises:
ValueError: If *keys* does not match the default key set, or
if a site has a transition to an unknown site index.
"""
return self._normalise_counts(self.transition_counts_by_site(keys=keys))
[docs]
def transition_probabilities_by_label(
self,
*,
keys: Sequence[str] | None = None,
) -> TransitionTable[str]:
"""Return label-aggregated row-normalised transition probabilities.
Each row is normalised so that its values sum to 1.0. Rows with no
outgoing transitions remain as all zeros. Sites without labels are
skipped. A warning is emitted if any transitions are dropped.
Args:
keys: Optional key ordering for rows and columns. If ``None``,
keys are sorted. Must be a permutation of the default keys.
Returns:
A :class:`TransitionTable` of float probabilities keyed by site label.
Raises:
ValueError: If *keys* does not match the default key set, or
if a site has a transition to an unknown site index.
"""
return self._normalise_counts(self.transition_counts_by_label(keys=keys))
@property
def atom_sites(self) -> list[int | None]:
"""Return the sites that each atom currently occupies.
Returns:
A list of site indices (or None for unoccupied atoms), one for each atom.
"""
return [atom.in_site for atom in self.atoms]
@property
def site_occupations(self) -> list[list[int]]:
"""Return the atoms occupying each site.
Returns:
A list of lists, where each inner list contains the indices of atoms
occupying a site.
"""
return [s.contains_atoms for s in self.sites]
[docs]
def append_timestep(self,
structure: Structure,
t: int | None=None) -> None:
"""Append a new timestep to the trajectory.
This method:
1. Analyses the structure to assign atoms to sites
2. Updates the trajectory information for atoms and sites
3. Adds the timestep to the list of timesteps if provided
Args:
structure: A pymatgen Structure object for this timestep.
t: Optional timestep index to record. If None, no timestep is recorded.
"""
self.analyse_structure(structure)
for atom in self.atoms:
atom.trajectory.append(atom.in_site)
for site in self.sites:
site.trajectory.append(site.contains_atoms)
if t is not None:
self.timesteps.append(t)
[docs]
def reset(self) -> None:
"""Reset the trajectory.
This clears all trajectory information for atoms and sites and
resets the timesteps list.
"""
for atom in self.atoms:
atom.reset()
self.site_collection.reset()
self.timesteps = []
@property
def atoms_trajectory(self):
"""Return the trajectory of all atoms.
Returns:
A list of lists, where each inner list represents a timestep and
contains the site indices occupied by each atom at that timestep.
"""
return list(map(list, zip(*[atom.trajectory for atom in self.atoms])))
@property
def sites_trajectory(self):
"""Return the trajectory of all sites.
Returns:
A list of lists, where each inner list represents a timestep and
contains the atom indices occupying each site at that timestep.
"""
return list(map(list, zip(*[site.trajectory for site in self.sites])))
@property
def at(self):
"""Shorthand for atoms_trajectory.
Returns:
The atoms_trajectory property.
"""
return self.atoms_trajectory
@property
def st(self):
"""Shorthand for sites_trajectory.
Returns:
The sites_trajectory property.
"""
return self.sites_trajectory
[docs]
def trajectory_from_structures(
self,
structures: Sequence[Structure],
progress: bool = False,
) -> None:
"""Generate a trajectory from a list of structures.
Args:
structures: The structures to analyse, one per timestep.
progress: Show a progress bar. Automatically selects the
appropriate widget for terminal or notebook environments.
"""
generator: Iterable[tuple[int, Structure]] = enumerate(structures, 1)
if progress:
generator = tqdm(generator,
total=len(structures),
unit=' steps',
desc='Analysing trajectory')
for timestep, s in generator:
self.append_timestep(s, t=timestep)
def __len__(self):
"""Return the length of the trajectory.
Returns:
The number of timesteps in the trajectory.
"""
return len(self.timesteps)
[docs]
def site_summaries(self,
metrics: list[str] | None = None) -> list[dict]:
"""Generate summary statistics for all sites in this trajectory.
Args:
metrics: List of metrics to include for each site. None returns
default metrics for each site.
Returns:
List of summary dicts, one per site, in site order.
"""
return self.site_collection.summaries(metrics=metrics)
[docs]
def write_site_summaries(self,
filename: str,
metrics: list[str] | None = None) -> None:
"""Write site summaries to a JSON file.
Args:
filename: Path to output JSON file.
metrics: List of metrics to include for each site. None returns
default metrics for each site.
"""
import json
summaries = self.site_summaries(metrics=metrics)
with open(filename, 'w') as f:
json.dump(summaries, f, indent=2)
[docs]
def update_occupation(site, atom):
"""Update the occupation record for a site and atom pair.
This utility function updates the occupation records when an atom
is assigned to a site. It:
1. Adds the atom's index to the site's list of contained atoms
2. Sets the atom's in_site attribute to the site's index
Args:
site (Site): The site that contains the atom
atom (Atom): The atom to be assigned to the site
Returns:
None
Note:
This is a simplified version of the update_occupation method in
SiteCollection classes, used for direct assignments without
tracking transitions.
"""
site.contains_atoms.append(atom.index)
atom.in_site = site.index