"""Base contract class for all ACTUS contracts.
This module implements the abstract base class that all ACTUS contract types
inherit from. It provides the core simulation engine and common functionality.
References:
ACTUS v1.1 Section 3 - Contract Types
ACTUS v1.1 Section 4 - Event Schedules
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
import flax.nnx as nnx
import jax.numpy as jnp
from jactus.core import (
ActusDateTime,
ContractAttributes,
ContractEvent,
ContractState,
EventSchedule,
)
from jactus.functions import PayoffFunction, StateTransitionFunction
from jactus.observers import ChildContractObserver, RiskFactorObserver
from jactus.observers.behavioral import BehaviorRiskFactorObserver, CalloutEvent
from jactus.observers.scenario import Scenario
[docs]
@dataclass
class SimulationHistory:
"""Results from contract simulation.
Contains the complete history of events and states from a contract
simulation run.
Attributes:
events: List of all events (scheduled + observed)
states: List of states (one per event, plus initial)
initial_state: Contract state before first event. When
``IED >= SD`` (contract hasn't started yet), this is the
pre-IED state with ``nt=0``. When ``IED < SD`` (contract
already existed at status date), this reflects the post-IED
state reconstructed from contract attributes (``nt=notional``,
``ipnr=rate``, etc.).
final_state: Contract state after last event
Example:
>>> history = contract.simulate(observers)
>>> print(f"Generated {len(history.events)} events")
>>> print(f"Final notional: {history.final_state.nt}")
"""
events: list[ContractEvent]
states: list[ContractState]
initial_state: ContractState
final_state: ContractState
[docs]
def get_cashflows(self) -> list[tuple[ActusDateTime, jnp.ndarray, str]]:
"""Extract cashflow timeline from events.
Returns:
List of (time, payoff, currency) tuples
Example:
>>> cashflows = history.get_cashflows()
>>> for time, amount, currency in cashflows:
... print(f"{time.to_iso()}: {amount} {currency}")
"""
return [(e.event_time, e.payoff, e.currency) for e in self.events]
[docs]
def filter_events(
self, start: ActusDateTime | None = None, end: ActusDateTime | None = None
) -> list[ContractEvent]:
"""Filter events by time range.
Args:
start: Optional start time (inclusive)
end: Optional end time (inclusive)
Returns:
List of events in the specified range
Example:
>>> year_events = history.filter_events(
... start=ActusDateTime(2024, 1, 1, 0, 0, 0),
... end=ActusDateTime(2024, 12, 31, 23, 59, 59)
... )
"""
filtered = self.events
if start is not None:
filtered = [e for e in filtered if e.event_time >= start]
if end is not None:
filtered = [e for e in filtered if e.event_time <= end]
return filtered
[docs]
class BaseContract(nnx.Module, ABC):
"""Abstract base class for all ACTUS contracts.
This class provides the core simulation engine and common functionality
that all contract types share. Subclasses must implement the abstract
methods to define contract-specific behavior.
The class extends flax.nnx.Module for Pytree compatibility with JAX.
Note: the scalar simulation path (this class) processes events sequentially
in Python and does not support JIT/grad/vmap. For JIT-compiled batch
simulation with autodiff, use the array-mode API (see ``*_array.py`` modules
and ``simulate_portfolio()``).
Attributes:
attributes: Contract attributes (terms and conditions)
risk_factor_observer: Observer for market risk factors
child_contract_observer: Optional observer for child contracts
_event_cache: Cached event schedule (None until first computation)
Example:
>>> class MyContract(BaseContract):
... def generate_event_schedule(self):
... # Generate contract-specific events
... pass
... # ... implement other abstract methods
>>> contract = MyContract(attributes, risk_observer)
>>> history = contract.simulate()
>>> cashflows = history.get_cashflows()
References:
ACTUS v1.1 Section 3 - Contract Types
ACTUS v1.1 Section 4 - Algorithm
"""
[docs]
def __init__(
self,
attributes: ContractAttributes,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: ChildContractObserver | None = None,
*,
rngs: nnx.Rngs | None = None,
):
"""Initialize base contract.
Args:
attributes: Contract attributes (terms and conditions)
risk_factor_observer: Observer for accessing market risk factors
child_contract_observer: Optional observer for child contracts
rngs: Optional Flax RNG state for stochastic contracts
Example:
>>> from jactus.observers import ConstantRiskFactorObserver
>>> attrs = ContractAttributes(...)
>>> risk_obs = ConstantRiskFactorObserver(1.0)
>>> contract = MyContract(attrs, risk_obs)
"""
super().__init__()
self.attributes = attributes
self.risk_factor_observer = risk_factor_observer
self.child_contract_observer = child_contract_observer
self.rngs = rngs if rngs is not None else nnx.Rngs(0)
self._event_cache: EventSchedule | None = None
# ========================================================================
# Abstract methods - must be implemented by subclasses
# ========================================================================
[docs]
@abstractmethod
def generate_event_schedule(self) -> EventSchedule:
"""Generate the scheduled events for this contract.
This method must be implemented by each contract type to generate
its specific event schedule according to ACTUS rules.
Returns:
EventSchedule containing all scheduled events
Example:
>>> def generate_event_schedule(self):
... events = []
... # Add IED event
... events.append(ContractEvent(
... event_type=EventType.IED,
... event_time=self.attributes.initial_exchange_date,
... ...
... ))
... # Add other events...
... return EventSchedule(events)
References:
ACTUS v1.1 Section 4.1 - Event Schedule Generation
"""
raise NotImplementedError
[docs]
@abstractmethod
def initialize_state(self) -> ContractState:
"""Initialize contract state before first event.
Creates the initial state based on contract attributes. This state
is used as the starting point for simulation.
Returns:
Initial ContractState
Example:
>>> def initialize_state(self):
... return ContractState(
... sd=self.attributes.status_date,
... tmd=self.attributes.maturity_date,
... nt=jnp.array(self.attributes.notional_principal),
... ipnr=jnp.array(self.attributes.nominal_interest_rate),
... ...
... )
References:
ACTUS v1.1 Section 4.2 - State Initialization
"""
raise NotImplementedError
[docs]
@abstractmethod
def get_payoff_function(self, event_type: Any) -> PayoffFunction:
"""Get payoff function for a specific event type.
Returns the appropriate payoff function (POF) for calculating
the cashflow generated by the given event type.
Args:
event_type: Event type (e.g., EventType.IP, EventType.PR)
Returns:
PayoffFunction for the event type
Example:
>>> def get_payoff_function(self, event_type):
... if event_type == EventType.IP:
... return InterestPaymentPayoff()
... elif event_type == EventType.PR:
... return PrincipalRedemptionPayoff()
... else:
... return ZeroPayoff()
References:
ACTUS v1.1 Section 2.7 - Payoff Functions
"""
raise NotImplementedError
[docs]
@abstractmethod
def get_state_transition_function(self, event_type: Any) -> StateTransitionFunction:
"""Get state transition function for a specific event type.
Returns the appropriate state transition function (STF) for
updating contract state when the given event occurs.
Args:
event_type: Event type (e.g., EventType.IP, EventType.PR)
Returns:
StateTransitionFunction for the event type
Example:
>>> def get_state_transition_function(self, event_type):
... if event_type == EventType.IP:
... return InterestPaymentSTF()
... elif event_type == EventType.PR:
... return PrincipalRedemptionSTF()
... else:
... return IdentitySTF()
References:
ACTUS v1.1 Section 2.8 - State Transition Functions
"""
raise NotImplementedError
# ========================================================================
# Concrete methods - common to all contracts
# ========================================================================
[docs]
def get_lifetime(self) -> tuple[ActusDateTime, ActusDateTime]:
"""Get contract lifetime (start and end dates).
Returns:
Tuple of (start_date, end_date)
Example:
>>> start, end = contract.get_lifetime()
>>> print(f"Contract runs from {start.to_iso()} to {end.to_iso()}")
"""
start = self.attributes.status_date
# End date logic - will be refined in D2.6
end = self.attributes.maturity_date or start
return start, end
[docs]
def is_maturity_contract(self) -> bool:
"""Check if contract has a defined maturity date.
Returns:
True if contract has maturity date, False otherwise
Example:
>>> if contract.is_maturity_contract():
... print("Contract matures at", contract.attributes.maturity_date)
"""
return self.attributes.maturity_date is not None
[docs]
def get_events(self, force_regenerate: bool = False) -> EventSchedule:
"""Get event schedule with caching.
Generates and caches the event schedule on first call. Subsequent
calls return the cached schedule unless force_regenerate=True.
Args:
force_regenerate: If True, regenerate schedule even if cached
Returns:
EventSchedule containing all scheduled events
Example:
>>> events = contract.get_events()
>>> print(f"Contract has {len(events)} events")
>>> # Regenerate if attributes changed
>>> events = contract.get_events(force_regenerate=True)
"""
if self._event_cache is None or force_regenerate:
self._event_cache = self.generate_event_schedule()
return self._event_cache
[docs]
def get_events_in_range(
self,
start: ActusDateTime | None = None,
end: ActusDateTime | None = None,
) -> list[ContractEvent]:
"""Get events within a time range.
Args:
start: Optional start time (inclusive)
end: Optional end time (inclusive)
Returns:
List of events in the specified range
Example:
>>> # Get all events in 2024
>>> events_2024 = contract.get_events_in_range(
... start=ActusDateTime(2024, 1, 1, 0, 0, 0),
... end=ActusDateTime(2024, 12, 31, 23, 59, 59)
... )
"""
schedule = self.get_events()
filtered = list(schedule.events)
if start is not None:
filtered = [e for e in filtered if e.event_time >= start]
if end is not None:
filtered = [e for e in filtered if e.event_time <= end]
return filtered
[docs]
def simulate(
self,
risk_factor_observer: RiskFactorObserver | None = None,
child_contract_observer: ChildContractObserver | None = None, # noqa: ARG002
scenario: Scenario | None = None,
behavior_observers: list[BehaviorRiskFactorObserver] | None = None,
) -> SimulationHistory:
"""Simulate contract through all events.
Executes the full ACTUS algorithm:
1. Collect callout events from behavioral observers (if any)
2. Merge callout events into the scheduled event timeline
3. Initialize state
4. For each event: calculate payoff (POF) using pre-event state,
apply state transition function (STF), and store event with states
Behavioral observers can be provided in three ways:
- Via a ``Scenario`` object (recommended for production use)
- Via the ``behavior_observers`` list parameter
- By passing a ``BehaviorRiskFactorObserver`` as the ``risk_factor_observer``
Args:
risk_factor_observer: Optional override for risk factor observer.
child_contract_observer: Optional override for child contract observer.
scenario: Optional Scenario bundling market + behavioral observers.
If provided, its market observer is used as the risk factor
observer (unless ``risk_factor_observer`` is also provided),
and its behavioral observers are activated for callout events.
behavior_observers: Optional list of behavioral observers to
activate for callout event injection.
Returns:
SimulationHistory with events and states.
Example:
>>> # Simple simulation (no behavioral models)
>>> history = contract.simulate()
>>>
>>> # With a scenario
>>> history = contract.simulate(scenario=my_scenario)
>>>
>>> # With explicit behavioral observers
>>> history = contract.simulate(
... behavior_observers=[prepayment_model],
... )
References:
ACTUS v1.1 Section 4 - Algorithm
"""
# Resolve risk factor observer
if scenario is not None and risk_factor_observer is None:
risk_obs = scenario.get_observer()
else:
risk_obs = risk_factor_observer or self.risk_factor_observer
# Collect all behavioral observers
all_behavior_observers: list[BehaviorRiskFactorObserver] = []
if scenario is not None:
all_behavior_observers.extend(scenario.behavior_observers.values())
if behavior_observers is not None:
all_behavior_observers.extend(behavior_observers)
if isinstance(risk_obs, BehaviorRiskFactorObserver):
all_behavior_observers.append(risk_obs)
# Initialize
state = self.initialize_state()
initial_state = state
# Get scheduled events
schedule = self.get_events()
# Collect and merge callout events from behavioral observers
if all_behavior_observers:
callout_events = _collect_callout_events(all_behavior_observers, self.attributes)
if callout_events:
schedule = _merge_callout_events(schedule, callout_events, self.attributes)
# Process each event
events_with_states = []
for event in schedule.events:
# Get functions for this event type
stf = self.get_state_transition_function(event.event_type)
pof = self.get_payoff_function(event.event_type)
# For CS (Calculate/Shift) BDC conventions, use the original
# unadjusted date for calculations (year fraction, accrual).
# For SC (Shift/Calculate) or NULL, calculation_time is None
# and we use event_time as before.
calc_time = event.calculation_time or event.event_time
# Calculate payoff BEFORE state transition (using pre-event state)
payoff = pof(
event_type=event.event_type,
state=state,
attributes=self.attributes,
time=calc_time,
risk_factor_observer=risk_obs,
)
# Apply state transition AFTER payoff calculation
state_post = stf(
event_type=event.event_type,
state_pre=state,
attributes=self.attributes,
time=calc_time,
risk_factor_observer=risk_obs,
)
# Create event with states and payoff
processed_event = ContractEvent(
event_type=event.event_type,
event_time=event.event_time,
payoff=payoff,
currency=event.currency or self.attributes.currency or "XXX",
state_pre=state,
state_post=state_post,
sequence=event.sequence,
)
events_with_states.append(processed_event)
state = state_post
return SimulationHistory(
events=events_with_states,
states=[e.state_post for e in events_with_states if e.state_post is not None],
initial_state=initial_state,
final_state=state,
)
[docs]
def get_cashflows(
self,
risk_factor_observer: RiskFactorObserver | None = None,
child_contract_observer: ChildContractObserver | None = None,
) -> list[tuple[ActusDateTime, jnp.ndarray, str]]:
"""Get cashflow timeline from contract.
Convenience method that simulates and extracts cashflows.
Args:
risk_factor_observer: Optional override for risk factor observer
child_contract_observer: Optional override for child contract observer
Returns:
List of (time, payoff, currency) tuples
Example:
>>> cashflows = contract.get_cashflows()
>>> total = sum(payoff for _, payoff, _ in cashflows)
>>> print(f"Total cashflows: {total}")
"""
history = self.simulate(risk_factor_observer, child_contract_observer)
return history.get_cashflows()
[docs]
def validate(self) -> dict[str, list[str]]:
"""Validate contract attributes.
Checks contract attributes for consistency and completeness.
Returns any validation errors or warnings.
Returns:
Dictionary with 'errors' and 'warnings' lists
Example:
>>> result = contract.validate()
>>> if result['errors']:
... print("Validation failed:", result['errors'])
>>> if result['warnings']:
... print("Warnings:", result['warnings'])
Note:
Base implementation performs basic checks. Subclasses should
override to add contract-specific validation.
"""
errors = []
warnings = []
# Check required attributes
if not self.attributes.contract_id:
errors.append("contract_id is required")
if not self.attributes.status_date:
errors.append("status_date is required")
# Check notional principal
if (
self.attributes.notional_principal is not None
and self.attributes.notional_principal <= 0
):
warnings.append("notional_principal should be positive")
# Check interest rate
if (
self.attributes.nominal_interest_rate is not None
and abs(self.attributes.nominal_interest_rate) > 1.0
):
warnings.append("nominal_interest_rate seems unusually high (>100%)")
return {"errors": errors, "warnings": warnings}
# ============================================================================
# Helper functions
# ============================================================================
[docs]
def sort_events_by_sequence(events: list[ContractEvent]) -> list[ContractEvent]:
"""Sort events by time and sequence number.
Events are sorted first by time, then by sequence number for events
at the same time. This ensures deterministic event ordering.
Args:
events: List of events to sort
Returns:
Sorted list of events
Example:
>>> events = [event3, event1, event2]
>>> sorted_events = sort_events_by_sequence(events)
>>> assert sorted_events[0].event_time <= sorted_events[1].event_time
Note:
This is a pure function - it does not modify the input list.
"""
return sorted(events, key=lambda e: (e.event_time, e.sequence))
[docs]
def merge_scheduled_and_observed_events(
scheduled: list[ContractEvent],
observed: list[ContractEvent],
) -> list[ContractEvent]:
"""Merge scheduled and observed events.
Combines scheduled events (from generate_event_schedule) with
observed events (from child contract or risk factor observers).
Removes duplicates and sorts by time and sequence.
Args:
scheduled: List of scheduled events
observed: List of observed events
Returns:
Merged and sorted list of events
Example:
>>> all_events = merge_scheduled_and_observed_events(
... scheduled_events,
... observed_events
... )
Note:
If two events have the same time and type, only the first is kept.
This prevents duplicate event processing.
"""
# Combine lists
all_events = scheduled + observed
# Remove duplicates (same time and event_type)
seen = set()
unique_events = []
for event in all_events:
key = (event.event_time, event.event_type)
if key not in seen:
seen.add(key)
unique_events.append(event)
# Sort by time and sequence
return sort_events_by_sequence(unique_events)
def _collect_callout_events(
behavior_observers: list[BehaviorRiskFactorObserver],
attributes: ContractAttributes,
) -> list[CalloutEvent]:
"""Collect callout events from all behavioral observers.
Calls ``contract_start()`` on each behavioral observer and aggregates
the returned callout events, sorted by time.
Args:
behavior_observers: List of behavioral observers.
attributes: Contract attributes.
Returns:
Sorted list of all callout events.
"""
all_events: list[CalloutEvent] = []
for observer in behavior_observers:
events = observer.contract_start(attributes)
all_events.extend(events)
return sorted(all_events, key=lambda e: e.time)
def _merge_callout_events(
schedule: EventSchedule,
callout_events: list[CalloutEvent],
attributes: ContractAttributes,
) -> EventSchedule:
"""Merge callout events into an existing event schedule.
Converts ``CalloutEvent`` objects into ``ContractEvent`` objects using
the ``PP`` (Principal Prepayment) event type for MRD callouts and
``AD`` (Analysis/Monitoring) for other callout types, then merges
them into the schedule.
Args:
schedule: Existing event schedule.
callout_events: Callout events to merge.
attributes: Contract attributes (for currency).
Returns:
New EventSchedule with callout events merged in.
"""
from jactus.core.events import EVENT_SEQUENCE_ORDER
from jactus.core.types import EventType
# Map callout types to ACTUS event types
callout_type_map: dict[str, EventType] = {
"MRD": EventType.PP, # Prepayment → Principal Prepayment event
"AFD": EventType.AD, # Deposit transaction → Analysis/Monitoring event
}
new_events = list(schedule.events)
existing_times_and_types = {(e.event_time, e.event_type) for e in schedule.events}
for callout in callout_events:
event_type = callout_type_map.get(callout.callout_type, EventType.AD)
key = (callout.time, event_type)
# Don't add duplicate events
if key not in existing_times_and_types:
new_event = ContractEvent(
event_type=event_type,
event_time=callout.time,
payoff=jnp.array(0.0),
currency=attributes.currency or "XXX",
sequence=EVENT_SEQUENCE_ORDER.get(event_type, 20),
)
new_events.append(new_event)
existing_times_and_types.add(key)
new_events.sort()
return EventSchedule(tuple(new_events), schedule.contract_id)