"""Reference-based workflow for defining sites in crystal structures.
This module provides the ReferenceBasedSites class, which is the main orchestrator
for defining crystallographic sites in a target structure based on coordination
environments identified in a reference structure. This approach is particularly
useful for:
1. Analyzing structures with distortions relative to an ideal reference
2. Tracking specific site types across different structures or simulation frames
3. Creating consistent site definitions across diverse structures
The ReferenceBasedSites class integrates several components to accomplish this workflow:
- StructureAligner: Aligns the structures to find the optimal translation vector
- CoordinationEnvironmentFinder: Identifies coordination environments in the reference
- IndexMapper: Maps atom indices between reference and target structures
- SiteFactory: Creates appropriate site objects in the target structure
This approach lets users define sites based on ideal coordination environments in
a reference structure, then create corresponding sites in real or distorted structures
where those same environments might be harder to identify directly.
"""
import numpy as np
from typing import Any
from pymatgen.core import Structure
from site_analysis.reference_workflow.structure_aligner import StructureAligner
from site_analysis.reference_workflow.coord_finder import CoordinationEnvironmentFinder
from site_analysis.reference_workflow.index_mapper import IndexMapper
from site_analysis.reference_workflow.site_factory import SiteFactory
from site_analysis.polyhedral_site import PolyhedralSite
from site_analysis.dynamic_voronoi_site import DynamicVoronoiSite
[docs]
class ReferenceBasedSites:
"""Main orchestrator for defining sites using a reference structure approach.
This class ties together all components needed to define sites in crystal structures
using a reference structure as a template:
1. StructureAligner - to align the target structure to the reference
2. CoordinationEnvironmentFinder - to find coordination environments in the reference
3. IndexMapper - to map environments from reference to target structure
4. SiteFactory - to create appropriate site objects
Attributes:
reference_structure: Reference structure defining ideal coordination environments.
target_structure: Target structure where sites will be created.
aligned_structure: Reference structure translated to match the target
(None if align=False).
translation_vector: Translation vector used for alignment (None if align=False).
alignment_metrics: Metrics describing quality of structure alignment (None if align=False).
"""
def __init__(self,
reference_structure: Structure,
target_structure: Structure,
align: bool = True,
align_species: list[str] | None = None,
align_metric: str = 'rmsd',
align_algorithm: str = 'Nelder-Mead',
align_minimizer_options: dict[str, Any] | None = None,
align_tolerance: float = 1e-4) -> None:
"""Initialise ReferenceBasedSites with reference and target structures.
Args:
reference_structure: Reference structure defining ideal coordination environments
target_structure: Target structure where sites will be created
align: Whether to perform structure alignment. Default is True.
align_species: Species to use for alignment. Default is all species.
align_metric: Metric for alignment ('rmsd', 'max_dist'). Default is 'rmsd'.
align_algorithm: Algorithm for optimization ('Nelder-Mead', 'differential_evolution').
Default is 'Nelder-Mead'.
align_minimizer_options: Additional options for the minimizer. Default is None.
align_tolerance: Convergence tolerance for alignment optimizer. Default is 1e-4.
"""
self.reference_structure = reference_structure
self.target_structure = target_structure
# Eagerly extract arrays from structures
self._ref_frac_coords: np.ndarray = reference_structure.frac_coords
self._ref_lattice_matrix: np.ndarray = reference_structure.lattice.matrix
self._ref_species = [s.species_string for s in reference_structure]
self._target_frac_coords = target_structure.frac_coords
self._target_lattice_matrix = target_structure.lattice.matrix
self._target_species = [s.species_string for s in target_structure]
# Initialise alignment attributes
self.aligned_structure: Structure | None = None
self.translation_vector: np.ndarray | None = None
self.alignment_metrics: dict[str, float] | None = None
self._aligned_frac_coords: np.ndarray | None = None
# Perform alignment if requested
if align:
self._align_structures(
align_species,
align_metric,
align_algorithm,
align_minimizer_options,
align_tolerance
)
# These will be initialised on first use
self._coord_finder: CoordinationEnvironmentFinder | None = None
self._index_mapper: IndexMapper | None = None
self._site_factory: SiteFactory | None = None
@property
def _effective_ref_coords(self) -> np.ndarray:
"""Reference coordinates to use, preferring aligned if available."""
if self._aligned_frac_coords is not None:
return self._aligned_frac_coords
return self._ref_frac_coords
[docs]
def create_polyhedral_sites(self,
center_species: str,
vertex_species: str | list[str],
cutoff: float,
n_vertices: int,
label: str | None = None,
labels: list[str] | None = None,
target_species: str | list[str] | None = None,
use_reference_centers: bool = True) -> list[PolyhedralSite]:
"""Create PolyhedralSite objects based on coordination environments in the reference structure.
Args:
center_species: Species at the center of coordination environments
vertex_species: Species at vertices of coordination environments
cutoff: Cutoff distance for coordination environment (required)
n_vertices: Number of vertices per environment (required)
label: Label to apply to all created sites. Default is None.
labels: List of labels for each site. Default is None.
target_species: Species to map to in the target structure. Default is None.
use_reference_centers: Whether to use reference centers for
PBC handling. See TrajectoryBuilder.with_polyhedral_sites() for details.
Default is True.
Returns:
List of PolyhedralSite objects
Raises:
ValueError: If coordination environments cannot be found or mapped,
or if both label and labels are provided.
"""
# Find coordination environments in reference structure
ref_environments = self._find_coordination_environments(
center_species=center_species,
coordination_species=vertex_species,
cutoff=cutoff,
n_coord=n_vertices
)
# Check we do not have repeat periodic images in the coordination environments
self._validate_unique_environments(ref_environments)
# Calculate reference centers if requested
if use_reference_centers:
center_indices = list(ref_environments.keys())
reference_centers = self._calculate_reference_centers_from_indices(center_indices)
else:
reference_centers = None
# Map environments to target structure
mapped_environments = self._map_environments(
list(ref_environments.values()),
target_species
)
# Create polyhedral sites
site_factory = self._initialise_site_factory()
sites = site_factory.create_polyhedral_sites(
mapped_environments,
reference_centers=reference_centers,
label=label,
labels=labels
)
return sites
[docs]
def create_dynamic_voronoi_sites(self,
center_species: str,
reference_species: str | list[str],
cutoff: float,
n_reference: int,
label: str | None = None,
labels: list[str] | None = None,
target_species: str | list[str] | None = None,
use_reference_centers: bool = True) -> list[DynamicVoronoiSite]:
"""Create DynamicVoronoiSite objects based on coordination environments in the reference structure.
Args:
center_species: Species at the center of coordination environments
reference_species: Species of reference atoms used to define the dynamic site centers
cutoff: Cutoff distance for finding reference atoms (required)
n_reference: Number of reference atoms per site (required)
label: Label to apply to all created sites. Default is None.
labels: List of labels for each site. Default is None.
target_species: Species to map to in the target structure. Default is None.
use_reference_centers: Whether to use reference centers for
PBC handling. See TrajectoryBuilder.with_polyhedral_sites() for details.
Default is True.
Returns:
List of DynamicVoronoiSite objects
Raises:
ValueError: If coordination environments cannot be found or mapped,
or if both label and labels are provided.
"""
# Find coordination environments in reference structure
ref_environments = self._find_coordination_environments(
center_species, reference_species, cutoff, n_reference
)
# Check we do not have repeat periodic images in the coordination environments
self._validate_unique_environments(ref_environments)
# Calculate reference centers if requested
if use_reference_centers:
center_indices = list(ref_environments.keys())
reference_centers = self._calculate_reference_centers_from_indices(center_indices)
else:
reference_centers = None
# Map environments to target structure
mapped_environments = self._map_environments(
list(ref_environments.values()),
target_species
)
# Create dynamic Voronoi sites
site_factory = self._initialise_site_factory()
sites = site_factory.create_dynamic_voronoi_sites(
mapped_environments,
reference_centers=reference_centers,
label=label,
labels=labels
)
return sites
def _align_structures(self,
align_species: list[str] | None = None,
align_metric: str = 'rmsd',
align_algorithm: str = 'Nelder-Mead',
align_minimizer_options: dict[str, Any] | None = None,
align_tolerance: float = 1e-4) -> None:
"""Align target structure to reference structure.
Args:
align_species: Species to use for alignment. Default is all species.
align_metric: Metric for alignment ('rmsd', 'max_dist'). Default is 'rmsd'.
align_algorithm: Algorithm for optimization ('Nelder-Mead', 'differential_evolution').
Default is 'Nelder-Mead'.
align_minimizer_options: Additional options for the minimizer. Default is None.
align_tolerance: Convergence tolerance for alignment optimizer. Default is 1e-4.
Raises:
ValueError: If alignment fails.
"""
try:
# Create a structure aligner
aligner = StructureAligner()
# Align structures
aligned_structure, translation_vector, metrics = aligner.align(
self.reference_structure,
self.target_structure,
species=align_species,
metric=align_metric,
tolerance=align_tolerance,
algorithm=align_algorithm,
minimizer_options=align_minimizer_options,
)
# Update attributes
self.aligned_structure = aligned_structure
self.translation_vector = translation_vector
self.alignment_metrics = metrics
self._aligned_frac_coords = aligned_structure.frac_coords
except ValueError as e:
# Re-raise with more context
raise ValueError(f"Structure alignment failed: {str(e)}") from e
def _find_coordination_environments(self,
center_species: str,
coordination_species: str | list[str],
cutoff: float,
n_coord: int) -> dict[int, list[int]]:
"""Find coordination environments in the reference structure.
Args:
center_species: Species at the center of coordination environments
coordination_species: Coordination atom species
cutoff: Cutoff distance for coordination environment
n_coord: Number of coordination atoms per environment
Returns:
Dictionary mapping center atom indices to lists of coordinating atom indices.
Keys are indices of center atoms, values are lists of coordinating atom indices.
Raises:
ValueError: If coordination environments cannot be found.
"""
try:
# Find coordination environments
coord_finder = self._initialise_coord_finder()
environments_dict = coord_finder.find_environments(
center_species=center_species,
coordination_species=coordination_species,
n_coord=n_coord,
cutoff=cutoff
)
return environments_dict
except ValueError as e:
# Re-raise with more context
raise ValueError(
f"Failed to find coordination environments for {center_species} centers "
f"and {coordination_species} coordinating atoms: {str(e)}"
) from e
def _map_environments(self,
ref_environments: list[list[int]],
target_species: str | list[str] | None = None) -> list[list[int]]:
"""Map coordination environments from reference to target structure.
Args:
ref_environments: List of environments from reference structure
target_species: Species to map to in the target structure. Default is None.
Returns:
List of mapped environments for the target structure
Raises:
ValueError: If environments cannot be mapped between structures.
"""
# If no environments were found, return an empty list immediately
if not ref_environments:
return []
try:
# Create index mapper if not already initialised
if self._index_mapper is None:
self._index_mapper = IndexMapper()
ref_coords = self._effective_ref_coords
lattice_matrix = self._ref_lattice_matrix
# Map environments
mapped_environments = self._index_mapper.map_coordinating_atoms(
ref_frac_coords=ref_coords,
target_frac_coords=self._target_frac_coords,
lattice_matrix=lattice_matrix,
ref_coordinating=ref_environments,
target_species=self._target_species,
species_filter=target_species,
)
return mapped_environments
except ValueError as e:
# Re-raise with more context
species_str = f" for {target_species} species" if target_species else ""
raise ValueError(
f"Failed to map coordination environments{species_str}: {str(e)}"
) from e
def _initialise_site_factory(self) -> SiteFactory:
"""Initialise the site factory if not already done.
Note: This method exists primarily for testing purposes.
"""
if self._site_factory is None:
self._site_factory = SiteFactory(self.target_structure)
return self._site_factory
def _initialise_coord_finder(self) -> CoordinationEnvironmentFinder:
"""Initialise the coordination environment finder if not already done."""
if self._coord_finder is None:
self._coord_finder = CoordinationEnvironmentFinder(self.reference_structure)
return self._coord_finder
def _validate_unique_environments(self, environments: dict[int, list[int]]) -> None:
"""Validate that each environment contains unique atom indices.
Args:
environments: Dict of environments, where keys are center atom indices
and values are lists of coordinating atom indices.
Raises:
ValueError: If any environment contains duplicate atom indices.
"""
for center_idx, env in environments.items():
if len(env) != len(set(env)):
# Find the duplicates
counts: dict[int, int] = {}
for idx in env:
counts[idx] = counts.get(idx, 0) + 1
duplicates = [idx for idx, count in counts.items() if count > 1]
raise ValueError(
f"Environment for center atom {center_idx} contains duplicate atom indices {duplicates}. "
f"This typically occurs in small unit cells where the same atom "
f"appears as a neighbor in multiple periodic images. "
f"Please use a larger supercell for your analysis."
)
def _calculate_reference_centers_from_indices(self, center_indices: list[int]) -> list[np.ndarray]:
"""Calculate reference centres from center atom indices."""
coords = self._effective_ref_coords
return [coords[i].copy() for i in center_indices]