"""Index mapping between reference and target crystal structures.
This module provides the IndexMapper class, which maps coordinating atom indices
from a reference structure to corresponding atoms in a target structure. This mapping
is essential when transferring site definitions from one structure to another,
particularly when atom orderings differ or when structures have been distorted.
The IndexMapper uses a distance-based approach to find the closest corresponding atoms
in the target structure for each reference atom, considering periodic boundary conditions.
It enforces a 1:1 mapping constraint to ensure that each reference atom maps to a distinct
target atom, which is necessary for preserving the topology of coordination environments.
The mapping can be filtered by atom species to ensure that atoms map only to atoms
of the same species, which is important for maintaining chemical validity when
mapping between structures with mixed compositions.
This module is a core component of the reference-based workflow, enabling the transfer
of site definitions between different structures or timesteps in a simulation.
"""
import numpy as np
from site_analysis.distances import all_mic_distances
[docs]
class IndexMapper:
"""Maps coordinating atom indices between reference and target crystal structures.
Used to translate coordination environments defined in an ideal reference
structure to corresponding atoms in a target structure, handling permutations
and structural distortions via distance-based matching.
The mapper verifies 1:1 correspondence between reference and target atoms,
ensuring each coordinating position maps to exactly one target atom. If
this constraint is violated, a ValueError is raised.
"""
[docs]
def map_coordinating_atoms(
self,
ref_frac_coords: np.ndarray,
target_frac_coords: np.ndarray,
lattice_matrix: np.ndarray,
ref_coordinating: list[list[int]],
target_species: list[str] | None = None,
species_filter: str | list[str] | None = None,
) -> list[list[int]]:
"""Map coordinating atom indices from reference to target structure.
Args:
ref_frac_coords: Fractional coordinates of the reference structure,
shape (N, 3).
target_frac_coords: Fractional coordinates of the target structure,
shape (M, 3).
lattice_matrix: Lattice matrix (3x3) for distance calculations.
ref_coordinating: List of coordinating atom index lists from reference.
Each sublist contains indices of atoms that define a site.
target_species: Full species list for all atoms in the target structure.
Required when using species_filter.
species_filter: Optional filter for which species to map to.
If specified, only maps to atoms of these species in the target
structure.
Returns:
List of coordinating atom index lists mapped to the target structure.
Maintains the same structure as input but with updated indices.
Raises:
ValueError: If 1:1 mapping cannot be achieved (e.g., missing atoms,
ambiguous distances, or insufficient target atoms in target structure).
"""
if target_species is not None and len(target_species) != len(target_frac_coords):
raise ValueError(
f"target_species length ({len(target_species)}) does not match "
f"target_frac_coords rows ({len(target_frac_coords)})"
)
# Extract all unique coordinating atoms from reference structure
unique_indices = self._extract_unique_coordinating_atoms(ref_coordinating)
# Create mapping from reference to target structure
index_mapping = self._find_closest_atom_mapping(
ref_frac_coords, target_frac_coords, lattice_matrix,
unique_indices, target_species, species_filter,
)
# Map the coordination lists using the established mapping
mapped_coordinating = self._apply_mapping(ref_coordinating, index_mapping)
return mapped_coordinating
def _extract_unique_coordinating_atoms(
self,
ref_coordinating: list[list[int]]
) -> list[int]:
"""Extract unique coordinating atom indices from coordination lists.
Args:
ref_coordinating: List of coordinating atom index lists.
Returns:
Sorted list of unique coordinating atom indices.
"""
unique_indices = set()
for coord_list in ref_coordinating:
unique_indices.update(coord_list)
return sorted(list(unique_indices))
def _find_closest_atom_mapping(
self,
ref_frac_coords: np.ndarray,
target_frac_coords: np.ndarray,
lattice_matrix: np.ndarray,
ref_indices: list[int],
target_species: list[str] | None,
species_filter: str | list[str] | None,
) -> dict[int, int]:
"""Find closest atom in target structure for each reference atom.
Args:
ref_frac_coords: Fractional coordinates of the reference structure,
shape (N, 3).
target_frac_coords: Fractional coordinates of the target structure,
shape (M, 3).
lattice_matrix: Lattice matrix (3x3) for distance calculations.
ref_indices: List of reference atom indices to map.
target_species: Full species list for all atoms in the target structure.
species_filter: Optional species filter for the target structure.
Returns:
Dictionary mapping reference indices to target indices.
Raises:
ValueError: If 1:1 mapping cannot be achieved.
"""
if not ref_indices:
return {}
# Ensure that species_filter is a list if it is specified
if isinstance(species_filter, str):
species_filter = [species_filter]
# Create a filtered list of target atoms in the target structure
if species_filter is not None:
if target_species is None:
raise ValueError(
"target_species must be provided when using species_filter"
)
target_mask = np.array([s in species_filter for s in target_species])
if not np.any(target_mask):
raise ValueError(
f"No atoms of species {species_filter} found in target structure"
)
else:
target_mask = np.ones(len(target_frac_coords), dtype=bool)
# Get coordinates of reference atoms to map
ref_coords = ref_frac_coords[ref_indices]
# Get coordinates of target atoms in the target structure
target_indices = np.where(target_mask)[0]
filtered_target_coords = target_frac_coords[target_indices]
# Calculate distances between reference and target atoms (with PBC)
dr_ij = all_mic_distances(ref_coords, filtered_target_coords, lattice_matrix)
# Find closest target atom for each reference atom
closest_indices = np.argmin(dr_ij, axis=1)
mapped_indices = target_indices[closest_indices]
# Check for 1:1 mapping violations
if len(mapped_indices) != len(np.unique(mapped_indices)):
# Find the duplicates
seen = set()
duplicates = []
for idx in mapped_indices:
if idx in seen:
duplicates.append(int(idx))
seen.add(idx)
raise ValueError(
f"1:1 mapping violation: Multiple reference atoms map to "
f"the same target atom(s) at indices {duplicates}"
)
# Create mapping dictionary
mapping = {ref_indices[i]: int(mapped_indices[i]) for i in range(len(ref_indices))}
return mapping
def _apply_mapping(
self,
ref_coordinating: list[list[int]],
index_mapping: dict[int, int]
) -> list[list[int]]:
"""Apply the index mapping to coordination lists.
Args:
ref_coordinating: Original coordination lists from reference.
index_mapping: Mapping from reference to target indices.
Returns:
Coordination lists with mapped indices.
"""
mapped_coordinating = []
for coord_list in ref_coordinating:
# Convert each index to Python int to ensure consistent types
mapped_list = [int(index_mapping[ref_idx]) for ref_idx in coord_list]
mapped_coordinating.append(mapped_list)
return mapped_coordinating