"""Child contract observer for composite ACTUS contracts.
This module implements observers for child contracts in composite contract structures.
In ACTUS, composite contracts can observe events, states, and attributes of their
child contracts to determine their own behavior.
References:
ACTUS v1.1 Section 2.10 - Child Contract Observer
ACTUS v1.1 Section 3.4 - Composite Contract Types
"""
from abc import ABC, abstractmethod
from typing import Protocol, runtime_checkable
import jax.numpy as jnp
from jactus.core import ActusDateTime, ContractAttributes, ContractEvent, ContractState
[docs]
@runtime_checkable
class ChildContractObserver(Protocol):
"""Protocol for observing child contract data in composite contracts.
The child contract observer provides access to events, states, and attributes
of child contracts within a composite structure. This enables parent contracts
to make decisions based on child contract behavior.
Methods correspond to ACTUS observer functions:
- observe_events: U_ev(i, t, a) - Observe child events at time t
- observe_state: U_sv(i, t, x, a) - Observe child state at time t
- observe_attribute: U_ca(i, x) - Observe child attribute value
Example:
>>> observer = MockChildContractObserver()
>>> observer.register_child("loan1", child_contract)
>>> events = observer.observe_events("loan1", current_time, attributes)
>>> state = observer.observe_state("loan1", current_time, None, attributes)
>>> notional = observer.observe_attribute("loan1", "notional_principal")
Note:
This protocol uses runtime_checkable to allow isinstance() checks.
Implementations should handle missing child contracts gracefully.
References:
ACTUS v1.1 Section 2.10 - Child Contract Observer
"""
[docs]
def observe_events(
self,
identifier: str,
time: ActusDateTime,
attributes: ContractAttributes | None = None,
) -> list[ContractEvent]:
"""Observe events from a child contract.
Returns all events from the identified child contract that occur at or
after the specified time. This allows parent contracts to react to
child contract events.
Args:
identifier: Child contract identifier
time: Observation time
attributes: Optional parent contract attributes for filtering
Returns:
List of child contract events
Example:
>>> events = observer.observe_events("child1", t0, parent_attrs)
>>> principal_events = [e for e in events if e.event_type == EventType.PR]
"""
...
[docs]
def observe_state(
self,
identifier: str,
time: ActusDateTime,
state: ContractState | None = None,
attributes: ContractAttributes | None = None,
) -> ContractState:
"""Observe state from a child contract.
Returns the state of the identified child contract at the specified time.
This allows parent contracts to access child contract state variables.
Args:
identifier: Child contract identifier
time: Observation time
state: Optional parent state for context
attributes: Optional parent attributes for context
Returns:
Child contract state at the specified time
Example:
>>> child_state = observer.observe_state("child1", t0, parent_state, parent_attrs)
>>> child_notional = child_state.nt
"""
...
[docs]
def observe_attribute(
self,
identifier: str,
attribute_name: str,
) -> jnp.ndarray:
"""Observe an attribute value from a child contract.
Returns the value of a specific attribute from the identified child contract.
This allows parent contracts to access child contract configuration.
Args:
identifier: Child contract identifier
attribute_name: Name of the attribute to observe
Returns:
Attribute value as JAX array
Example:
>>> notional = observer.observe_attribute("child1", "notional_principal")
>>> rate = observer.observe_attribute("child1", "nominal_interest_rate")
"""
...
[docs]
class BaseChildContractObserver(ABC):
"""Abstract base class for child contract observers.
Provides a common implementation pattern with error handling and validation.
Subclasses implement the abstract methods to define observation behavior.
Example:
>>> class MyObserver(BaseChildContractObserver):
... def _get_events(self, identifier, time, attributes):
... return self.children[identifier].get_events()
... def _get_state(self, identifier, time, state, attributes):
... return self.children[identifier].get_state(time)
... def _get_attribute(self, identifier, attribute_name):
... return getattr(self.children[identifier].attributes, attribute_name)
References:
ACTUS v1.1 Section 2.10 - Child Contract Observer
"""
@abstractmethod
def _get_events(
self,
identifier: str, # noqa: ARG002
time: ActusDateTime, # noqa: ARG002
attributes: ContractAttributes | None, # noqa: ARG002
) -> list[ContractEvent]:
"""Abstract method to retrieve events from child contract.
Subclasses must implement this to define how events are retrieved.
Args:
identifier: Child contract identifier
time: Observation time
attributes: Optional parent attributes
Returns:
List of child contract events
"""
raise NotImplementedError
@abstractmethod
def _get_state(
self,
identifier: str, # noqa: ARG002
time: ActusDateTime, # noqa: ARG002
state: ContractState | None, # noqa: ARG002
attributes: ContractAttributes | None, # noqa: ARG002
) -> ContractState:
"""Abstract method to retrieve state from child contract.
Subclasses must implement this to define how state is retrieved.
Args:
identifier: Child contract identifier
time: Observation time
state: Optional parent state
attributes: Optional parent attributes
Returns:
Child contract state
"""
raise NotImplementedError
@abstractmethod
def _get_attribute(
self,
identifier: str, # noqa: ARG002
attribute_name: str, # noqa: ARG002
) -> jnp.ndarray:
"""Abstract method to retrieve attribute from child contract.
Subclasses must implement this to define how attributes are retrieved.
Args:
identifier: Child contract identifier
attribute_name: Attribute name
Returns:
Attribute value as JAX array
"""
raise NotImplementedError
[docs]
def observe_events(
self,
identifier: str,
time: ActusDateTime,
attributes: ContractAttributes | None = None,
) -> list[ContractEvent]:
"""Observe events from a child contract with error handling.
Wrapper that adds validation and error handling around _get_events.
Args:
identifier: Child contract identifier
time: Observation time
attributes: Optional parent attributes
Returns:
List of child contract events
Raises:
KeyError: If child contract not found
ValueError: If observation fails
"""
try:
events = self._get_events(identifier, time, attributes)
if not isinstance(events, list):
raise ValueError(f"Expected list of events, got {type(events)}")
return events
except KeyError as e:
raise KeyError(f"Child contract not found: {identifier}") from e
[docs]
def observe_state(
self,
identifier: str,
time: ActusDateTime,
state: ContractState | None = None,
attributes: ContractAttributes | None = None,
) -> ContractState:
"""Observe state from a child contract with error handling.
Wrapper that adds validation and error handling around _get_state.
Args:
identifier: Child contract identifier
time: Observation time
state: Optional parent state
attributes: Optional parent attributes
Returns:
Child contract state
Raises:
KeyError: If child contract not found
ValueError: If observation fails
"""
try:
child_state = self._get_state(identifier, time, state, attributes)
if not isinstance(child_state, ContractState):
raise ValueError(f"Expected ContractState, got {type(child_state)}")
return child_state
except KeyError as e:
raise KeyError(f"Child contract not found: {identifier}") from e
[docs]
def observe_attribute(
self,
identifier: str,
attribute_name: str,
) -> jnp.ndarray:
"""Observe attribute from a child contract with error handling.
Wrapper that adds validation and error handling around _get_attribute.
Args:
identifier: Child contract identifier
attribute_name: Attribute name
Returns:
Attribute value as JAX array
Raises:
KeyError: If child contract not found
AttributeError: If attribute not found
"""
try:
value = self._get_attribute(identifier, attribute_name)
return jnp.asarray(value, dtype=jnp.float32)
except KeyError as e:
raise KeyError(f"Child contract not found: {identifier}") from e
except AttributeError as e:
raise AttributeError(
f"Attribute '{attribute_name}' not found in child contract '{identifier}'"
) from e
[docs]
class MockChildContractObserver(BaseChildContractObserver):
"""Mock implementation for child contract observation.
This observer stores child contract data in dictionaries and provides
simple observation capabilities. Useful for testing and development.
Attributes:
child_events: Dictionary mapping child IDs to event lists
child_states: Dictionary mapping child IDs to states
child_attributes: Dictionary mapping child IDs to attribute dicts
Example:
>>> observer = MockChildContractObserver()
>>> # Register child contract data
>>> observer.register_child("loan1",
... events=[event1, event2],
... state=loan_state,
... attributes={"notional_principal": 100000.0}
... )
>>> # Observe child data
>>> events = observer.observe_events("loan1", t0)
>>> state = observer.observe_state("loan1", t0)
>>> notional = observer.observe_attribute("loan1", "notional_principal")
Note:
This is a simple mock implementation. Real implementations would
integrate with actual contract simulation engines.
References:
ACTUS v1.1 Section 2.10 - Child Contract Observer
"""
[docs]
def __init__(self) -> None:
"""Initialize empty child contract observer."""
self.child_events: dict[str, list[ContractEvent]] = {}
self.child_states: dict[str, ContractState] = {}
self.child_attributes: dict[str, dict[str, float]] = {}
[docs]
def register_child(
self,
identifier: str,
events: list[ContractEvent] | None = None,
state: ContractState | None = None,
attributes: dict[str, float] | None = None,
) -> None:
"""Register a child contract with its data.
Args:
identifier: Child contract identifier
events: Optional list of child events
state: Optional child state
attributes: Optional dictionary of attribute name -> value
Example:
>>> observer.register_child("child1",
... events=[ContractEvent(...)],
... state=ContractState(...),
... attributes={"notional_principal": 100000.0}
... )
"""
if events is not None:
self.child_events[identifier] = events
if state is not None:
self.child_states[identifier] = state
if attributes is not None:
self.child_attributes[identifier] = attributes
def _get_events(
self,
identifier: str,
time: ActusDateTime,
attributes: ContractAttributes | None, # noqa: ARG002
) -> list[ContractEvent]:
"""Retrieve events from child contract.
Filters events to return only those at or after the specified time.
Args:
identifier: Child contract identifier
time: Observation time
attributes: Optional parent attributes (unused in mock)
Returns:
List of child events at or after time
"""
all_events = self.child_events[identifier]
# Filter events at or after the observation time
return [e for e in all_events if e.event_time >= time]
def _get_state(
self,
identifier: str,
time: ActusDateTime, # noqa: ARG002
state: ContractState | None, # noqa: ARG002
attributes: ContractAttributes | None, # noqa: ARG002
) -> ContractState:
"""Retrieve state from child contract.
Args:
identifier: Child contract identifier
time: Observation time (unused in mock)
state: Optional parent state (unused in mock)
attributes: Optional parent attributes (unused in mock)
Returns:
Child contract state
"""
return self.child_states[identifier]
def _get_attribute(
self,
identifier: str,
attribute_name: str,
) -> jnp.ndarray:
"""Retrieve attribute from child contract.
Args:
identifier: Child contract identifier
attribute_name: Attribute name
Returns:
Attribute value as JAX array
Raises:
AttributeError: If attribute not found
"""
if attribute_name not in self.child_attributes[identifier]:
raise AttributeError(f"Attribute '{attribute_name}' not found in child '{identifier}'")
value = self.child_attributes[identifier][attribute_name]
return jnp.array(value, dtype=jnp.float32)
[docs]
def apply_conditions(
self,
attributes: ContractAttributes,
overrides: dict[str, float],
) -> ContractAttributes:
"""Apply conditional attribute overrides to contract attributes.
This method temporarily modifies contract attributes based on child
contract observations. Used in composite contracts where parent
attributes depend on child state.
Args:
attributes: Original contract attributes
overrides: Dictionary of attribute names to new values
Returns:
New ContractAttributes with overrides applied
Example:
>>> # Override notional based on child contract
>>> child_notional = observer.observe_attribute("child1", "notional_principal")
>>> new_attrs = observer.apply_conditions(
... parent_attrs,
... {"notional_principal": float(child_notional)}
... )
Note:
This creates a new ContractAttributes instance rather than
modifying the original (immutability).
"""
# Create a dictionary of current attribute values
attr_dict = attributes.model_dump()
# Apply overrides
for key, value in overrides.items():
if key in attr_dict:
attr_dict[key] = value
# Create new ContractAttributes with updated values
return ContractAttributes(**attr_dict)
[docs]
class SimulatedChildContractObserver(BaseChildContractObserver):
"""Child contract observer backed by full simulation histories.
Stores the complete event history from simulated child contracts and
provides time-aware state lookups (returns state at the most recent
event at or before the query time).
Example:
>>> from jactus.contracts import create_contract
>>> observer = SimulatedChildContractObserver()
>>> child_result = child_contract.simulate()
>>> observer.register_simulation("child1", child_result.events, child_attrs)
>>> state = observer.observe_state("child1", query_time)
"""
[docs]
def __init__(self) -> None:
"""Initialize empty observer."""
# {id: [(time, state_post), ...]} sorted by time
self._histories: dict[str, list[tuple[ActusDateTime, ContractState]]] = {}
self._events: dict[str, list[ContractEvent]] = {}
self._attributes: dict[str, ContractAttributes] = {}
[docs]
def register_simulation(
self,
identifier: str,
events: list[ContractEvent],
attributes: ContractAttributes | None = None,
initial_state: ContractState | None = None,
) -> None:
"""Register a child contract's simulation results.
Args:
identifier: Child contract identifier
events: Full list of simulated events (with state_pre/state_post)
attributes: Optional child contract attributes
initial_state: Optional initial state (for queries before first event)
"""
self._events[identifier] = events
history = []
# If initial_state provided, add it at a very early time
if initial_state is not None:
history.append((initial_state.sd, initial_state))
for e in events:
if e.state_post is not None:
history.append((e.event_time, e.state_post))
self._histories[identifier] = history
if attributes is not None:
self._attributes[identifier] = attributes
def _get_events(
self,
identifier: str,
time: ActusDateTime,
attributes: ContractAttributes | None,
) -> list[ContractEvent]:
"""Return events at or after the given time."""
return [e for e in self._events[identifier] if e.event_time >= time]
def _get_state(
self,
identifier: str,
time: ActusDateTime,
state: ContractState | None,
attributes: ContractAttributes | None,
) -> ContractState:
"""Return the state at the most recent event at or before time."""
history = self._histories[identifier]
best_state = None
for event_time, event_state in history:
if event_time <= time:
best_state = event_state
else:
break
if best_state is None:
raise KeyError(f"No state found for '{identifier}' at or before {time.to_iso()}")
return best_state
def _get_attribute(
self,
identifier: str,
attribute_name: str,
) -> jnp.ndarray:
"""Return an attribute value from the child contract."""
if identifier not in self._attributes:
raise KeyError(f"No attributes registered for '{identifier}'")
attrs = self._attributes[identifier]
value = getattr(attrs, attribute_name, None)
if value is None:
raise AttributeError(f"Attribute '{attribute_name}' not found in '{identifier}'")
return jnp.array(float(value), dtype=jnp.float32)