Source code for site_analysis.reference_workflow.structure_aligner

"""Structure alignment tools for comparing and superimposing crystal structures.

This module provides the StructureAligner class, which finds the optimal
translation vector to superimpose one crystal structure onto another. This
alignment is important for:

1. Comparing structures from different sources with different coordinate origins
2. Analyzing structural changes while accounting for rigid translations
3. Preparing structures for site mapping in reference-based workflows

The alignment algorithm optimizes a translation vector to minimise distances
between corresponding atoms in the two structures, considering periodic
boundary conditions. It supports different optimisation metrics (RMSD or
maximum atom distance) and can align based on specific atom species.

This module is a key component of the reference-based workflow for defining
sites in one structure based on a template from another structure.
"""

import numpy as np
from pymatgen.core import Structure
from typing import Any, Callable
from site_analysis.tools import calculate_species_distances

[docs] class StructureAligner: """Aligns crystal structures via translation optimization. This class provides methods to align a reference structure to a target structure by finding the optimal translation vector that minimizes distances between corresponding atoms, considering periodic boundary conditions. """
[docs] def align(self, reference: Structure, target: Structure, species: list[str] | None = None, metric: str = 'rmsd', tolerance: float = 1e-4, algorithm: str = 'Nelder-Mead', minimizer_options: dict[str, Any] | None = None) -> tuple[Structure, np.ndarray, dict[str, float]]: """Align reference structure to target structure via translation. Finds the optimal translation vector that minimises distances between corresponding atoms in the two structures. Args: reference: Reference structure to translate. target: Target structure to align to. species: Species to include in alignment. If None, all species present in both structures are used. metric: Distance metric to optimise ('rmsd' or 'max_dist'). tolerance: Convergence tolerance for the optimiser. algorithm: Optimisation algorithm ('Nelder-Mead' or 'differential_evolution'). minimizer_options: Additional options passed to the optimiser. Returns: A tuple of (aligned_structure, translation_vector, metrics) where aligned_structure is the translated reference, translation_vector is the applied translation in fractional coordinates, and metrics is a dictionary of alignment quality measures. Raises: ValueError: If structures have incompatible compositions or if optimisation fails. """ # Extract arrays from Structure at the public boundary ref_frac_coords = reference.frac_coords target_frac_coords = target.frac_coords lattice_matrix = reference.lattice.matrix ref_species = [s.species_string for s in reference] target_species_list = [s.species_string for s in target] # Validate structures and get species to use valid_species = self._validate_structures( ref_species, target_species_list, species) # Create objective function objective_function = self._create_objective_function( ref_frac_coords, target_frac_coords, lattice_matrix, ref_species, target_species_list, valid_species, metric) # Run the appropriate optimiser using the dispatcher translation_vector = self._run_minimizer( algorithm, objective_function, tolerance, minimizer_options) # Apply the translation to get the aligned structure aligned_structure = self._apply_translation(reference, translation_vector) # Calculate final metrics using arrays aligned_coords = (ref_frac_coords + translation_vector) % 1.0 species_distances, all_distances = calculate_species_distances( aligned_coords, target_frac_coords, lattice_matrix, ref_species, target_species_list, species=valid_species) metrics = { 'rmsd': float(np.sqrt(np.mean(np.array(all_distances)**2))) if all_distances else float('inf'), 'max_dist': float(np.max(all_distances)) if all_distances else float('inf'), 'mean_dist': float(np.mean(all_distances)) if all_distances else float('inf'), } return aligned_structure, translation_vector, metrics
def _create_objective_function(self, ref_frac_coords: np.ndarray, target_frac_coords: np.ndarray, lattice_matrix: np.ndarray, ref_species: list[str], target_species: list[str], valid_species: list[str], metric: str) -> Callable[[np.ndarray], float]: """Create the objective function for optimisation. Args: ref_frac_coords: Fractional coordinates of the reference structure. target_frac_coords: Fractional coordinates of the target structure. lattice_matrix: Lattice matrix for distance calculations. ref_species: Species strings for each site in the reference. target_species: Species strings for each site in the target. valid_species: List of species to include in alignment. metric: Metric to optimise (``'rmsd'`` or ``'max_dist'``). Returns: Objective function that takes a translation vector and returns the distance metric value. """ def objective_function( translation_vector: np.ndarray) -> float: translation_vector = translation_vector % 1.0 translated_coords = (ref_frac_coords + translation_vector) % 1.0 _, all_distances = calculate_species_distances( translated_coords, target_frac_coords, lattice_matrix, ref_species, target_species, species=valid_species) if not all_distances: return float('inf') if metric == 'rmsd': return float(np.sqrt(np.mean(np.array(all_distances)**2))) elif metric == 'max_dist': return float(np.max(all_distances)) else: raise ValueError(f"Unknown metric: {metric}") return objective_function def _validate_structures(self, ref_species: list[str], target_species: list[str], species: list[str] | None = None) -> list[str]: """Validate that structures can be aligned and determine species to use. Args: ref_species: List of species strings for each site in the reference. target_species: List of species strings for each site in the target. species: Optional list of species to use for alignment. If ``None``, all species are used and compositions must match exactly. Returns: List of species to use for alignment. Raises: ValueError: If structures cannot be aligned. """ if species is None: ref_counts: dict[str, int] = {} for s in ref_species: ref_counts[s] = ref_counts.get(s, 0) + 1 target_counts: dict[str, int] = {} for s in target_species: target_counts[s] = target_counts.get(s, 0) + 1 if ref_counts != target_counts: raise ValueError( f"Structures have different compositions: " f"{ref_counts} vs {target_counts}" ) species_to_use = sorted(ref_counts.keys()) else: species_to_use = species for sp in species_to_use: ref_count = sum(1 for s in ref_species if s == sp) target_count = sum(1 for s in target_species if s == sp) if ref_count == 0: raise ValueError(f"Species {sp} not found in reference structure") if target_count == 0: raise ValueError(f"Species {sp} not found in target structure") if ref_count != target_count: raise ValueError( f"Different number of {sp} atoms: " f"{ref_count} in reference vs {target_count} in target" ) return species_to_use def _apply_translation(self, structure: Structure, translation_vector: np.ndarray) -> Structure: """Apply translation to entire structure. Args: structure: Structure to translate translation_vector: Translation vector to apply Returns: Translated structure """ # Create a copy of the structure new_structure: Structure = structure.copy() # Apply translation to all sites for i, site in enumerate(new_structure): frac_coords = site.frac_coords + translation_vector # Ensure coordinates are within [0, 1) frac_coords = frac_coords % 1.0 new_structure[i] = site.species, frac_coords return new_structure def _run_minimizer(self, algorithm: str, objective_function: Callable[[np.ndarray], float], tolerance: float, minimizer_options: dict[str, Any] | None = None) -> np.ndarray: """Run the selected minimization algorithm. Args: algorithm: Name of the algorithm to run objective_function: Function to minimize tolerance: Convergence tolerance minimizer_options: Additional options for the minimizer Returns: np.ndarray: Optimal translation vector Raises: ValueError: If the algorithm is not supported """ # Get the algorithm registry algorithm_registry = self._get_algorithm_registry() # Check if algorithm is supported if algorithm not in algorithm_registry: raise ValueError(f"Unsupported algorithm: {algorithm}. " f"Supported algorithms: {', '.join(algorithm_registry.keys())}") # Get the appropriate implementation method run_algorithm = algorithm_registry[algorithm] # Call the selected algorithm implementation return run_algorithm(objective_function, tolerance, minimizer_options) def _get_algorithm_registry(self) -> dict[str, Callable[ [Callable[[np.ndarray], float], float, dict[str, Any] | None ], np.ndarray]]: """Get the registry of supported optimization algorithms. Returns: dict: Dictionary mapping algorithm names to implementation methods """ return { 'Nelder-Mead': self._run_nelder_mead, 'differential_evolution': self._run_differential_evolution } def _run_nelder_mead(self, objective_function: Callable[[np.ndarray], float], tolerance: float, minimizer_options: dict[str, Any] | None = None) -> np.ndarray: """Run Nelder-Mead optimization. Args: objective_function: Function to minimize tolerance: Convergence tolerance minimizer_options: Additional options for the minimizer Returns: np.ndarray: Optimised translation vector Raises: ValueError: If optimization fails """ from scipy.optimize import minimize # Ensure minimizer_options is a dictionary minimizer_options = minimizer_options or {} # Default options - ensure they exactly match the original implementation options: dict[str, Any] = { 'xatol': tolerance, 'fatol': tolerance } # Update with user-provided options options.update(minimizer_options) # Run optimisation result = minimize( # type: ignore[call-overload] objective_function, x0=np.array([0, 0, 0]), # Start with zero translation method='Nelder-Mead', options=options ) if not result.success: raise ValueError(f"Optimization failed: {result.message}") # Ensure in [0,1) range return np.array(result.x) % 1.0 def _run_differential_evolution(self, objective_function: Callable[[np.ndarray], float], tolerance: float, minimizer_options: dict[str, Any] | None = None) -> np.ndarray: """Run differential evolution optimization. Args: objective_function: Function to minimize tolerance: Convergence tolerance minimizer_options: Additional options for the minimizer Returns: np.ndarray: Optimal translation vector """ from scipy.optimize import differential_evolution # Default options for differential evolution options = { 'tol': tolerance, 'popsize': 15, 'maxiter': 1000, 'strategy': 'best1bin', 'updating': 'immediate', 'workers': 1 # Default to single process for compatibility } # Bounds for translation vector (all components in [0,1)) bounds = [(0, 1), (0, 1), (0, 1)] # Update with user-provided options if minimizer_options: options.update(minimizer_options) # Extract bounds if provided in options if minimizer_options and 'bounds' in minimizer_options: bounds = minimizer_options['bounds'] options.pop('bounds') # Run optimization result = differential_evolution( objective_function, bounds=bounds, # type: ignore[arg-type] **options # type: ignore[arg-type] ) if not result.success: raise ValueError(f"Differential evolution optimization failed: {result.message}") return np.array(result.x) % 1.0 # Ensure in [0,1) range