"""Principal At Maturity (PAM) contract implementation.
This module implements the PAM contract type - a bullet loan with interest payments
where the principal is repaid at maturity. PAM is the foundational loan contract
and serves as a template for amortizing contracts (LAM, NAM, ANN).
ACTUS Reference:
ACTUS v1.1 Section 7.1 - PAM: Principal At Maturity
Key Features:
- Principal repaid in full at maturity
- Regular interest payments (IP events)
- Optional interest capitalization (IPCI)
- Variable interest rates with rate resets (RR, RRF)
- Fees (FP events)
- Prepayments (PP events)
- Scaling (SC events)
- 14 event types total
Example:
>>> from jactus.contracts.pam import PrincipalAtMaturityContract
>>> from jactus.core import ContractAttributes, ContractType, ContractRole
>>> from jactus.observers import ConstantRiskFactorObserver
>>>
>>> attrs = ContractAttributes(
... contract_id="LOAN-001",
... contract_type=ContractType.PAM,
... contract_role=ContractRole.RPA,
... status_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
... initial_exchange_date=ActusDateTime(2024, 1, 15, 0, 0, 0),
... maturity_date=ActusDateTime(2029, 1, 15, 0, 0, 0),
... currency="USD",
... notional_principal=100000.0,
... nominal_interest_rate=0.05,
... day_count_convention=DayCountConvention.A360,
... interest_payment_cycle="1Y"
... )
>>>
>>> rf_obs = ConstantRiskFactorObserver(constant_value=0.05)
>>> contract = PrincipalAtMaturityContract(
... attributes=attrs,
... risk_factor_observer=rf_obs
... )
>>> result = contract.simulate()
"""
from typing import Any
import flax.nnx as nnx
import jax.numpy as jnp
from jactus.contracts.base import BaseContract
from jactus.core import (
ActusDateTime,
ContractAttributes,
ContractEvent,
ContractRole,
ContractState,
ContractType,
DayCountConvention,
EventSchedule,
EventType,
FeeBasis,
)
from jactus.core.time import adjust_to_business_day
from jactus.core.types import (
EVENT_SCHEDULE_PRIORITY,
BusinessDayConvention,
Calendar,
EndOfMonthConvention,
)
from jactus.functions import BasePayoffFunction, BaseStateTransitionFunction
from jactus.observers import ChildContractObserver, RiskFactorObserver
from jactus.utilities import contract_role_sign, generate_schedule, year_fraction
def _last_cycle_date_before(
anchor: ActusDateTime,
cycle: str,
target: ActusDateTime,
) -> ActusDateTime:
"""Find the last date in a cycle sequence that is <= target.
Walks forward from *anchor* by *cycle* until exceeding *target*,
then returns the previous step. Used to compute the correct
accrual start for contracts whose IED predates the status date.
"""
from jactus.core.time import add_period
current = anchor
while True:
next_date = add_period(current, cycle)
if next_date > target:
return current
current = next_date
[docs]
class PAMPayoffFunction(BasePayoffFunction):
"""Payoff function for PAM contracts.
Implements all 14 PAM payoff functions according to ACTUS specification.
Uses dictionary-based dispatch for O(1) lookup and future jax.lax.switch
compatibility (requires EventType.index integer mapping).
ACTUS Reference:
ACTUS v1.1 Section 7.1 - PAM Payoff Functions
Events:
AD: Analysis Date (0.0)
IED: Initial Exchange Date (disburse principal)
MD: Maturity Date (return principal + accrued)
PP: Principal Prepayment
PY: Penalty Payment
FP: Fee Payment
PRD: Purchase Date
TD: Termination Date
IP: Interest Payment
IPCI: Interest Capitalization
RR: Rate Reset
RRF: Rate Reset Fixing
SC: Scaling
CE: Credit Event
"""
def _build_dispatch_table(self) -> dict[EventType, Any]:
"""Build event type → handler dispatch table.
Returns a dict mapping each handled EventType to its payoff method.
This replaces the if/elif chain with O(1) dict lookup and provides
a clear registry of supported events.
"""
return {
EventType.AD: self._pof_ad,
EventType.IED: self._pof_ied,
EventType.MD: self._pof_md,
EventType.PP: self._pof_pp,
EventType.PY: self._pof_py,
EventType.FP: self._pof_fp,
EventType.PRD: self._pof_prd,
EventType.TD: self._pof_td,
EventType.IP: self._pof_ip,
EventType.IPCI: self._pof_ipci,
EventType.RR: self._pof_rr,
EventType.RRF: self._pof_rrf,
EventType.SC: self._pof_sc,
EventType.CE: self._pof_ce,
}
[docs]
def calculate_payoff(
self,
event_type: Any,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""Calculate payoff for PAM events.
Dispatches to specific payoff function via dict lookup (O(1)).
Args:
event_type: Type of event
state: Current contract state
attributes: Contract attributes
time: Event time
risk_factor_observer: Observer for market data
Returns:
Payoff amount as JAX array
ACTUS Reference:
POF_[event]_PAM functions from Section 7.1
"""
handler = self._build_dispatch_table().get(event_type)
if handler is not None:
return handler(state, attributes, time, risk_factor_observer) # type: ignore[no-any-return]
# Unknown event type - return 0
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ad(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_AD_PAM: Analysis Date has no cashflow.
Returns:
0.0
"""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ied(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_IED_PAM: Initial Exchange - disburse principal.
Formula:
POF_IED_PAM = X^CURS_CUR(t) × R(CNTRL) × (-1) × (NT + PDIED)
Where:
NT: Notional principal
PDIED: Premium/discount at IED
R(CNTRL): Role sign
X^CURS_CUR(t): FX rate
Returns:
Negative of notional plus premium/discount (outflow for lender)
"""
# Get notional and premium/discount
nt = attributes.notional_principal or 0.0
pdied = attributes.premium_discount_at_ied or 0.0
# Calculate payoff: R(CNTRL) × (-1) × (NT + PDIED)
# R(CNTRL) is needed because NT and PDIED are unsigned attributes
role_sign = contract_role_sign(self.contract_role)
payoff = role_sign * (-1.0) * (nt + pdied)
return jnp.array(payoff, dtype=jnp.float32)
def _pof_md(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_MD_PAM: Maturity Date - return principal + accrued.
Formula:
POF_MD_PAM = X^CURS_CUR(t) × (Nsc_t⁻ × Nt_t⁻ + Isc_t⁻ × Ipac_t⁻ + Feac_t⁻)
Returns:
Scaled notional + scaled accrued interest + accrued fees
"""
nsc = float(state.nsc)
nt = float(state.nt)
isc = float(state.isc)
ipac = float(state.ipac)
feac = float(state.feac)
payoff = nsc * nt + isc * ipac + feac
return jnp.array(payoff, dtype=jnp.float32)
def _pof_pp(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_PP_PAM: Principal Prepayment.
Formula:
POF_PP_PAM = X^CURS_CUR(t) × f(O_ev(CID, PP, t))
The prepayment amount is observed from the risk factor observer
using the contract ID and PP event type.
Returns:
Observed prepayment amount
"""
# Observe prepayment amount from risk factor observer
try:
pp_amount = risk_factor_observer.observe_event(
attributes.contract_id or "",
EventType.PP,
time,
state,
attributes,
)
return jnp.array(float(pp_amount), dtype=jnp.float32)
except (KeyError, NotImplementedError, TypeError):
return jnp.array(0.0, dtype=jnp.float32)
def _pof_py(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_PY_PAM: Penalty Payment.
Formula depends on PYTP (penalty type):
PYTP='A': PYRT
PYTP='N': Y(Sd_t⁻, t) × Nt_t⁻ × PYRT
PYTP='I': Interest rate differential (simplified to type N)
No R(CNTRL) — Nt is a signed state variable.
"""
pytp = attributes.penalty_type
pyrt = attributes.penalty_rate or 0.0
if pytp == "A":
return jnp.array(pyrt, dtype=jnp.float32)
if pytp == "N" or pytp == "I":
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state.sd, time, dcc)
nt = float(state.nt)
return jnp.array(yf * nt * pyrt, dtype=jnp.float32)
return jnp.array(0.0, dtype=jnp.float32)
def _pof_fp(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_FP_PAM: Fee Payment.
Formula depends on FEB (fee basis):
FEB='A': FER
FEB='N': Y(Sd_t⁻, t) × Nt_t⁻ × FER + Feac_t⁻
No R(CNTRL) — Nt and Feac are signed state variables.
"""
feb = attributes.fee_basis
fer = attributes.fee_rate or 0.0
if feb == FeeBasis.A:
return jnp.array(fer, dtype=jnp.float32)
if feb == FeeBasis.N:
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state.sd, time, dcc)
nt = float(state.nt)
feac = float(state.feac)
return jnp.array(yf * nt * fer + feac, dtype=jnp.float32)
feac = float(state.feac)
return jnp.array(feac, dtype=jnp.float32)
def _pof_prd(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_PRD_PAM: Purchase Date - pay purchase price + accrued interest.
Formula:
POF_PRD_PAM = R(CNTRL) × (-1) ×
(PPRD + Ipac_t⁻ + Y(Sd_t⁻, t) × Ipnr_t⁻ × Nt_t⁻)
Returns:
Negative of (purchase price + accrued interest)
"""
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state.sd, time, dcc)
ipac = float(state.ipac)
ipnr = float(state.ipnr)
nt = float(state.nt)
accrued_interest = yf * ipnr * nt
pprd = attributes.price_at_purchase_date or 0.0
# POF_PRD_PAM = (-1) × (PPRD + Ipac + Y × Ipnr × Nt)
# No R(CNTRL) — Ipac and Nt are signed state variables
payoff = (-1.0) * (pprd + ipac + accrued_interest)
return jnp.array(payoff, dtype=jnp.float32)
def _pof_td(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_TD_PAM: Termination Date - receive termination price + accrued.
Formula:
POF_TD_PAM = R(CNTRL) ×
(PTD + Ipac_t⁻ + Y(Sd_t⁻, t) × Ipnr_t⁻ × Nt_t⁻)
Returns:
Termination price + accrued interest
"""
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state.sd, time, dcc)
ipac = float(state.ipac)
ipnr = float(state.ipnr)
nt = float(state.nt)
accrued_interest = yf * ipnr * nt
ptd = attributes.price_at_termination_date or 0.0
# POF_TD_PAM = PTD + Ipac + Y × Ipnr × Nt
# No R(CNTRL) — Ipac and Nt are signed state variables
payoff = ptd + ipac + accrued_interest
return jnp.array(payoff, dtype=jnp.float32)
def _pof_ip(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_IP_PAM: Interest Payment.
Formula:
POF_IP_PAM = X^CURS_CUR(t) × Isc_t⁻ × (Ipac_t⁻ + Y(Sd_t⁻, t) × Ipnr_t⁻ × Nt_t⁻)
Returns:
Scaled accrued interest payment
"""
isc = float(state.isc)
ipac = float(state.ipac)
ipnr = float(state.ipnr)
nt = float(state.nt)
# Calculate year fraction from last status date to now
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state.sd, time, dcc)
# Interest payment = Isc × (Ipac + YF × Ipnr × Nt)
payoff = isc * (ipac + yf * ipnr * nt)
return jnp.array(payoff, dtype=jnp.float32)
def _pof_ipci(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_IPCI_PAM: Interest Capitalization - no cashflow.
Returns:
0.0 (interest is capitalized into notional)
"""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_rr(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_RR_PAM: Rate Reset - no cashflow.
Returns:
0.0 (rate is updated in state transition)
"""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_rrf(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_RRF_PAM: Rate Reset Fixing - no cashflow.
Returns:
0.0 (rate is fixed in state transition)
"""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_sc(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_SC_PAM: Scaling - no cashflow.
Returns:
0.0 (scaling multipliers updated in state transition)
"""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ce(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_CE_PAM: Credit Event - no cashflow.
Returns:
0.0 (handled in state transition)
"""
return jnp.array(0.0, dtype=jnp.float32)
[docs]
class PAMStateTransitionFunction(BaseStateTransitionFunction):
"""State transition function for PAM contracts.
Implements all 14 PAM state transition functions according to ACTUS specification.
Uses dictionary-based dispatch for O(1) lookup.
ACTUS Reference:
ACTUS v1.1 Section 7.1 - PAM State Transition Functions
"""
def _build_dispatch_table(self) -> dict[EventType, Any]:
"""Build event type → handler dispatch table."""
return {
EventType.AD: self._stf_ad,
EventType.IED: self._stf_ied,
EventType.MD: self._stf_md,
EventType.PP: self._stf_pp,
EventType.PY: self._stf_py,
EventType.FP: self._stf_fp,
EventType.PRD: self._stf_prd,
EventType.TD: self._stf_td,
EventType.IP: self._stf_ip,
EventType.IPCI: self._stf_ipci,
EventType.RR: self._stf_rr,
EventType.RRF: self._stf_rrf,
EventType.SC: self._stf_sc,
EventType.CE: self._stf_ce,
}
[docs]
def transition_state(
self,
event_type: Any,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""Transition PAM contract state.
Dispatches to specific state transition function via dict lookup (O(1)).
Args:
event_type: Type of event
state_pre: State before event
attributes: Contract attributes
time: Event time
risk_factor_observer: Observer for market data
Returns:
Updated contract state
ACTUS Reference:
STF_[event]_PAM functions from Section 7.1
"""
handler = self._build_dispatch_table().get(event_type)
if handler is not None:
return handler(state_pre, attributes, time, risk_factor_observer) # type: ignore[no-any-return]
# Unknown event type - return unchanged state
return state_pre
def _stf_ad(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_AD_PAM: Analysis Date - accrue interest and update status date.
Updates:
ipac_t = Ipac_t⁻ + Y(Sd_t⁻, t) × Ipnr_t⁻ × Nt_t⁻
sd_t = t
"""
# Calculate year fraction
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state_pre.sd, time, dcc)
# Accrue interest
ipac = float(state_pre.ipac) + yf * float(state_pre.ipnr) * float(state_pre.nt)
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=state_pre.ipnr,
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_ied(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IED_PAM: Initial Exchange - set notional and interest rate.
Updates:
nt_t = R(CNTRL) × NT
ipnr_t = IPNR (if defined, else 0.0)
ipac_t = IPAC (if given) or calculated from IPANX
sd_t = t
"""
# Get role sign
role_sign = self._get_role_sign(attributes.contract_role)
# Set notional with role sign
nt = role_sign * (attributes.notional_principal or 0.0)
# Set interest rate
ipnr = attributes.nominal_interest_rate or 0.0
# Set initial accrued interest
# Per ACTUS spec: use IPAC if given, else calculate from IPANX if before IED
if attributes.accrued_interest is not None:
ipac = attributes.accrued_interest
elif (
attributes.interest_payment_anchor is not None
and attributes.interest_payment_anchor < time
):
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(attributes.interest_payment_anchor, time, dcc)
ipac = yf * ipnr * abs(nt)
else:
ipac = 0.0
return ContractState(
sd=time,
tmd=attributes.maturity_date or time,
nt=jnp.array(nt, dtype=jnp.float32),
ipnr=jnp.array(ipnr, dtype=jnp.float32),
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
)
def _stf_md(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_MD_PAM: Maturity - zero notional and accruals, preserve rate.
Per ACTUS spec, ipnr is preserved at maturity (not zeroed).
Updates:
nt_t = 0.0
ipac_t = 0.0
feac_t = 0.0
sd_t = t
"""
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=jnp.array(0.0, dtype=jnp.float32),
ipnr=state_pre.ipnr,
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_pp(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PP_PAM: Prepayment - reduce notional and accrue interest.
Updates:
ipac_t = Ipac_t⁻ + Y(Sd_t⁻, t) × Ipnr_t⁻ × Nt_t⁻
nt_t = Nt_t⁻ - PP_amount
sd_t = t
"""
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state_pre.sd, time, dcc)
ipac = float(state_pre.ipac) + yf * float(state_pre.ipnr) * float(state_pre.nt)
# Get prepayment amount from risk factor observer
try:
pp_amount = float(
risk_factor_observer.observe_event(
attributes.contract_id or "",
EventType.PP,
time,
state_pre,
attributes,
)
)
except (KeyError, NotImplementedError, TypeError):
pp_amount = 0.0
new_nt = float(state_pre.nt) - pp_amount
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=jnp.array(new_nt, dtype=jnp.float32),
ipnr=state_pre.ipnr,
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_py(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PY_PAM: Penalty - accrue interest, no notional change."""
return self._stf_ad(state_pre, attributes, time, risk_factor_observer)
def _stf_fp(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_FP_PAM: Fee Payment - reset fees, accrue interest."""
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state_pre.sd, time, dcc)
ipac = float(state_pre.ipac) + yf * float(state_pre.ipnr) * float(state_pre.nt)
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=state_pre.ipnr,
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32), # Reset fees
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_prd(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_PRD_PAM: Purchase - accrue interest."""
return self._stf_ad(state_pre, attributes, time, risk_factor_observer)
def _stf_td(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_TD_PAM: Termination - zero out amounts, preserve rate."""
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=jnp.array(0.0, dtype=jnp.float32),
ipnr=state_pre.ipnr,
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_ip(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IP_PAM: Interest Payment - reset accrued interest."""
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=state_pre.ipnr,
ipac=jnp.array(0.0, dtype=jnp.float32), # Reset accrued interest
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_ipci(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_IPCI_PAM: Interest Capitalization - add accrued interest to notional."""
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state_pre.sd, time, dcc)
# Calculate total accrued interest
total_ipac = float(state_pre.ipac) + yf * float(state_pre.ipnr) * float(state_pre.nt)
# Add to notional
nt = float(state_pre.nt) + total_ipac
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=jnp.array(nt, dtype=jnp.float32),
ipnr=state_pre.ipnr,
ipac=jnp.array(0.0, dtype=jnp.float32), # Reset after capitalization
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_rr(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_RR_PAM: Rate Reset - observe market rate and apply caps/floors.
Per ACTUS spec (Section 7.1):
Ipac_t = Ipac_(t-) + Y(Sd_(t-), t) * Ipnr_(t-) * Nt_(t-)
Ipnr_t = min(max(RRMLT * O_rf(RRMO, t) + RRSP, RRLF), RRLC)
Sd_t = t
"""
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state_pre.sd, time, dcc)
# Accrue interest up to reset time using old rate
ipac = float(state_pre.ipac) + yf * float(state_pre.ipnr) * float(state_pre.nt)
# Observe new market rate
market_object = attributes.rate_reset_market_object or ""
observed_rate = float(
risk_factor_observer.observe_risk_factor(market_object, time, state_pre, attributes)
)
# Apply rate multiplier and spread
multiplier = (
attributes.rate_reset_multiplier
if attributes.rate_reset_multiplier is not None
else 1.0
)
spread = attributes.rate_reset_spread if attributes.rate_reset_spread is not None else 0.0
new_rate = multiplier * observed_rate + spread
# Apply floor and cap
if attributes.rate_reset_floor is not None:
new_rate = max(new_rate, attributes.rate_reset_floor)
if attributes.rate_reset_cap is not None:
new_rate = min(new_rate, attributes.rate_reset_cap)
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=jnp.array(new_rate, dtype=jnp.float32),
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_rrf(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_RRF_PAM: Rate Reset Fixing - set rate to predefined value.
Per ACTUS spec (Section 7.1):
Ipac_t = Ipac_(t-) + Y(Sd_(t-), t) * Ipnr_(t-) * Nt_(t-)
Ipnr_t = RRNXT
Sd_t = t
"""
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state_pre.sd, time, dcc)
# Accrue interest up to reset time using old rate
ipac = float(state_pre.ipac) + yf * float(state_pre.ipnr) * float(state_pre.nt)
# Set rate to next predefined value
new_rate = (
attributes.rate_reset_next
if attributes.rate_reset_next is not None
else float(state_pre.ipnr)
)
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=jnp.array(new_rate, dtype=jnp.float32),
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_sc(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_SC_PAM: Scaling - accrue interest and update scaling multipliers.
Formula:
Ipac = Ipac + Y(Sd, t) × Ipnr × Nt
scaling_ratio = O_rf(SCMO, t) / SCIXCDD
If SCEF[0] == 'I': Isc = scaling_ratio
If SCEF[1] == 'N': Nsc = scaling_ratio
Sd = t
"""
dcc = attributes.day_count_convention or DayCountConvention.A360
yf = year_fraction(state_pre.sd, time, dcc)
new_ipac = float(state_pre.ipac) + yf * float(state_pre.ipnr) * float(state_pre.nt)
new_isc = state_pre.isc
new_nsc = state_pre.nsc
scaling_mo = attributes.scaling_market_object
if scaling_mo:
current_index = float(
risk_factor_observer.observe_risk_factor(scaling_mo, time, state_pre, attributes)
)
ref_index = attributes.scaling_index_at_contract_deal_date or 1.0
if ref_index != 0.0:
scaling_ratio = current_index / ref_index
else:
scaling_ratio = 1.0
effect_str = (
str(attributes.scaling_effect.value) if attributes.scaling_effect else "000"
)
if len(effect_str) >= 1 and effect_str[0] == "I":
new_isc = jnp.array(scaling_ratio, dtype=jnp.float32)
if len(effect_str) >= 2 and effect_str[1] == "N":
new_nsc = jnp.array(scaling_ratio, dtype=jnp.float32)
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=state_pre.ipnr,
ipac=jnp.array(new_ipac, dtype=jnp.float32),
feac=state_pre.feac,
nsc=new_nsc,
isc=new_isc,
)
def _stf_ce(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_CE_PAM: Credit Event - same as AD."""
return self._stf_ad(state_pre, attributes, time, risk_factor_observer)
def _get_role_sign(self, contract_role: ContractRole | None) -> float:
"""Get the sign for contract role."""
if contract_role in (ContractRole.RPA, ContractRole.RFL):
return 1.0
if contract_role in (ContractRole.RPL, ContractRole.PFL):
return -1.0
return 1.0
[docs]
class PrincipalAtMaturityContract(BaseContract):
"""Principal At Maturity (PAM) contract implementation.
Represents a bullet loan where principal is repaid in full at maturity
with regular interest payments.
ACTUS Reference:
ACTUS v1.1 Section 7.1 - PAM: Principal At Maturity
Attributes:
attributes: Contract attributes
risk_factor_observer: Observer for market data
child_contract_observer: Observer for child contracts
rngs: Random number generators for JAX
"""
[docs]
def __init__(
self,
attributes: ContractAttributes,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: ChildContractObserver | None = None,
*,
rngs: nnx.Rngs | None = None,
):
"""Initialize PAM contract.
Args:
attributes: Contract attributes
risk_factor_observer: Observer for market data
child_contract_observer: Observer for child contracts (optional)
rngs: Random number generators (optional)
Raises:
ValueError: If validation fails
"""
super().__init__(
attributes=attributes,
risk_factor_observer=risk_factor_observer,
child_contract_observer=child_contract_observer,
rngs=rngs,
)
# Validate contract type
if attributes.contract_type != ContractType.PAM:
raise ValueError(f"Contract type must be PAM, got {attributes.contract_type}")
# Validate required attributes
if attributes.initial_exchange_date is None:
raise ValueError("PAM contract requires initial_exchange_date (IED)")
if attributes.maturity_date is None:
raise ValueError("PAM contract requires maturity_date (MD)")
if attributes.notional_principal is None:
raise ValueError("PAM contract requires notional_principal (NT)")
# Validate date ordering
# Note: IED < SD is allowed per ACTUS spec (contract already existed
# before the status/observation date).
if attributes.maturity_date <= attributes.initial_exchange_date:
raise ValueError("maturity_date (MD) must be after initial_exchange_date (IED)")
[docs]
def generate_event_schedule(self) -> EventSchedule:
"""Generate PAM event schedule per ACTUS specification.
Schedule formula for each event type:
IED: Single event at initial_exchange_date (if IED >= SD)
IP: S(IPANX, IPCL, MD) - from interest payment anchor
IPCI: S(IPANX, IPCL, IPCED) - interest capitalization until end date
RR: S(RRANX, RRCL, MD) - rate reset schedule
RRF: S(RRANX, RRCL, MD) - if rateResetFixing defined
FP: S(FEANX, FECL, MD) - fee payment schedule
PRD: Single event at purchase_date
TD: Single event at termination_date (truncates schedule)
MD: Single event at maturity_date
Events before status_date are excluded.
"""
events: list[ContractEvent] = []
attrs = self.attributes
ied = attrs.initial_exchange_date
md = attrs.maturity_date
sd = attrs.status_date
currency = attrs.currency or "XXX"
assert ied is not None
assert md is not None
bdc = attrs.business_day_convention
eomc = attrs.end_of_month_convention
cal = attrs.calendar
# CS (Calculate/Shift) conventions: generate dates WITHOUT BDC,
# then shift event_time while preserving original date for calculations.
# SC (Shift/Calculate) conventions: BDC applied during schedule generation.
is_cs = bdc in (
BusinessDayConvention.CSF,
BusinessDayConvention.CSMF,
BusinessDayConvention.CSP,
BusinessDayConvention.CSMP,
)
sched_bdc = BusinessDayConvention.NULL if is_cs else (bdc or BusinessDayConvention.NULL)
def _sched(anchor: ActusDateTime, cycle: str, end: ActusDateTime) -> list[ActusDateTime]:
"""Generate schedule with EOMC/calendar from attributes."""
return generate_schedule(
start=anchor,
cycle=cycle,
end=end,
end_of_month_convention=eomc or EndOfMonthConvention.SD,
business_day_convention=sched_bdc,
calendar=cal or Calendar.NO_CALENDAR,
)
def _add(
etype: EventType, time: ActusDateTime, calc_time: ActusDateTime | None = None
) -> None:
"""Add event. For CS conventions on cycle dates, shift time and set calc_time."""
event_time = time
event_calc_time = calc_time
if is_cs and calc_time is None:
# Shift the event time for display/matching, keep original for calculation
shifted = adjust_to_business_day(time, bdc, cal or Calendar.NO_CALENDAR)
if shifted != time:
event_calc_time = time
event_time = shifted
events.append(
ContractEvent(
event_type=etype,
event_time=event_time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=currency,
sequence=0,
calculation_time=event_calc_time,
)
)
# IED: only if IED >= SD
if ied >= sd:
_add(EventType.IED, ied)
# IP: Interest Payment schedule from IPANX (or IED)
if attrs.interest_payment_cycle:
ip_anchor = attrs.interest_payment_anchor or ied
ipced = attrs.interest_capitalization_end_date
ip_dates = _sched(ip_anchor, attrs.interest_payment_cycle, md)
# Add IPCI at IPCED if it's not already on a cycle date
if ipced and ipced not in ip_dates:
ip_dates = sorted(set(ip_dates + [ipced]))
# Stub handling: if MD not on cycle, add IP at MD
ip_cycle_str = attrs.interest_payment_cycle or ""
if md and md not in ip_dates and ip_dates:
if ip_cycle_str.endswith("+"):
# Long stub: replace last cycle date with MD
ip_dates[-1] = md
else:
# Short stub (default): keep last cycle date and add MD
ip_dates.append(md)
ip_dates = sorted(set(ip_dates))
for dt in ip_dates:
if dt < ied:
continue
if ipced and dt <= ipced:
_add(EventType.IPCI, dt)
else:
_add(EventType.IP, dt)
# RR: Rate Reset schedule (exclude events at MD, handle long stub)
if attrs.rate_reset_cycle and attrs.rate_reset_anchor:
rr_dates = _sched(attrs.rate_reset_anchor, attrs.rate_reset_cycle, md)
rr_cycle_str = attrs.rate_reset_cycle or ""
if rr_cycle_str.endswith("+") and rr_dates and md and rr_dates[-1] != md:
rr_dates = rr_dates[:-1]
first_rr = True
for dt in rr_dates:
if md and dt >= md:
break
# First reset is RRF when nextResetRate is provided
if first_rr and attrs.rate_reset_next is not None:
_add(EventType.RRF, dt)
first_rr = False
else:
_add(EventType.RR, dt)
first_rr = False
# FP: Fee Payment schedule
if attrs.fee_payment_cycle:
fp_anchor = attrs.fee_payment_anchor or ied
fp_dates = _sched(fp_anchor, attrs.fee_payment_cycle, md)
for dt in fp_dates:
if dt > ied:
_add(EventType.FP, dt)
# SC: Scaling Index schedule
if attrs.scaling_index_cycle:
sc_anchor = attrs.scaling_index_anchor or ied
sc_dates = _sched(sc_anchor, attrs.scaling_index_cycle, md)
for dt in sc_dates:
if dt > ied:
_add(EventType.SC, dt)
# PRD: Purchase date
if attrs.purchase_date:
_add(EventType.PRD, attrs.purchase_date)
# TD: Termination date
if attrs.termination_date:
_add(EventType.TD, attrs.termination_date)
# MD: Maturity Date
_add(EventType.MD, md)
# Filter out events before SD, sort, and handle TD truncation
events = [e for e in events if e.event_time >= sd]
# If PRD exists, remove IED and events before PRD
if attrs.purchase_date:
events = [
e
for e in events
if e.event_type != EventType.IED and e.event_time >= attrs.purchase_date
]
events.sort(
key=lambda e: (e.event_time.to_iso(), EVENT_SCHEDULE_PRIORITY.get(e.event_type, 99))
)
# If TD exists, remove all events after TD (except TD itself)
if attrs.termination_date:
td_time = attrs.termination_date
events = [e for e in events if e.event_time <= td_time]
# Reassign sequence numbers
for i in range(len(events)):
events[i] = ContractEvent(
event_type=events[i].event_type,
event_time=events[i].event_time,
payoff=events[i].payoff,
currency=events[i].currency,
sequence=i,
calculation_time=events[i].calculation_time,
)
return EventSchedule(events=tuple(events), contract_id=attrs.contract_id)
[docs]
def initialize_state(self) -> ContractState:
"""Initialize PAM contract state.
When IED < SD (contract already existed), state is initialized
as if STF_IED already ran: Nt, Ipnr are set from attributes,
and interest is accrued from IED (or IPANX) to SD.
Returns:
Initial contract state
"""
attrs = self.attributes
sd = attrs.status_date
ied = attrs.initial_exchange_date
# When IED < SD or PRD is set (IED excluded from schedule),
# initialize state as if STF_IED already ran
needs_post_ied = (ied and ied < sd) or attrs.purchase_date
if needs_post_ied:
role_sign = 1.0
if attrs.contract_role in (ContractRole.RPL, ContractRole.PFL):
role_sign = -1.0
nt = role_sign * (attrs.notional_principal or 0.0)
ipnr = attrs.nominal_interest_rate or 0.0
dcc = attrs.day_count_convention or DayCountConvention.A360
# When IED >= SD and PRD is set, the contract starts at IED.
# Set sd=IED so interest accrues from IED to PRD (not SD to PRD).
if ied and ied >= sd and attrs.purchase_date:
init_sd = ied
ipac = 0.0
else:
# IED < SD: contract already started, accrue from last IP date to SD
init_sd = sd
accrual_start = attrs.interest_payment_anchor or ied
if attrs.accrued_interest is not None:
ipac = attrs.accrued_interest
elif accrual_start and accrual_start < sd:
# Walk forward through the IP cycle to find the last
# payment date before SD — past IP events reset ipac.
if attrs.interest_payment_cycle and accrual_start < sd:
accrual_start = _last_cycle_date_before(
accrual_start, attrs.interest_payment_cycle, sd
)
yf = year_fraction(accrual_start, sd, dcc)
ipac = yf * ipnr * abs(nt)
else:
ipac = 0.0
return ContractState(
sd=init_sd,
tmd=attrs.maturity_date or sd,
nt=jnp.array(nt, dtype=jnp.float32),
ipnr=jnp.array(ipnr, dtype=jnp.float32),
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
)
# Normal case: initialize before IED
return ContractState(
sd=sd,
tmd=attrs.maturity_date or sd,
nt=jnp.array(0.0, dtype=jnp.float32),
ipnr=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
)
[docs]
def get_payoff_function(self, event_type: Any) -> PAMPayoffFunction:
"""Get payoff function for PAM events."""
return PAMPayoffFunction(
contract_role=self.attributes.contract_role,
currency=self.attributes.currency,
settlement_currency=None,
)
[docs]
def get_state_transition_function(self, event_type: Any) -> PAMStateTransitionFunction:
"""Get state transition function for PAM events."""
return PAMStateTransitionFunction()