"""Transition table for storing labelled transition data.
Provides the :class:`TransitionTable` class, which stores transition
counts or probabilities as a labelled square matrix with convenient
access patterns.
"""
from __future__ import annotations
import html
from typing import Generic, Sequence, TypeVar
import numpy as np
TableKey = TypeVar('TableKey', int, str)
[docs]
class TransitionTable(Generic[TableKey]):
"""A labelled square matrix of transition data.
Stores transition counts or probabilities with named keys for
rows and columns. Provides multiple access patterns:
- ``.matrix`` — the raw (read-only) :class:`numpy.ndarray`
- ``.get(from_key, to_key)`` — key-based lookup
- ``.to_dict()`` — square dict-of-dicts
- ``.reorder(keys)`` — return a new table with reordered rows/columns
- ``.filter(keys)`` — return a new table with only the specified keys
Args:
keys: Row and column labels (site indices or site labels).
matrix: A square 2-D numpy array of transition values.
Raises:
ValueError: If *matrix* is not 2-D and square, if
``len(keys) != matrix.shape[0]``, if *keys* contains
duplicates, or if *matrix* has a non-numeric dtype.
"""
__slots__ = ('_keys', '_matrix', '_key_to_index', '_frozen')
def __init__(
self,
keys: tuple[TableKey, ...],
matrix: np.ndarray,
) -> None:
self._matrix = np.array(matrix, copy=True)
if self._matrix.ndim != 2 or self._matrix.shape[0] != self._matrix.shape[1]:
raise ValueError(
f"matrix must be a square 2-D array, got shape {self._matrix.shape}"
)
if len(keys) != self._matrix.shape[0]:
raise ValueError(
f"len(keys) ({len(keys)}) != matrix dimension "
f"({self._matrix.shape[0]})"
)
if not np.issubdtype(self._matrix.dtype, np.number):
raise ValueError(
f"matrix must have a numeric dtype, got {self._matrix.dtype}"
)
if len(set(keys)) != len(keys):
raise ValueError("keys must not contain duplicates")
self._keys: tuple[TableKey, ...] = keys
self._key_to_index: dict[TableKey, int] = {k: i for i, k in enumerate(keys)}
self._matrix.flags.writeable = False
self._frozen = True
@property
def keys(self) -> tuple[TableKey, ...]:
"""Row and column labels."""
return self._keys
@property
def matrix(self) -> np.ndarray:
"""The transition data as a read-only 2-D numpy array."""
return self._matrix
[docs]
def get(self, from_key: TableKey, to_key: TableKey) -> int | float:
"""Look up a single transition value by key.
Args:
from_key: The source key (row).
to_key: The destination key (column).
Returns:
The transition value at ``(from_key, to_key)``.
Raises:
KeyError: If either key is not present in the table.
"""
try:
i = self._key_to_index[from_key]
except KeyError:
raise KeyError(from_key) from None
try:
j = self._key_to_index[to_key]
except KeyError:
raise KeyError(to_key) from None
value: int | float = self._matrix[i, j].item()
return value
[docs]
def to_dict(self) -> dict[TableKey, dict[TableKey, int | float]]:
"""Convert to a square dict-of-dicts.
Returns:
A dict ``{from_key: {to_key: value}}`` mirroring the matrix.
"""
return {
self._keys[i]: {
self._keys[j]: self._matrix[i, j].item()
for j in range(len(self._keys))
}
for i in range(len(self._keys))
}
[docs]
def reorder(self, keys: Sequence[TableKey]) -> TransitionTable[TableKey]:
"""Return a new table with rows and columns reordered.
Args:
keys: The desired key ordering. Must contain exactly
the same keys as the current table.
Returns:
A new :class:`TransitionTable` with reordered rows and columns.
Raises:
ValueError: If *keys* does not match the current key set exactly.
"""
new_keys: tuple[TableKey, ...] = tuple(keys)
new_key_set = set(new_keys)
if len(new_keys) != len(self._keys) or new_key_set != set(self._keys):
missing = [k for k in self._keys if k not in new_key_set]
extra = [k for k in new_keys if k not in self._key_to_index]
parts = []
if missing:
parts.append(f"missing keys: {missing!r}")
if extra:
parts.append(f"unknown keys: {extra!r}")
raise ValueError(
f"keys must be a permutation of the current keys; {'; '.join(parts)}"
)
order = [self._key_to_index[k] for k in new_keys]
reordered = self._matrix[np.ix_(order, order)]
return TransitionTable(keys=new_keys, matrix=reordered)
[docs]
def filter(self, keys: Sequence[TableKey]) -> TransitionTable[TableKey]:
"""Return a new table containing only the specified keys.
Extracts the requested rows and columns without re-normalising.
Rows and columns in the result follow the order given in *keys*.
Args:
keys: The keys to retain. Must be a subset of the current
keys with no duplicates.
Returns:
A new :class:`TransitionTable` with only the requested keys.
Raises:
ValueError: If *keys* contains unknown or duplicate keys.
"""
new_keys: tuple[TableKey, ...] = tuple(keys)
if len(new_keys) != len(set(new_keys)):
raise ValueError("keys must not contain duplicates")
unknown = [k for k in new_keys if k not in self._key_to_index]
if unknown:
raise ValueError(f"unknown keys: {unknown!r}")
if len(new_keys) == 0:
empty: np.ndarray = np.empty((0, 0), dtype=self._matrix.dtype)
return TransitionTable(keys=new_keys, matrix=empty)
order = [self._key_to_index[k] for k in new_keys]
filtered = self._matrix[np.ix_(order, order)]
return TransitionTable(keys=new_keys, matrix=filtered)
def __eq__(self, other: object) -> bool:
if not isinstance(other, TransitionTable):
return NotImplemented
return self._keys == other._keys and np.array_equal(self._matrix, other._matrix)
__hash__ = None # type: ignore[assignment]
def _formatted_cells(self) -> tuple[list[str], list[list[str]]]:
"""Return string keys and a grid of formatted cell values.
Auto-detects formatting from the matrix dtype: integers are
formatted with ``d``, floats with ``.3f``.
"""
fmt = 'd' if np.issubdtype(self._matrix.dtype, np.integer) else '.3f'
n = len(self._keys)
str_keys = [str(k) for k in self._keys]
cells = [[f'{self._matrix[i, j]:{fmt}}' for j in range(n)]
for i in range(n)]
return str_keys, cells
def __str__(self) -> str:
"""Return a formatted table of transition values."""
if len(self._keys) == 0:
return ''
str_keys, cells = self._formatted_cells()
col_width = max(
max(len(k) for k in str_keys),
max(len(v) for row in cells for v in row),
)
header = ' ' * col_width + ''.join(v.rjust(col_width + 2) for v in str_keys)
rows = []
for key, row in zip(str_keys, cells):
row_values = ''.join(v.rjust(col_width + 2) for v in row)
rows.append(key.rjust(col_width) + row_values)
return header + '\n' + '\n'.join(rows)
def _repr_html_(self) -> str:
"""Return an HTML table for Jupyter notebook display."""
if len(self._keys) == 0:
return ''
str_keys, cells = self._formatted_cells()
esc_keys = [html.escape(k) for k in str_keys]
header_cells = ''.join(f'<th>{k}</th>' for k in esc_keys)
header = f'<tr><th></th>{header_cells}</tr>'
rows = []
for key, row in zip(esc_keys, cells):
row_cells = ''.join(f'<td>{v}</td>' for v in row)
rows.append(f'<tr><th>{key}</th>{row_cells}</tr>')
return f'<table>{header}{"".join(rows)}</table>'
def __repr__(self) -> str:
return (
f"TransitionTable(keys={self._keys!r}, "
f"shape={self._matrix.shape[0]}x{self._matrix.shape[1]})"
)
def __setattr__(self, name: str, value: object) -> None:
if getattr(self, '_frozen', False):
raise AttributeError("TransitionTable is immutable")
super().__setattr__(name, value)