"""Linear Amortizer (LAM) contract implementation.
This module implements the LAM contract type - an amortizing loan with fixed
principal redemption amounts where principal is repaid in regular installments.
LAM is the foundation for other amortizing contracts (NAM, ANN, LAX).
ACTUS Reference:
ACTUS v1.1 Section 7.2 - LAM: Linear Amortizer
Key Features:
- Fixed principal redemption amounts (Prnxt)
- Regular principal reduction (PR events)
- Interest calculated on IPCB (Interest Calculation Base)
- Three IPCB modes: NT (notional tracking), NTIED (fixed at IED), NTL (lagged)
- Optional IPCB events for base fixing
- Maturity can be calculated if not provided
- 16 event types total
IPCB Modes:
- NT: Interest calculated on current notional (lagging one period)
- NTIED: Interest calculated on initial notional at IED (fixed)
- NTL: Interest calculated on notional at last IPCB event (lagged with updates)
Example:
>>> from jactus.contracts import create_contract
>>> from jactus.core import ContractAttributes, ContractType, ContractRole
>>> from jactus.core import ActusDateTime, DayCountConvention
>>> from jactus.observers import ConstantRiskFactorObserver
>>>
>>> attrs = ContractAttributes(
... contract_id="MORTGAGE-001",
... contract_type=ContractType.LAM,
... contract_role=ContractRole.RPA,
... status_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
... initial_exchange_date=ActusDateTime(2024, 1, 15, 0, 0, 0),
... maturity_date=ActusDateTime(2054, 1, 15, 0, 0, 0), # 30 years
... currency="USD",
... notional_principal=300000.0,
... nominal_interest_rate=0.065,
... day_count_convention=DayCountConvention.A360,
... principal_redemption_cycle="1M", # Monthly payments
... next_principal_redemption_amount=1000.0, # $1000/month principal
... interest_calculation_base="NT" # Interest on current notional
... )
>>>
>>> rf_obs = ConstantRiskFactorObserver(constant_value=0.065)
>>> contract = create_contract(attrs, rf_obs)
>>> result = contract.simulate()
"""
import math
from typing import Any
import jax.numpy as jnp
from jactus.contracts.base import BaseContract
from jactus.core import (
ActusDateTime,
ContractAttributes,
ContractEvent,
ContractState,
ContractType,
DayCountConvention,
EventSchedule,
EventType,
)
from jactus.core.types import (
EVENT_SCHEDULE_PRIORITY,
BusinessDayConvention,
Calendar,
EndOfMonthConvention,
)
from jactus.functions import BasePayoffFunction, BaseStateTransitionFunction
from jactus.observers import RiskFactorObserver
from jactus.utilities import contract_role_sign, generate_schedule, year_fraction
[docs]
class LAMPayoffFunction(BasePayoffFunction):
"""Payoff function for LAM contracts.
Implements all LAM payoff functions according to ACTUS specification.
The key difference from PAM is the PR (Principal Redemption) event and
the use of IPCB (Interest Calculation Base) for interest calculations.
ACTUS Reference:
ACTUS v1.1 Section 7.2 - LAM Payoff Functions
Events:
AD: Analysis Date (0.0)
IED: Initial Exchange Date (disburse principal)
PR: Principal Redemption (fixed principal payment)
MD: Maturity Date (final principal + interest)
PP: Principal Prepayment
PY: Penalty Payment
FP: Fee Payment
PRD: Purchase Date
TD: Termination Date
IP: Interest Payment (on IPCB, not current notional)
IPCI: Interest Capitalization
IPCB: Interest Calculation Base fixing
RR: Rate Reset
RRF: Rate Reset Fixing
SC: Scaling
CE: Credit Event
"""
def _build_dispatch_table(self) -> dict[EventType, Any]:
"""Build event type → handler dispatch table.
Lambdas normalize varying handler signatures to a uniform
(state, attributes, time, risk_factor_observer) interface.
"""
return {
EventType.AD: lambda s, a, t, r: self._pof_ad(s, a),
EventType.IED: lambda s, a, t, r: self._pof_ied(s, a),
EventType.PR: lambda s, a, t, r: self._pof_pr(s, a),
EventType.MD: lambda s, a, t, r: self._pof_md(s, a),
EventType.PP: self._pof_pp,
EventType.PY: self._pof_py,
EventType.FP: self._pof_fp,
EventType.PRD: lambda s, a, t, r: self._pof_prd(s, a, t),
EventType.TD: lambda s, a, t, r: self._pof_td(s, a, t),
EventType.IP: lambda s, a, t, r: self._pof_ip(s, a, t),
EventType.IPCI: lambda s, a, t, r: self._pof_ipci(s, a),
EventType.IPCB: lambda s, a, t, r: self._pof_ipcb(s, a),
EventType.RR: lambda s, a, t, r: self._pof_rr(s, a),
EventType.RRF: lambda s, a, t, r: self._pof_rrf(s, a),
EventType.SC: lambda s, a, t, r: self._pof_sc(s, a),
EventType.CE: lambda s, a, t, r: self._pof_ce(s, a),
}
[docs]
def calculate_payoff(
self,
event_type: Any,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""Calculate payoff for LAM event via dict dispatch.
Args:
event_type: Type of event
state: Contract state before event
attributes: Contract attributes
time: Event time
risk_factor_observer: Risk factor observer
Returns:
JAX array containing the payoff amount
"""
handler = self._build_dispatch_table().get(event_type)
if handler is not None:
return handler(state, attributes, time, risk_factor_observer) # type: ignore[no-any-return]
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ad(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_AD: Analysis Date - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ied(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_IED: Initial Exchange - disburse principal.
Formula: R(CNTRL) × (-1) × Nsc × NT
"""
role_sign = contract_role_sign(self.contract_role)
return (
jnp.array(role_sign * (-1.0), dtype=jnp.float32)
* state.nsc
* (attrs.notional_principal or 0.0)
)
def _pof_pr(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_PR: Principal Redemption - pay fixed principal amount.
Formula: Nsc × Prnxt (capped at remaining notional)
No R(CNTRL) — Prnxt is a signed state variable.
"""
# Cap redemption at remaining notional to avoid overshoot
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
effective_prnxt = jnp.sign(prnxt) * jnp.minimum(jnp.abs(prnxt), jnp.abs(state.nt))
return state.nsc * effective_prnxt
def _pof_md(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_MD: Maturity - final principal + accrued interest + fees.
Formula: Nsc × Nt + Isc × Ipac + Feac
No R(CNTRL) — all signed state variables.
"""
return state.nsc * state.nt + state.isc * state.ipac + state.feac
def _pof_pp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
rf_obs: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_PP_LAM: Principal Prepayment.
Formula:
POF_PP_LAM = X^CURS_CUR(t) × f(O_ev(CID, PP, t))
The prepayment amount is observed from the risk factor observer.
"""
try:
pp_amount = rf_obs.observe_event(
attrs.contract_id or "",
EventType.PP,
time,
state,
attrs,
)
return jnp.array(float(pp_amount), dtype=jnp.float32)
except (KeyError, NotImplementedError, TypeError):
return jnp.array(0.0, dtype=jnp.float32)
def _pof_py(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
rf_obs: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_PY: Penalty payment - observed penalty amount.
Note: Not yet implemented.
"""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_fp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
rf_obs: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_FP: Fee payment - accrued fees.
Formula: Feac (signed state variable)
"""
return state.feac
def _pof_prd(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_PRD: Purchase - purchase price + accrued interest on IPCB.
Formula: R(CNTRL) × (-1) × (PPRD + Ipac + Y(Sd, t) × Ipnr × Ipcb)
R(CNTRL) needed because PPRD is an unsigned attribute.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued_interest = yf * state.ipnr * ipcb
role_sign = contract_role_sign(self.contract_role)
pprd = attrs.price_at_purchase_date or 0.0
return jnp.array(role_sign * -1.0, dtype=jnp.float32) * (
jnp.array(pprd, dtype=jnp.float32) + state.ipac + accrued_interest
)
def _pof_td(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_TD: Termination - termination price + accrued interest on IPCB.
Formula: R(CNTRL) × (PTD + Ipac + Y(Sd, t) × Ipnr × Ipcb)
R(CNTRL) needed because PTD is an unsigned attribute.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued_interest = yf * state.ipnr * ipcb
role_sign = contract_role_sign(self.contract_role)
ptd = attrs.price_at_termination_date or 0.0
return jnp.array(role_sign, dtype=jnp.float32) * (
jnp.array(ptd, dtype=jnp.float32) + state.ipac + accrued_interest
)
def _pof_ip(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_IP: Interest Payment - accrued interest on IPCB.
Formula: Isc × (Ipac + Y(Sd, t) × Ipnr × Ipcb)
No R(CNTRL) — all signed state variables.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued_interest = yf * state.ipnr * ipcb
return state.isc * (state.ipac + accrued_interest)
def _pof_ipci(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_IPCI: Interest Capitalization - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ipcb(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_IPCB: Interest Calculation Base fixing - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_rr(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_RR: Rate Reset - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_rrf(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_RRF: Rate Reset Fixing - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_sc(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_SC: Scaling - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ce(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_CE: Credit Event - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
[docs]
class LAMStateTransitionFunction(BaseStateTransitionFunction):
"""State transition function for LAM contracts.
Implements all LAM state transitions according to ACTUS specification.
The key differences from PAM are PR event handling and IPCB tracking.
ACTUS Reference:
ACTUS v1.1 Section 7.2 - LAM State Transition Functions
"""
def _build_dispatch_table(self) -> dict[EventType, Any]:
"""Build event type → handler dispatch table."""
return {
EventType.AD: self._stf_ad,
EventType.IED: self._stf_ied,
EventType.PR: self._stf_pr,
EventType.MD: self._stf_md,
EventType.PP: self._stf_pp,
EventType.PY: self._stf_py,
EventType.FP: self._stf_fp,
EventType.PRD: self._stf_prd,
EventType.TD: self._stf_td,
EventType.IP: self._stf_ip,
EventType.IPCI: self._stf_ipci,
EventType.IPCB: self._stf_ipcb,
EventType.RR: self._stf_rr,
EventType.RRF: self._stf_rrf,
EventType.SC: self._stf_sc,
EventType.CE: self._stf_ce,
}
[docs]
def transition_state(
self,
event_type: Any,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""Apply state transition for LAM event via dict dispatch.
Args:
event_type: Type of event
state: Contract state before event
attributes: Contract attributes
time: Event time
risk_factor_observer: Observer for market data
Returns:
New contract state after event
"""
handler = self._build_dispatch_table().get(event_type)
if handler is not None:
return handler(state, attributes, time, risk_factor_observer) # type: ignore[no-any-return]
return state
def _stf_ad(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_AD: Analysis Date - update status date only."""
return state.replace(sd=time)
def _stf_ied(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IED: Initial Exchange - initialize all state variables.
Key Feature: Initialize Ipcb based on IPCB mode.
"""
role_sign = contract_role_sign(attrs.contract_role)
# Determine IPCB (Interest Calculation Base)
if attrs.interest_calculation_base_amount is not None:
# IPCBA overrides: fixed at specified amount
ipcb = role_sign * jnp.array(attrs.interest_calculation_base_amount, dtype=jnp.float32)
else:
# Default: initialize to notional (mode-specific updates happen later)
ipcb = role_sign * jnp.array(attrs.notional_principal, dtype=jnp.float32)
# Initialize prnxt (signed with role_sign)
# If PRNXT is provided, use it. Otherwise, preserve the auto-calculated
# value from initialize_state (stored in state.prnxt).
if attrs.next_principal_redemption_amount is not None:
prnxt = role_sign * jnp.array(attrs.next_principal_redemption_amount, dtype=jnp.float32)
else:
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32) # Keep auto-calculated value
# Use accrued_interest from attributes if provided (signed with role_sign)
ipac_val = role_sign * attrs.accrued_interest if attrs.accrued_interest is not None else 0.0
return state.replace(
sd=time,
nt=role_sign * jnp.array(attrs.notional_principal or 0.0, dtype=jnp.float32),
ipnr=jnp.array(attrs.nominal_interest_rate or 0.0, dtype=jnp.float32),
ipac=jnp.array(ipac_val, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
ipcb=ipcb,
prnxt=prnxt,
)
def _stf_pr(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PR: Principal Redemption - reduce notional, update IPCB if needed.
Formula:
Nt = Nt - Prnxt (Prnxt is signed)
Ipac = Ipac + Y(Sd, t) × Ipnr × Ipcb
Ipcb = Nt (if IPCB='NT')
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
# Calculate accrued interest using current IPCB
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Reduce notional (prnxt is signed), cap at remaining notional
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
effective_prnxt = jnp.sign(prnxt) * jnp.minimum(jnp.abs(prnxt), jnp.abs(state.nt))
new_nt = state.nt - effective_prnxt
# Update IPCB based on mode
ipcb_mode = attrs.interest_calculation_base or "NT"
new_ipcb: jnp.ndarray
if ipcb_mode in ("NT", "NTIED"):
new_ipcb = new_nt # Track current notional
else: # NTL
new_ipcb = state.ipcb or jnp.array(
0.0, dtype=jnp.float32
) # Only updated at IPCB events
return state.replace(
sd=time,
nt=new_nt,
ipac=new_ipac,
ipcb=new_ipcb,
)
def _stf_md(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_MD: Maturity - zero out all state variables."""
return state.replace(
sd=time,
nt=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
ipcb=jnp.array(0.0, dtype=jnp.float32),
)
def _stf_pp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PP_LAM: Prepayment - accrue interest, reduce notional, update IPCB.
Updates:
ipac_t = Ipac_t⁻ + Y(Sd_t⁻, t) × Ipnr_t⁻ × Ipcb_t⁻
nt_t = Nt_t⁻ - PP_amount
ipcb_t = Nt_t (if IPCB='NT')
sd_t = t
"""
dcc = attrs.day_count_convention or DayCountConvention.A360
yf = year_fraction(state.sd, time, dcc)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Get prepayment amount from risk factor observer
try:
pp_amount = float(
risk_factor_observer.observe_event(
attrs.contract_id or "",
EventType.PP,
time,
state,
attrs,
)
)
except (KeyError, NotImplementedError, TypeError):
pp_amount = 0.0
new_nt = state.nt - jnp.array(pp_amount, dtype=jnp.float32)
# Update IPCB based on mode
ipcb_mode = attrs.interest_calculation_base or "NT"
if ipcb_mode in ("NT", "NTIED"):
new_ipcb = new_nt
else: # NTL - only updated at IPCB events
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32)
return state.replace(
sd=time,
nt=new_nt,
ipac=new_ipac,
ipcb=new_ipcb,
)
def _stf_py(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PY: Penalty - accrue interest and fees.
Note: Not yet fully implemented.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
return state.replace(sd=time, ipac=new_ipac)
def _stf_fp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_FP: Fee Payment - reset accrued fees.
Note: Not yet fully implemented.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
return state.replace(sd=time, ipac=new_ipac, feac=jnp.array(0.0, dtype=jnp.float32))
def _stf_prd(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PRD: Purchase - accrue interest and update status date."""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
return state.replace(sd=time, ipac=new_ipac)
def _stf_td(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_TD: Termination - zero out all states."""
return state.replace(
sd=time,
nt=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
ipcb=jnp.array(0.0, dtype=jnp.float32),
)
def _stf_ip(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IP: Interest Payment - reset accrued interest."""
return state.replace(sd=time, ipac=jnp.array(0.0, dtype=jnp.float32))
def _stf_ipci(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IPCI: Interest Capitalization - add interest to notional.
Formula:
Nt = Nt + Ipac + Y(Sd, t) × Ipnr × Ipcb
Ipac = 0
Ipcb = Nt (if IPCB='NT')
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued = yf * state.ipnr * ipcb
new_nt = state.nt + state.ipac + accrued
# Update IPCB if mode is 'NT'
ipcb_mode = attrs.interest_calculation_base or "NT"
new_ipcb: jnp.ndarray
if ipcb_mode == "NT":
new_ipcb = new_nt
else:
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32)
return state.replace(
sd=time, nt=new_nt, ipac=jnp.array(0.0, dtype=jnp.float32), ipcb=new_ipcb
)
def _stf_ipcb(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IPCB: Interest Calculation Base fixing - reset IPCB to current notional.
Formula:
Ipcb = Nt
Ipac = Ipac + Y(Sd, t) × Ipnr × Ipcb_old
Key Feature: Only used when IPCB='NTL'.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb_old = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb_old
return state.replace(sd=time, ipcb=state.nt, ipac=new_ipac)
def _stf_rr(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_RR: Rate Reset - observe market rate and apply caps/floors.
Formula:
Ipac = Ipac + Y(Sd, t) * Ipnr * Ipcb
Ipnr = min(max(RRMLT * O_rf(RRMO, t) + RRSP, RRLF), RRLC)
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Observe market rate
market_object = attrs.rate_reset_market_object or ""
observed_rate = float(
risk_factor_observer.observe_risk_factor(market_object, time, state, attrs)
)
# Apply multiplier and spread
multiplier = attrs.rate_reset_multiplier if attrs.rate_reset_multiplier is not None else 1.0
spread = attrs.rate_reset_spread if attrs.rate_reset_spread is not None else 0.0
new_rate = multiplier * observed_rate + spread
# Apply floor and cap
if attrs.rate_reset_floor is not None:
new_rate = max(new_rate, attrs.rate_reset_floor)
if attrs.rate_reset_cap is not None:
new_rate = min(new_rate, attrs.rate_reset_cap)
return state.replace(
sd=time,
ipac=new_ipac,
ipnr=jnp.array(new_rate, dtype=jnp.float32),
)
def _stf_rrf(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_RRF: Rate Reset Fixing - fix interest rate to next reset rate.
Formula:
Ipac = Ipac + Y(Sd, t) * Ipnr * Ipcb
Ipnr = RRNXT
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
new_rate = attrs.rate_reset_next if attrs.rate_reset_next is not None else state.ipnr
return state.replace(
sd=time,
ipac=new_ipac,
ipnr=jnp.array(float(new_rate), dtype=jnp.float32),
)
def _stf_sc(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_SC: Scaling - update scaling multipliers from index.
Formula:
Ipac = Ipac + Y(Sd, t) × Ipnr × Ipcb
scaling_ratio = I(t) / I_ref
If SCEF[0] == 'I': Isc = scaling_ratio
If SCEF[1] == 'N': Nsc = scaling_ratio
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Observe current scaling index value
new_isc = state.isc
new_nsc = state.nsc
scaling_mo = attrs.scaling_market_object
if scaling_mo:
current_index = float(
risk_factor_observer.observe_risk_factor(scaling_mo, time, state, attrs)
)
ref_index = attrs.scaling_index_at_contract_deal_date or 1.0
if ref_index != 0.0:
scaling_ratio = current_index / ref_index
else:
scaling_ratio = 1.0
effect_str = str(attrs.scaling_effect.value) if attrs.scaling_effect else "000"
if len(effect_str) >= 1 and effect_str[0] == "I":
new_isc = jnp.array(scaling_ratio, dtype=jnp.float32)
if len(effect_str) >= 2 and effect_str[1] == "N":
new_nsc = jnp.array(scaling_ratio, dtype=jnp.float32)
return state.replace(sd=time, ipac=new_ipac, isc=new_isc, nsc=new_nsc)
def _stf_ce(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_CE: Credit Event - update status date."""
return state.replace(sd=time)
[docs]
class LinearAmortizerContract(BaseContract):
"""Linear Amortizer (LAM) contract.
Amortizing loan with fixed principal redemption amounts. Principal is repaid
in regular installments (PR events), with interest calculated on an Interest
Calculation Base (IPCB) that can track current notional, stay fixed, or update
periodically.
ACTUS Reference:
ACTUS v1.1 Section 7.2 - LAM: Linear Amortizer
Attributes:
attributes: Contract attributes
risk_factor_observer: Risk factor observer for market data
Example:
See module docstring for usage example.
"""
[docs]
def __init__(
self,
attributes: ContractAttributes,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
):
"""Initialize LAM contract.
Args:
attributes: Contract attributes
risk_factor_observer: Risk factor observer
child_contract_observer: Optional child contract observer
Raises:
ValueError: If contract_type is not LAM
ValueError: If required attributes missing
"""
if attributes.contract_type != ContractType.LAM:
raise ValueError(f"Contract type must be LAM, got {attributes.contract_type}")
# Validate required attributes
if not attributes.initial_exchange_date:
raise ValueError("initial_exchange_date required for LAM")
if not attributes.principal_redemption_cycle and not attributes.maturity_date:
raise ValueError("Either principal_redemption_cycle or maturity_date required")
# Note: PRNXT validation is done at simulation time, not contract creation
# if not attributes.next_principal_redemption_amount:
# raise ValueError("next_principal_redemption_amount (PRNXT) required for LAM")
super().__init__(
attributes=attributes,
risk_factor_observer=risk_factor_observer,
child_contract_observer=child_contract_observer,
)
[docs]
def generate_event_schedule(self) -> EventSchedule:
"""Generate LAM event schedule per ACTUS specification.
Schedule formula for each event type:
IED: Single event at initial_exchange_date (if IED >= SD)
PR: S(PRANX, PRCL, MD) - principal redemption schedule
IP: S(IPANX, IPCL, MD) - interest payment schedule
IPCI: S(IPANX, IPCL, IPCED) - interest capitalization until end date
IPCB: S(IPCBANX, IPCBCL, MD) - if IPCB='NTL'
RR: S(RRANX, RRCL, MD) - rate reset schedule
FP: S(FEANX, FECL, MD) - fee payment schedule
PRD: Single event at purchase_date
TD: Single event at termination_date (truncates schedule)
MD: Single event at maturity_date
Events before status_date are excluded.
"""
events: list[ContractEvent] = []
attrs = self.attributes
ied = attrs.initial_exchange_date
md = attrs.maturity_date
sd = attrs.status_date
currency = attrs.currency or "XXX"
assert ied is not None
bdc = attrs.business_day_convention
eomc = attrs.end_of_month_convention
cal = attrs.calendar
# Calculate MD if not provided: MD = last PR date from schedule
if md is None and attrs.principal_redemption_cycle:
prnxt = attrs.next_principal_redemption_amount or 0.0
nt = attrs.notional_principal or 0.0
if prnxt > 0:
n_periods = math.ceil(nt / prnxt)
pr_anchor = attrs.principal_redemption_anchor or ied
# Generate enough dates to find MD
far_end = pr_anchor.add_period(f"{(n_periods + 2) * 12}M", EndOfMonthConvention.SD)
pr_dates = generate_schedule(
start=pr_anchor,
cycle=attrs.principal_redemption_cycle,
end=far_end,
end_of_month_convention=eomc or EndOfMonthConvention.SD,
business_day_convention=bdc or BusinessDayConvention.NULL,
calendar=cal or Calendar.NO_CALENDAR,
)
pr_dates = [d for d in pr_dates if d >= ied]
if len(pr_dates) >= n_periods:
md = pr_dates[n_periods - 1]
def _sched(
anchor: ActusDateTime, cycle: str, end: ActusDateTime | None
) -> list[ActusDateTime]:
return generate_schedule(
start=anchor,
cycle=cycle,
end=end,
end_of_month_convention=eomc or EndOfMonthConvention.SD,
business_day_convention=bdc or BusinessDayConvention.NULL,
calendar=cal or Calendar.NO_CALENDAR,
)
def _add(etype: EventType, time: ActusDateTime) -> None:
events.append(
ContractEvent(
event_type=etype,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=currency,
sequence=0,
)
)
# IED: only if IED >= SD
if ied >= sd:
_add(EventType.IED, ied)
# PR: Principal Redemption schedule
if attrs.principal_redemption_cycle:
pr_anchor = attrs.principal_redemption_anchor or ied
pr_dates = _sched(pr_anchor, attrs.principal_redemption_cycle, md)
# Long stub: remove last PR date before MD if it's not on cycle end
pr_cycle_str = attrs.principal_redemption_cycle or ""
if pr_cycle_str.endswith("+") and pr_dates and md and pr_dates[-1] != md:
pr_dates = pr_dates[:-1]
for dt in pr_dates:
if md and dt >= md:
break
if dt >= ied:
_add(EventType.PR, dt)
# IP: Interest Payment schedule from IPANX (or IED)
if attrs.interest_payment_cycle:
ip_anchor = attrs.interest_payment_anchor or ied
ipced = attrs.interest_capitalization_end_date
ip_dates = _sched(ip_anchor, attrs.interest_payment_cycle, md)
if ipced and ipced not in ip_dates:
ip_dates = sorted(set(ip_dates + [ipced]))
# Stub handling: if MD not on cycle, add MD (and remove last for long stub)
ip_cycle_str = attrs.interest_payment_cycle or ""
if md and md not in ip_dates and ip_dates:
if ip_cycle_str.endswith("+"):
# Long stub: replace last cycle date with MD
ip_dates[-1] = md
else:
# Short stub (default): keep last cycle date and add MD
ip_dates.append(md)
ip_dates = sorted(set(ip_dates))
for dt in ip_dates:
if dt < ied:
continue
if ipced and dt <= ipced:
_add(EventType.IPCI, dt)
else:
_add(EventType.IP, dt)
# IPCB: Interest Calculation Base schedule (only if IPCB='NTL')
if attrs.interest_calculation_base == "NTL" and attrs.interest_calculation_base_cycle:
ipcb_anchor = attrs.interest_calculation_base_anchor or ied
ipcb_dates = _sched(ipcb_anchor, attrs.interest_calculation_base_cycle, md)
# Long stub: remove last date before MD
ipcb_cycle_str = attrs.interest_calculation_base_cycle or ""
if ipcb_cycle_str.endswith("+") and ipcb_dates and md and ipcb_dates[-1] != md:
ipcb_dates = ipcb_dates[:-1]
for dt in ipcb_dates:
if dt > ied and (not md or dt < md):
_add(EventType.IPCB, dt)
# RR: Rate Reset schedule (exclude events at MD, handle long stub)
if attrs.rate_reset_cycle and attrs.rate_reset_anchor:
rr_dates = _sched(attrs.rate_reset_anchor, attrs.rate_reset_cycle, md)
rr_cycle_str = attrs.rate_reset_cycle or ""
if rr_cycle_str.endswith("+") and rr_dates and md and rr_dates[-1] != md:
rr_dates = rr_dates[:-1]
first_rr = True
for dt in rr_dates:
if md and dt >= md:
break
if first_rr and attrs.rate_reset_next is not None:
_add(EventType.RRF, dt)
first_rr = False
else:
_add(EventType.RR, dt)
first_rr = False
# FP: Fee Payment schedule
if attrs.fee_payment_cycle:
fp_anchor = attrs.fee_payment_anchor or ied
fp_dates = _sched(fp_anchor, attrs.fee_payment_cycle, md)
for dt in fp_dates:
if dt > ied:
_add(EventType.FP, dt)
# SC: Scaling Index schedule
if attrs.scaling_index_cycle:
sc_anchor = attrs.scaling_index_anchor or ied
sc_dates = _sched(sc_anchor, attrs.scaling_index_cycle, md)
for dt in sc_dates:
if dt > ied:
_add(EventType.SC, dt)
# PRD: Purchase date
if attrs.purchase_date:
_add(EventType.PRD, attrs.purchase_date)
# TD: Termination date
if attrs.termination_date:
_add(EventType.TD, attrs.termination_date)
# MD: Maturity Date
if md:
_add(EventType.MD, md)
# Filter out events before SD, sort
events = [e for e in events if e.event_time >= sd]
# If PRD exists, remove IED and non-PRD events before/at PRD
if attrs.purchase_date:
prd_time = attrs.purchase_date
events = [
e
for e in events
if e.event_type == EventType.PRD
or (e.event_type != EventType.IED and e.event_time > prd_time)
]
events.sort(
key=lambda e: (e.event_time.to_iso(), EVENT_SCHEDULE_PRIORITY.get(e.event_type, 99))
)
# If TD exists, remove all events after TD
if attrs.termination_date:
td_time = attrs.termination_date
events = [e for e in events if e.event_time <= td_time]
# Reassign sequence numbers
for i in range(len(events)):
events[i] = ContractEvent(
event_type=events[i].event_type,
event_time=events[i].event_time,
payoff=events[i].payoff,
currency=events[i].currency,
sequence=i,
)
return EventSchedule(events=tuple(events), contract_id=attrs.contract_id)
def _pre_simulate_to_prd(self, attrs: ContractAttributes, prnxt: jnp.ndarray) -> ContractState:
"""Pre-simulate events from IED to PRD to get correct initial state.
When a contract has a purchase date, events between IED and PRD
(PR, IP, RR, IPCB, etc.) affect the state. This method runs those
events to compute the true state at purchase time.
"""
ied = attrs.initial_exchange_date
prd = attrs.purchase_date
sd = attrs.status_date
md = attrs.maturity_date
assert ied is not None
assert prd is not None
bdc = attrs.business_day_convention
eomc = attrs.end_of_month_convention
cal = attrs.calendar
def _sched(
anchor: ActusDateTime | None, cycle: str | None, end: ActusDateTime | None
) -> list[ActusDateTime]:
if anchor is None or cycle is None or end is None:
return []
return generate_schedule(
start=anchor,
cycle=cycle,
end=end,
end_of_month_convention=eomc or EndOfMonthConvention.SD,
business_day_convention=bdc or BusinessDayConvention.NULL,
calendar=cal or Calendar.NO_CALENDAR,
)
# Build pre-PRD event list (IED through events strictly before PRD)
pre_events: list[tuple[ActusDateTime, EventType]] = []
pre_events.append((ied, EventType.IED))
# PR events (include events at PRD time - they reduce NT before purchase)
if attrs.principal_redemption_cycle:
pr_anchor = attrs.principal_redemption_anchor or ied
pr_dates = _sched(pr_anchor, attrs.principal_redemption_cycle, md or prd)
for dt in pr_dates:
if dt >= ied and dt <= prd:
pre_events.append((dt, EventType.PR))
# IP events (include events at PRD time - they reset ipac before purchase)
if attrs.interest_payment_cycle:
ip_anchor = attrs.interest_payment_anchor or ied
ipced = attrs.interest_capitalization_end_date
ip_dates = _sched(ip_anchor, attrs.interest_payment_cycle, md or prd)
for dt in ip_dates:
if dt < ied or dt > prd:
continue
if ipced and dt <= ipced:
pre_events.append((dt, EventType.IPCI))
else:
pre_events.append((dt, EventType.IP))
# RR events (include events at PRD time)
if attrs.rate_reset_cycle and attrs.rate_reset_anchor:
rr_dates = _sched(attrs.rate_reset_anchor, attrs.rate_reset_cycle, md or prd)
first_rr = True
for dt in rr_dates:
if dt > prd:
break
if first_rr and attrs.rate_reset_next is not None:
pre_events.append((dt, EventType.RRF))
first_rr = False
else:
pre_events.append((dt, EventType.RR))
first_rr = False
# IPCB events
if attrs.interest_calculation_base == "NTL" and attrs.interest_calculation_base_cycle:
ipcb_anchor = attrs.interest_calculation_base_anchor or ied
ipcb_dates = _sched(ipcb_anchor, attrs.interest_calculation_base_cycle, md or prd)
for dt in ipcb_dates:
if dt > ied and dt <= prd:
pre_events.append((dt, EventType.IPCB))
# Only keep events from IED onward (some anchors may be before IED)
pre_events = [(t, e) for t, e in pre_events if t >= ied]
# Sort by time, then by event priority
pre_events.sort(key=lambda e: (e[0].to_iso(), EVENT_SCHEDULE_PRIORITY.get(e[1], 99)))
# Create initial state (pre-IED)
state = ContractState(
sd=ied,
tmd=md or ied,
nt=jnp.array(0.0, dtype=jnp.float32),
ipnr=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prnxt=prnxt,
ipcb=jnp.array(0.0, dtype=jnp.float32),
)
# Run STF for each pre-PRD event
stf = self.get_state_transition_function(None)
for time, etype in pre_events:
state = stf.transition_state(
event_type=etype,
state=state,
attributes=attrs,
time=time,
risk_factor_observer=self.risk_factor_observer,
)
# When IED < SD, advance state to SD (accrual before SD is not reported)
# Note: ipac is reset to 0 (or user-specified accrued_interest) because
# POF_PRD will compute accrual from SD to PRD. Setting ipac to the accrual
# from last event to SD would double-count interest.
if ied < sd and sd < prd:
ipac = jnp.array(attrs.accrued_interest or 0.0, dtype=jnp.float32)
state = state.replace(sd=sd, ipac=ipac)
return state
[docs]
def initialize_state(self) -> ContractState:
"""Initialize LAM contract state.
When IED < SD (contract already existed), state is initialized
as if STF_IED already ran: Nt, Ipnr are set from attributes,
and interest is accrued from IED (or IPANX) to SD.
Returns:
Initial contract state
"""
attrs = self.attributes
sd = attrs.status_date
ied = attrs.initial_exchange_date
role_sign = contract_role_sign(attrs.contract_role)
# Initialize Prnxt (next principal redemption amount)
if attrs.next_principal_redemption_amount is not None:
prnxt_val = attrs.next_principal_redemption_amount
elif (
attrs.notional_principal
and attrs.principal_redemption_cycle
and attrs.maturity_date
and ied
):
# Auto-calculate: PRNXT = NT / number_of_PR_periods
pr_anchor = attrs.principal_redemption_anchor or ied
pr_dates = generate_schedule(
start=pr_anchor,
cycle=attrs.principal_redemption_cycle,
end=attrs.maturity_date,
)
pr_dates = [d for d in pr_dates if d <= attrs.maturity_date]
if attrs.maturity_date not in pr_dates:
pr_dates.append(attrs.maturity_date)
num_periods = len(pr_dates)
prnxt_val = attrs.notional_principal / num_periods if num_periods > 0 else 0.0
else:
prnxt_val = 0.0
prnxt = jnp.array(role_sign * prnxt_val, dtype=jnp.float32)
# PRD pre-simulation: run events from IED to PRD to get correct state
if attrs.purchase_date and ied:
return self._pre_simulate_to_prd(attrs, prnxt)
needs_post_ied = ied and ied < sd
if needs_post_ied:
# Contract already started - initialize post-IED state
nt = role_sign * (attrs.notional_principal or 0.0)
ipnr = attrs.nominal_interest_rate or 0.0
dcc = attrs.day_count_convention or DayCountConvention.A360
init_sd = sd
accrual_start = attrs.interest_payment_anchor or ied
if attrs.accrued_interest is not None:
ipac = attrs.accrued_interest
elif accrual_start and accrual_start < sd:
yf = year_fraction(accrual_start, sd, dcc)
ipac = yf * ipnr * nt
else:
ipac = 0.0
if attrs.interest_calculation_base_amount is not None:
ipcb_val = role_sign * attrs.interest_calculation_base_amount
else:
ipcb_val = nt # IPCB initialized to signed notional at IED
return ContractState(
sd=init_sd,
tmd=attrs.maturity_date or sd,
nt=jnp.array(nt, dtype=jnp.float32),
ipnr=jnp.array(ipnr, dtype=jnp.float32),
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prnxt=prnxt,
ipcb=jnp.array(ipcb_val, dtype=jnp.float32),
)
return ContractState(
sd=sd,
tmd=attrs.maturity_date or sd,
nt=jnp.array(0.0, dtype=jnp.float32), # Set at IED
ipnr=jnp.array(0.0, dtype=jnp.float32), # Set at IED
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prnxt=prnxt,
ipcb=jnp.array(0.0, dtype=jnp.float32), # Set at IED
)
[docs]
def get_payoff_function(self, event_type: Any) -> LAMPayoffFunction:
"""Get LAM payoff function.
Args:
event_type: Type of event (not used, all events use same POF)
Returns:
LAM payoff function instance
"""
return LAMPayoffFunction(
contract_role=self.attributes.contract_role,
currency=self.attributes.currency,
settlement_currency=None,
)
[docs]
def get_state_transition_function(self, event_type: Any) -> LAMStateTransitionFunction:
"""Get LAM state transition function.
Args:
event_type: Type of event (not used, all events use same STF)
Returns:
LAM state transition function instance
"""
return LAMStateTransitionFunction()