"""Exotic Linear Amortizer (LAX) contract implementation.
This module implements the LAX contract type - the most complex amortizing contract
with flexible array schedules that allow varying principal redemption amounts, rates,
and cycles over the life of the contract.
ACTUS Reference:
ACTUS v1.1 Section 7.3 - LAX: Exotic Linear Amortizer
Key Features:
- Array schedules for principal redemption (ARPRANX, ARPRCL, ARPRNXT)
- Array schedules for interest payments (ARIPANX, ARIPCL)
- Array schedules for rate resets (ARRRANX, ARRRCL, ARRATE)
- Increase/decrease indicators (ARINCDEC) for principal changes
- Fixed/variable rate indicators (ARFIXVAR)
- PI (Principal Increase) and PR (Principal Redemption) events
- PRF (Principal Redemption Amount Fixing) events
- All IPCB modes from LAM
Array Schedule Concept:
Instead of a single cycle and anchor, LAX uses arrays to define multiple
sub-schedules with different parameters. For example:
- ARPRANX = [2024-01-15, 2025-01-15, 2026-01-15]
- ARPRCL = ["1M", "1M", "1M"]
- ARPRNXT = [1000, 2000, 3000]
- ARINCDEC = ["INC", "INC", "DEC"]
This creates increasing principal for 2 years, then decreasing.
Example:
>>> from jactus.contracts import create_contract
>>> from jactus.core import ContractAttributes, ContractType, ContractRole
>>> from jactus.core import ActusDateTime, DayCountConvention
>>> from jactus.observers import ConstantRiskFactorObserver
>>>
>>> attrs = ContractAttributes(
... contract_id="STEP-UP-LOAN-001",
... contract_type=ContractType.LAX,
... contract_role=ContractRole.RPA,
... status_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
... initial_exchange_date=ActusDateTime(2024, 1, 15, 0, 0, 0),
... maturity_date=ActusDateTime(2027, 1, 15, 0, 0, 0),
... currency="USD",
... notional_principal=100000.0,
... nominal_interest_rate=0.05,
... day_count_convention=DayCountConvention.A360,
... array_pr_anchor=[ActusDateTime(2024, 2, 15), ActusDateTime(2025, 1, 15)],
... array_pr_cycle=["1M", "1M"],
... array_pr_next=[1000.0, 2000.0],
... array_increase_decrease=["INC", "DEC"]
... )
>>>
>>> rf_obs = ConstantRiskFactorObserver(constant_value=0.05)
>>> contract = create_contract(attrs, rf_obs)
>>> result = contract.simulate()
"""
from typing import Any
import jax.numpy as jnp
from jactus.contracts.base import BaseContract, SimulationHistory
from jactus.core import (
ActusDateTime,
ContractAttributes,
ContractEvent,
ContractState,
ContractType,
DayCountConvention,
EventSchedule,
EventType,
)
from jactus.functions import BasePayoffFunction, BaseStateTransitionFunction
from jactus.observers import RiskFactorObserver
from jactus.observers.behavioral import BehaviorRiskFactorObserver
from jactus.observers.scenario import Scenario
from jactus.utilities import contract_role_sign, generate_schedule, year_fraction
[docs]
def generate_array_schedule(
anchors: list[ActusDateTime],
cycles: list[str] | None,
end: ActusDateTime,
filter_values: list[str] | None = None,
filter_target: str | None = None,
) -> list[ActusDateTime]:
"""Generate schedule from array of anchors and cycles.
If cycles is None or empty, each anchor date is treated as a single point event.
If cycles is provided, each (anchor, cycle) pair generates a recurring sub-schedule
bounded by the next anchor's start date (segment boundaries).
Args:
anchors: Array of anchor dates (start dates for each sub-schedule)
cycles: Array of cycles (one per anchor), or None for point events
end: End date (maturity date)
filter_values: Optional array of filter values (e.g., ARINCDEC)
filter_target: Optional target value to filter for (e.g., "DEC")
Returns:
Union of all sub-schedules, sorted and deduplicated
"""
if not anchors:
return []
# Point events: no cycles, just return anchors (after filtering)
if not cycles:
all_events = []
for i, anchor in enumerate(anchors):
if filter_values is not None and filter_target is not None:
if filter_values[i] != filter_target:
continue
all_events.append(anchor)
all_events = sorted(set(all_events))
return [d for d in all_events if d <= end]
if len(anchors) != len(cycles):
raise ValueError(
f"Anchors and cycles must have same length: {len(anchors)} vs {len(cycles)}"
)
if filter_values is not None and len(filter_values) != len(anchors):
raise ValueError(
f"Filter values must have same length as anchors: {len(filter_values)} vs {len(anchors)}"
)
all_events = []
for i, (anchor, cycle) in enumerate(zip(anchors, cycles, strict=False)):
# Skip if filter doesn't match
if filter_values is not None and filter_target is not None:
if filter_values[i] != filter_target:
continue
# Determine segment end: next anchor in the FULL array (not just filtered ones)
# or overall end if this is the last segment
segment_end = end
if i + 1 < len(anchors):
segment_end = anchors[i + 1]
# Generate sub-schedule bounded by segment end
sub_schedule = generate_schedule(start=anchor, cycle=cycle, end=segment_end)
# Filter to events at or after anchor but before segment_end
# (segment_end is the start of the next segment, so exclude it)
sub_schedule = [d for d in sub_schedule if anchor <= d < segment_end]
all_events.extend(sub_schedule)
# Sort and deduplicate
all_events = sorted(set(all_events))
# Filter to events before or at maturity
all_events = [d for d in all_events if d <= end]
return all_events
[docs]
class LAXPayoffFunction(BasePayoffFunction):
"""Payoff function for LAX contracts.
Extends LAM payoff functions with PI (Principal Increase) and PRF
(Principal Redemption Amount Fixing) events.
ACTUS Reference:
ACTUS v1.1 Section 7.3 - LAX Payoff Functions
Events:
All LAM events (AD, IED, PR, MD, PP, PY, FP, PRD, TD, IP, IPCI, IPCB, RR, RRF, SC, CE)
Plus:
PI: Principal Increase (negative principal redemption)
PRF: Principal Redemption Amount Fixing (update Prnxt)
"""
[docs]
def __init__(
self, contract_role: Any, currency: str, settlement_currency: str | None = None
) -> None:
"""Initialize LAX payoff function.
Args:
contract_role: Contract role (RPA or RPL)
currency: Contract currency
settlement_currency: Optional settlement currency
"""
super().__init__(
contract_role=contract_role,
currency=currency,
settlement_currency=settlement_currency,
)
[docs]
def calculate_payoff(
self,
event_type: EventType,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""Calculate payoff for given event type.
Args:
event_type: Type of event
state: Current contract state
attributes: Contract attributes
time: Event time
risk_factor_observer: Risk factor observer
Returns:
Payoff amount (JAX array)
"""
if event_type == EventType.AD:
return self._pof_ad(state, attributes, time)
if event_type == EventType.IED:
return self._pof_ied(state, attributes, time)
if event_type == EventType.PR:
return self._pof_pr(state, attributes, time)
if event_type == EventType.PI:
return self._pof_pi(state, attributes, time)
if event_type == EventType.MD:
return self._pof_md(state, attributes, time)
if event_type == EventType.PP:
return self._pof_pp(state, attributes, time, risk_factor_observer)
if event_type == EventType.PY:
return self._pof_py(state, attributes, time)
if event_type == EventType.FP:
return self._pof_fp(state, attributes, time)
if event_type == EventType.PRD:
return self._pof_prd(state, attributes, time)
if event_type == EventType.TD:
return self._pof_td(state, attributes, time)
if event_type == EventType.IP:
return self._pof_ip(state, attributes, time)
if event_type == EventType.IPCI:
return self._pof_ipci(state, attributes, time)
if event_type == EventType.IPCB:
return self._pof_ipcb(state, attributes, time)
if event_type == EventType.PRF:
return self._pof_prf(state, attributes, time)
if event_type == EventType.RR:
return self._pof_rr(state, attributes, time)
if event_type == EventType.RRF:
return self._pof_rrf(state, attributes, time)
if event_type == EventType.SC:
return self._pof_sc(state, attributes, time)
if event_type == EventType.CE:
return self._pof_ce(state, attributes, time)
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ad(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_AD: Analysis Date - no payoff."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ied(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_IED: Initial Exchange - disburse principal."""
role_sign = contract_role_sign(attrs.contract_role)
nt = attrs.notional_principal or 0.0
pdied = attrs.premium_discount_at_ied or 0.0
return jnp.array(role_sign * (-1) * (nt + pdied), dtype=jnp.float32)
def _pof_pr(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_PR: Principal Redemption - pay fixed principal amount.
No role_sign — state.prnxt is already signed.
"""
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
return state.nsc * prnxt
def _pof_pi(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_PI: Principal Increase - receive additional principal (negative PR).
PI payoff is the negative of PR — the sign is already handled
by the event type. No role_sign — state.prnxt is already signed.
"""
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
return -state.nsc * prnxt
def _pof_md(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_MD: Maturity - pay remaining principal.
Interest is paid by the IP event at maturity, not MD.
No role_sign — state.nt is already signed.
"""
return state.nsc * state.nt
def _pof_pp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
rf_obs: RiskFactorObserver | None = None,
) -> jnp.ndarray:
"""POF_PP_LAX: Principal Prepayment.
Formula:
POF_PP_LAX = X^CURS_CUR(t) × f(O_ev(CID, PP, t))
The prepayment amount is observed from the risk factor observer.
"""
if rf_obs is None:
return jnp.array(0.0, dtype=jnp.float32)
try:
pp_amount = rf_obs.observe_event(
attrs.contract_id or "",
EventType.PP,
time,
state,
attrs,
)
return jnp.array(float(pp_amount), dtype=jnp.float32)
except (KeyError, NotImplementedError, TypeError):
return jnp.array(0.0, dtype=jnp.float32)
def _pof_py(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_PY: Penalty - not yet implemented."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_fp(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_FP: Fee Payment - pay accrued fees."""
return state.feac
def _pof_prd(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_PRD: Purchase - not yet implemented."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_td(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_TD: Termination - pay notional and accrued interest.
No role_sign — state vars are already signed.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued = yf * state.ipnr * ipcb
return state.nsc * (state.nt + state.ipac + accrued)
def _pof_ip(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_IP: Interest Payment - pay accrued interest on IPCB.
No role_sign — state vars are already signed.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued = yf * state.ipnr * ipcb
return state.isc * (state.ipac + accrued)
def _pof_ipci(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_IPCI: Interest Capitalization - no payoff."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ipcb(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_IPCB: Interest Calculation Base Fixing - no payoff."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_prf(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_PRF: Principal Redemption Amount Fixing - no payoff.
PRF events update Prnxt but don't generate cashflows.
"""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_rr(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_RR: Rate Reset - no payoff."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_rrf(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_RRF: Rate Reset Fixing - no payoff."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_sc(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_SC: Scaling - no payoff."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ce(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_CE: Credit Event - not yet implemented."""
return jnp.array(0.0, dtype=jnp.float32)
[docs]
class LAXStateTransitionFunction(BaseStateTransitionFunction):
"""State transition function for LAX contracts.
Extends LAM state transitions with PI (Principal Increase) and PRF
(Principal Redemption Amount Fixing) events.
ACTUS Reference:
ACTUS v1.1 Section 7.3 - LAX State Transition Functions
Key Differences from LAM:
- PI events: Increase notional (opposite of PR)
- PRF events: Fix Prnxt from array schedule
- Array-based schedule generation
"""
[docs]
def transition_state(
self,
event_type: EventType,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""Transition state for given event type.
Args:
event_type: Type of event
state: Current contract state
attributes: Contract attributes
time: Event time
risk_factor_observer: Risk factor observer
Returns:
New contract state
"""
if event_type == EventType.AD:
return self._stf_ad(state, attributes, time, risk_factor_observer)
if event_type == EventType.IED:
return self._stf_ied(state, attributes, time, risk_factor_observer)
if event_type == EventType.PR:
return self._stf_pr(state, attributes, time, risk_factor_observer)
if event_type == EventType.PI:
return self._stf_pi(state, attributes, time, risk_factor_observer)
if event_type == EventType.MD:
return self._stf_md(state, attributes, time, risk_factor_observer)
if event_type == EventType.PP:
return self._stf_pp(state, attributes, time, risk_factor_observer)
if event_type == EventType.PY:
return self._stf_py(state, attributes, time, risk_factor_observer)
if event_type == EventType.FP:
return self._stf_fp(state, attributes, time, risk_factor_observer)
if event_type == EventType.PRD:
return self._stf_prd(state, attributes, time, risk_factor_observer)
if event_type == EventType.TD:
return self._stf_td(state, attributes, time, risk_factor_observer)
if event_type == EventType.IP:
return self._stf_ip(state, attributes, time, risk_factor_observer)
if event_type == EventType.IPCI:
return self._stf_ipci(state, attributes, time, risk_factor_observer)
if event_type == EventType.IPCB:
return self._stf_ipcb(state, attributes, time, risk_factor_observer)
if event_type == EventType.PRF:
return self._stf_prf(state, attributes, time, risk_factor_observer)
if event_type == EventType.RR:
return self._stf_rr(state, attributes, time, risk_factor_observer)
if event_type == EventType.RRF:
return self._stf_rrf(state, attributes, time, risk_factor_observer)
if event_type == EventType.SC:
return self._stf_sc(state, attributes, time, risk_factor_observer)
if event_type == EventType.CE:
return self._stf_ce(state, attributes, time, risk_factor_observer)
return state
def _stf_ad(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_AD: Analysis Date - update status date only."""
return state.replace(sd=time)
def _stf_ied(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IED: Initial Exchange - initialize all state variables.
Same as LAM initialization.
"""
role_sign = contract_role_sign(attrs.contract_role)
# Determine IPCB (Interest Calculation Base)
ipcb_mode = attrs.interest_calculation_base or "NT"
if ipcb_mode == "NTIED":
# Fixed at IED notional
ipcb = role_sign * jnp.array(attrs.notional_principal, dtype=jnp.float32)
elif ipcb_mode == "NT":
# Track current notional (will be updated at PR events)
ipcb = role_sign * jnp.array(attrs.notional_principal, dtype=jnp.float32)
else: # NTL
# Will be set at first IPCB event
ipcb = role_sign * jnp.array(attrs.notional_principal, dtype=jnp.float32)
# Initialize prnxt from array or single value (signed by role)
prnxt_val = attrs.next_principal_redemption_amount
if prnxt_val is None and attrs.array_pr_next:
prnxt_val = attrs.array_pr_next[0]
prnxt = jnp.array(role_sign * (prnxt_val or 0.0), dtype=jnp.float32)
return state.replace(
sd=time,
nt=role_sign * jnp.array(attrs.notional_principal, dtype=jnp.float32),
ipnr=jnp.array(attrs.nominal_interest_rate or 0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
ipcb=ipcb,
prnxt=prnxt,
)
def _stf_pr(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PR: Principal Redemption - reduce notional, update IPCB if needed.
Same as LAM: Nt -= Prnxt (both are signed state variables).
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
# Calculate accrued interest using current IPCB
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Reduce notional by prnxt (both signed, cap at remaining notional)
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
effective_prnxt = jnp.sign(prnxt) * jnp.minimum(jnp.abs(prnxt), jnp.abs(state.nt))
new_nt = state.nt - effective_prnxt
# Update IPCB if mode is 'NT'
ipcb_mode = attrs.interest_calculation_base or "NT"
new_ipcb: jnp.ndarray
if ipcb_mode == "NT":
new_ipcb = new_nt
elif ipcb_mode == "NTIED":
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32) # Fixed at IED
else: # NTL
new_ipcb = state.ipcb or jnp.array(
0.0, dtype=jnp.float32
) # Only updated at IPCB events
return state.replace(
sd=time,
nt=new_nt,
ipac=new_ipac,
ipcb=new_ipcb,
)
def _stf_pi(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PI: Principal Increase - increase notional, update IPCB if needed.
Opposite of PR: Nt += Prnxt (both are signed state variables).
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
# Calculate accrued interest using current IPCB
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Increase notional by prnxt (both signed)
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
new_nt = state.nt + prnxt
# Update IPCB if mode is 'NT'
ipcb_mode = attrs.interest_calculation_base or "NT"
new_ipcb: jnp.ndarray
if ipcb_mode == "NT":
new_ipcb = new_nt
elif ipcb_mode == "NTIED":
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32) # Fixed at IED
else: # NTL
new_ipcb = state.ipcb or jnp.array(
0.0, dtype=jnp.float32
) # Only updated at IPCB events
return state.replace(
sd=time,
nt=new_nt,
ipac=new_ipac,
ipcb=new_ipcb,
)
def _stf_md(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_MD: Maturity - zero out all state variables."""
return state.replace(
sd=time,
nt=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
ipcb=jnp.array(0.0, dtype=jnp.float32),
)
def _stf_pp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PP_LAX: Prepayment - accrue interest, reduce notional, update IPCB.
Updates:
ipac_t = Ipac_t⁻ + Y(Sd_t⁻, t) × Ipnr_t⁻ × Ipcb_t⁻
nt_t = Nt_t⁻ - PP_amount
ipcb_t = Nt_t (if IPCB='NT')
sd_t = t
"""
dcc = attrs.day_count_convention or DayCountConvention.A360
yf = year_fraction(state.sd, time, dcc)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Get prepayment amount from risk factor observer
try:
pp_amount = float(
risk_factor_observer.observe_event(
attrs.contract_id or "",
EventType.PP,
time,
state,
attrs,
)
)
except (KeyError, NotImplementedError, TypeError):
pp_amount = 0.0
new_nt = state.nt - jnp.array(pp_amount, dtype=jnp.float32)
# Update IPCB based on mode
ipcb_mode = attrs.interest_calculation_base or "NT"
if ipcb_mode in ("NT", "NTIED"):
new_ipcb = new_nt
else: # NTL - only updated at IPCB events
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32)
return state.replace(
sd=time,
nt=new_nt,
ipac=new_ipac,
ipcb=new_ipcb,
)
def _stf_py(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PY: Penalty - not yet implemented."""
return state.replace(sd=time)
def _stf_fp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_FP: Fee Payment - reset accrued fees."""
# Reset fees after payment
return state.replace(sd=time, feac=jnp.array(0.0, dtype=jnp.float32))
def _stf_prd(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PRD: Purchase - not yet implemented."""
return state.replace(sd=time)
def _stf_td(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_TD: Termination - zero out all state variables."""
return state.replace(
sd=time,
nt=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
ipcb=jnp.array(0.0, dtype=jnp.float32),
)
def _stf_ip(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IP: Interest Payment - reset accrued interest."""
# Reset interest after payment
return state.replace(sd=time, ipac=jnp.array(0.0, dtype=jnp.float32))
def _stf_ipci(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IPCI: Interest Capitalization - add accrued interest to notional."""
role_sign = contract_role_sign(attrs.contract_role)
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued = yf * state.ipnr * ipcb
# Add accrued interest to notional
new_nt = state.nt + role_sign * (state.ipac + accrued)
# Update IPCB if mode is 'NT'
ipcb_mode = attrs.interest_calculation_base or "NT"
new_ipcb: jnp.ndarray
if ipcb_mode == "NT":
new_ipcb = new_nt
else:
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32)
return state.replace(
sd=time,
nt=new_nt,
ipac=jnp.array(0.0, dtype=jnp.float32),
ipcb=new_ipcb,
)
def _stf_ipcb(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IPCB: Interest Calculation Base Fixing - update IPCB to current notional.
Only applies when IPCB mode is 'NTL' (lagged notional).
"""
ipcb_mode = attrs.interest_calculation_base or "NT"
new_ipcb: jnp.ndarray
if ipcb_mode == "NTL":
# Fix IPCB to current notional
new_ipcb = state.nt
else:
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32)
return state.replace(sd=time, ipcb=new_ipcb)
def _stf_prf(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PRF: Principal Redemption Amount Fixing - update Prnxt from array.
This event updates the Prnxt state variable based on the array schedule.
Prnxt is a signed state variable (role_sign applied).
"""
role_sign = contract_role_sign(attrs.contract_role)
if attrs.array_pr_anchor and attrs.array_pr_next:
prnxt_value = attrs.next_principal_redemption_amount or 0.0
# Find which array segment we're in
for i, anchor in enumerate(attrs.array_pr_anchor):
if time >= anchor:
prnxt_value = attrs.array_pr_next[i]
new_prnxt = jnp.array(role_sign * prnxt_value, dtype=jnp.float32)
else:
new_prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
return state.replace(sd=time, prnxt=new_prnxt)
def _stf_rr(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_RR: Rate Reset - accrue interest, then update rate.
For LAX with array_rate, the array value acts as the spread for each segment.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Get new rate from market observation
identifier = attrs.rate_reset_market_object or "RATE"
observed = risk_factor_observer.observe_risk_factor(identifier, time, state, attrs)
multiplier = attrs.rate_reset_multiplier if attrs.rate_reset_multiplier is not None else 1.0
# Use array_rate as spread if available, otherwise use rate_reset_spread
spread = attrs.rate_reset_spread if attrs.rate_reset_spread is not None else 0.0
if attrs.array_rate and attrs.array_rr_anchor:
for i, anchor in enumerate(attrs.array_rr_anchor):
if time >= anchor and i < len(attrs.array_rate):
spread = attrs.array_rate[i]
new_rate = multiplier * observed + spread
if attrs.rate_reset_floor is not None:
new_rate = jnp.maximum(new_rate, jnp.array(attrs.rate_reset_floor, dtype=jnp.float32))
if attrs.rate_reset_cap is not None:
new_rate = jnp.minimum(new_rate, jnp.array(attrs.rate_reset_cap, dtype=jnp.float32))
return state.replace(sd=time, ipac=new_ipac, ipnr=new_rate)
def _stf_rrf(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_RRF: Rate Reset Fixing - fix interest rate from array.
Accrue interest, then set rate from array schedule.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
if attrs.array_rate:
rate = attrs.nominal_interest_rate or 0.0
if attrs.array_rr_anchor:
for i, anchor in enumerate(attrs.array_rr_anchor):
if time >= anchor and i < len(attrs.array_rate):
rate = attrs.array_rate[i]
new_rate = jnp.array(rate, dtype=jnp.float32)
else:
new_rate = state.ipnr
return state.replace(sd=time, ipac=new_ipac, ipnr=new_rate)
def _stf_sc(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_SC: Scaling - not yet implemented."""
return state.replace(sd=time)
def _stf_ce(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_CE: Credit Event - not yet implemented."""
return state.replace(sd=time)
[docs]
class ExoticLinearAmortizerContract(BaseContract):
"""LAX (Exotic Linear Amortizer) contract implementation.
LAX is the most complex amortizing contract, supporting flexible array schedules
for principal redemption, interest payments, and rate resets.
ACTUS Reference:
ACTUS v1.1 Section 7.3 - LAX: Exotic Linear Amortizer
"""
[docs]
def __init__(
self,
attributes: ContractAttributes,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
):
"""Initialize LAX contract.
Args:
attributes: Contract attributes
risk_factor_observer: Risk factor observer for rate updates
child_contract_observer: Optional child contract observer
Raises:
ValueError: If contract_type is not LAX
"""
if attributes.contract_type != ContractType.LAX:
raise ValueError(f"Contract type must be LAX, got {attributes.contract_type.value}")
super().__init__(
attributes=attributes,
risk_factor_observer=risk_factor_observer,
child_contract_observer=child_contract_observer,
)
[docs]
def initialize_state(self) -> ContractState:
"""Initialize LAX contract state.
Initializes all state variables including Prnxt and Ipcb (same as LAM).
Returns:
Initial contract state
"""
role_sign = contract_role_sign(self.attributes.contract_role)
# Initialize Prnxt from single value or first array value
prnxt_val = self.attributes.next_principal_redemption_amount
if prnxt_val is None and self.attributes.array_pr_next:
prnxt_val = self.attributes.array_pr_next[0]
prnxt = jnp.array(role_sign * (prnxt_val or 0.0), dtype=jnp.float32)
return ContractState(
sd=self.attributes.status_date,
tmd=self.attributes.maturity_date or self.attributes.status_date,
nt=jnp.array(0.0, dtype=jnp.float32), # Set at IED
ipnr=jnp.array(0.0, dtype=jnp.float32), # Set at IED
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prnxt=prnxt,
ipcb=jnp.array(0.0, dtype=jnp.float32), # Set at IED
)
[docs]
def get_payoff_function(self, event_type: Any) -> LAXPayoffFunction:
"""Get LAX payoff function.
Args:
event_type: Type of event (not used, all events use same POF)
Returns:
LAX payoff function instance
"""
return LAXPayoffFunction(
contract_role=self.attributes.contract_role,
currency=self.attributes.currency,
settlement_currency=None,
)
[docs]
def get_state_transition_function(self, event_type: Any) -> LAXStateTransitionFunction:
"""Get LAX state transition function.
Args:
event_type: Type of event (not used, all events use same STF)
Returns:
LAX state transition function instance
"""
return LAXStateTransitionFunction()
[docs]
def generate_event_schedule(self) -> EventSchedule:
"""Generate complete event schedule for LAX contract.
LAX uses array schedules to generate PR, PI, PRF, IP, RR, and RRF events.
Returns:
EventSchedule with all contract events
"""
events = []
attributes = self.attributes
ied = attributes.initial_exchange_date
md = attributes.maturity_date
if not ied or not md:
return EventSchedule(events=(), contract_id=attributes.contract_id)
# AD: Analysis Date
events.append(
ContractEvent(
event_type=EventType.AD,
event_time=attributes.status_date,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# IED: Initial Exchange Date
events.append(
ContractEvent(
event_type=EventType.IED,
event_time=ied,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# PR/PI Schedule: Generated from array schedules with ARINCDEC filter
pr_cycles = attributes.array_pr_cycle # May be None for point events
if attributes.array_pr_anchor and attributes.array_increase_decrease:
# Note: No PRF events generated — the simulate() override injects
# prnxt from the array before each PR/PI event automatically.
# PR: Principal Redemption (ARINCDEC='DEC')
pr_schedule = generate_array_schedule(
anchors=attributes.array_pr_anchor,
cycles=pr_cycles,
end=md,
filter_values=attributes.array_increase_decrease,
filter_target="DEC",
)
for time in pr_schedule:
if ied < time < md:
events.append(
ContractEvent(
event_type=EventType.PR,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# PI: Principal Increase (ARINCDEC='INC')
pi_schedule = generate_array_schedule(
anchors=attributes.array_pr_anchor,
cycles=pr_cycles,
end=md,
filter_values=attributes.array_increase_decrease,
filter_target="INC",
)
for time in pi_schedule:
if ied < time < md:
events.append(
ContractEvent(
event_type=EventType.PI,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# IP Schedule: Generated from array schedules (cycles optional for point events)
if attributes.array_ip_anchor:
ip_schedule = generate_array_schedule(
anchors=attributes.array_ip_anchor,
cycles=attributes.array_ip_cycle, # May be None
end=md,
)
for time in ip_schedule:
if ied < time <= md:
events.append(
ContractEvent(
event_type=EventType.IP,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# RR/RRF Schedule: Generated from array schedules with ARFIXVAR filter
# Cycles are optional (point events at anchor dates if no cycles)
if attributes.array_rr_anchor:
rr_cycles = attributes.array_rr_cycle # May be None
if attributes.array_fixed_variable:
# RR: Rate Reset (ARFIXVAR='V')
rr_schedule = generate_array_schedule(
anchors=attributes.array_rr_anchor,
cycles=rr_cycles,
end=md,
filter_values=attributes.array_fixed_variable,
filter_target="V",
)
for time in rr_schedule:
if ied < time <= md:
events.append(
ContractEvent(
event_type=EventType.RR,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# RRF: Rate Reset Fixing (ARFIXVAR='F')
rrf_schedule = generate_array_schedule(
anchors=attributes.array_rr_anchor,
cycles=rr_cycles,
end=md,
filter_values=attributes.array_fixed_variable,
filter_target="F",
)
for time in rrf_schedule:
if ied < time <= md:
events.append(
ContractEvent(
event_type=EventType.RRF,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
else:
# No filter - default to RR
rr_schedule = generate_array_schedule(
anchors=attributes.array_rr_anchor,
cycles=rr_cycles,
end=md,
)
for time in rr_schedule:
if ied < time <= md:
events.append(
ContractEvent(
event_type=EventType.RR,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# IPCB Schedule: Only for IPCB='NTL' mode
if attributes.interest_calculation_base == "NTL":
if (
attributes.interest_calculation_base_cycle
and attributes.interest_calculation_base_anchor
):
ipcb_schedule = generate_schedule(
start=attributes.interest_calculation_base_anchor,
cycle=attributes.interest_calculation_base_cycle,
end=md,
)
for time in ipcb_schedule:
if ied < time <= md:
events.append(
ContractEvent(
event_type=EventType.IPCB,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# IP at maturity if not already in schedule
ip_times = {e.event_time for e in events if e.event_type == EventType.IP}
if md not in ip_times:
events.append(
ContractEvent(
event_type=EventType.IP,
event_time=md,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# MD: Maturity Date
events.append(
ContractEvent(
event_type=EventType.MD,
event_time=md,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=attributes.currency or "XXX",
)
)
# Sort events by time, then by ACTUS processing order within same time
# ACTUS order for LAX: PRF→IPCB→PR/PI→IPCI→IP→FP→RR/RRF→SC→MD
# (RR/RRF after IP so rate change takes effect in next period)
event_order = {
EventType.AD: 0,
EventType.IED: 1,
EventType.PRF: 2,
EventType.IPCB: 3,
EventType.PR: 4,
EventType.PI: 4,
EventType.IPCI: 5,
EventType.IP: 6,
EventType.FP: 7,
EventType.PP: 8,
EventType.PY: 8,
EventType.RR: 9,
EventType.RRF: 9,
EventType.SC: 10,
EventType.TD: 11,
EventType.MD: 12,
}
events.sort(key=lambda e: (e.event_time, event_order.get(e.event_type, 99)))
return EventSchedule(events=tuple(events), contract_id=attributes.contract_id)
def _get_prnxt_for_time(self, time: ActusDateTime) -> float | None:
"""Look up the prnxt value from array for a given event time.
Returns the prnxt from the most recent array segment anchor at or before time.
Returns None if no array is defined.
"""
attrs = self.attributes
if not attrs.array_pr_anchor or not attrs.array_pr_next:
return None
# Find the most recent anchor <= time
prnxt_val = attrs.array_pr_next[0]
for i, anchor in enumerate(attrs.array_pr_anchor):
if time >= anchor and i < len(attrs.array_pr_next):
prnxt_val = attrs.array_pr_next[i]
return prnxt_val
[docs]
def simulate(
self,
risk_factor_observer: RiskFactorObserver | None = None,
child_contract_observer: Any | None = None,
scenario: Scenario | None = None, # noqa: ARG002
behavior_observers: list[BehaviorRiskFactorObserver] | None = None, # noqa: ARG002
) -> SimulationHistory:
"""Simulate LAX contract with array-aware prnxt injection.
Before each PR/PI event, updates state.prnxt from the array schedule
so the correct principal amount is used without explicit PRF events.
"""
risk_obs = risk_factor_observer or self.risk_factor_observer
role_sign = contract_role_sign(self.attributes.contract_role)
state = self.initialize_state()
initial_state = state
events_with_states = []
schedule = self.get_events()
for event in schedule.events:
stf = self.get_state_transition_function(event.event_type)
pof = self.get_payoff_function(event.event_type)
calc_time = event.calculation_time or event.event_time
# Inject prnxt from array before PR/PI events
if event.event_type in (EventType.PR, EventType.PI, EventType.PRF):
prnxt_val = self._get_prnxt_for_time(calc_time)
if prnxt_val is not None:
state = state.replace(prnxt=jnp.array(role_sign * prnxt_val, dtype=jnp.float32))
payoff = pof(
event_type=event.event_type,
state=state,
attributes=self.attributes,
time=calc_time,
risk_factor_observer=risk_obs,
)
state_post = stf(
event_type=event.event_type,
state_pre=state,
attributes=self.attributes,
time=calc_time,
risk_factor_observer=risk_obs,
)
processed_event = ContractEvent(
event_type=event.event_type,
event_time=event.event_time,
payoff=payoff,
currency=event.currency or self.attributes.currency or "XXX",
state_pre=state,
state_post=state_post,
sequence=event.sequence,
)
events_with_states.append(processed_event)
state = state_post
return SimulationHistory(
events=events_with_states,
states=[e.state_post for e in events_with_states if e.state_post is not None],
initial_state=initial_state,
final_state=state,
)