"""Negative Amortizer (NAM) contract implementation.
This module implements the NAM contract type - an amortizing loan where principal
can increase (negative amortization) when payments are less than accrued interest.
NAM extends the LAM pattern with modified payoff and state transition functions.
ACTUS Reference:
ACTUS v1.1 Section 7.4 - NAM: Negative Amortizer
Key Features:
- Negative amortization: Notional can increase if payment < interest
- Modified PR payoff: Prnxt - accrued interest (can be negative)
- Modified PR STF: Notional changes by net payment amount
- IP schedule ends one period before PR schedule
- Maturity calculation accounts for negative amortization effect
- Same states as LAM: prnxt, ipcb
- Same events as LAM
Negative Amortization:
When the scheduled principal payment (Prnxt) is less than the accrued interest,
the shortfall is added to the notional principal:
- If Prnxt > interest: Normal amortization (notional decreases)
- If Prnxt < interest: Negative amortization (notional increases)
- If Prnxt = interest: Interest-only payment (notional unchanged)
Example:
>>> from jactus.contracts import create_contract
>>> from jactus.core import ContractAttributes, ContractType, ContractRole
>>> from jactus.core import ActusDateTime, DayCountConvention
>>> from jactus.observers import ConstantRiskFactorObserver
>>>
>>> attrs = ContractAttributes(
... contract_id="NAM-001",
... contract_type=ContractType.NAM,
... contract_role=ContractRole.RPA,
... status_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
... initial_exchange_date=ActusDateTime(2024, 1, 15, 0, 0, 0),
... maturity_date=ActusDateTime(2054, 1, 15, 0, 0, 0),
... currency="USD",
... notional_principal=300000.0,
... nominal_interest_rate=0.065,
... day_count_convention=DayCountConvention.A360,
... principal_redemption_cycle="1M", # Monthly payments
... next_principal_redemption_amount=800.0, # Low payment → negative amort
... interest_calculation_base="NT"
... )
>>>
>>> rf_obs = ConstantRiskFactorObserver(constant_value=0.065)
>>> contract = create_contract(attrs, rf_obs)
>>> result = contract.simulate()
"""
from typing import Any
import jax.numpy as jnp
from jactus.contracts.base import BaseContract
from jactus.core import (
ActusDateTime,
ContractAttributes,
ContractEvent,
ContractState,
ContractType,
DayCountConvention,
EventSchedule,
EventType,
)
from jactus.core.types import (
EVENT_SCHEDULE_PRIORITY,
BusinessDayConvention,
Calendar,
EndOfMonthConvention,
)
from jactus.functions import BasePayoffFunction, BaseStateTransitionFunction
from jactus.observers import RiskFactorObserver
from jactus.utilities import contract_role_sign, generate_schedule, year_fraction
[docs]
class NAMPayoffFunction(BasePayoffFunction):
"""Payoff function for NAM contracts.
Implements all NAM payoff functions according to ACTUS specification.
The key difference from LAM is the PR event payoff, which is net of
accrued interest and can be negative.
ACTUS Reference:
ACTUS v1.1 Section 7.4 - NAM Payoff Functions
Events:
Same as LAM, but PR payoff modified:
POF_PR = R(CNTRL) × Nsc × (Prnxt - Ipac - Y(Sd, t) × Ipnr × Ipcb)
"""
def _build_dispatch_table(self) -> dict[EventType, Any]:
"""Build event type → handler dispatch table.
Lambdas normalize varying handler signatures to a uniform
(state, attributes, time, risk_factor_observer) interface.
"""
return {
EventType.AD: lambda s, a, t, r: self._pof_ad(s, a),
EventType.IED: lambda s, a, t, r: self._pof_ied(s, a),
EventType.PR: lambda s, a, t, r: self._pof_pr(s, a, t),
EventType.MD: lambda s, a, t, r: self._pof_md(s, a),
EventType.PP: self._pof_pp,
EventType.PY: self._pof_py,
EventType.FP: self._pof_fp,
EventType.PRD: lambda s, a, t, r: self._pof_prd(s, a, t),
EventType.TD: lambda s, a, t, r: self._pof_td(s, a, t),
EventType.IP: lambda s, a, t, r: self._pof_ip(s, a, t),
EventType.IPCI: lambda s, a, t, r: self._pof_ipci(s, a),
EventType.IPCB: lambda s, a, t, r: self._pof_ipcb(s, a),
EventType.RR: lambda s, a, t, r: self._pof_rr(s, a),
EventType.RRF: lambda s, a, t, r: self._pof_rrf(s, a),
EventType.PRF: lambda s, a, t, r: jnp.array(0.0, dtype=jnp.float32),
EventType.SC: lambda s, a, t, r: self._pof_sc(s, a),
EventType.CE: lambda s, a, t, r: self._pof_ce(s, a),
}
[docs]
def calculate_payoff(
self,
event_type: Any,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> jnp.ndarray:
"""Calculate payoff for NAM event via dict dispatch.
Args:
event_type: Type of event
state: Contract state before event
attributes: Contract attributes
time: Event time
risk_factor_observer: Risk factor observer
Returns:
JAX array containing the payoff amount
"""
handler = self._build_dispatch_table().get(event_type)
if handler is not None:
return handler(state, attributes, time, risk_factor_observer) # type: ignore[no-any-return]
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ad(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_AD: Analysis Date - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ied(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_IED: Initial Exchange - disburse principal.
Formula: R(CNTRL) × (-1) × Nsc × NT
"""
role_sign = contract_role_sign(self.contract_role)
return (
jnp.array(role_sign * (-1.0), dtype=jnp.float32)
* state.nsc
* (attrs.notional_principal or 0.0)
)
def _pof_pr(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_PR: Principal Redemption - pay principal NET of accrued interest.
Formula: Nsc × (Prnxt - Ipac - Y(Sd, t) × Ipnr × Ipcb)
No R(CNTRL) — all signed state variables.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued_interest = yf * state.ipnr * ipcb
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
net_payment = prnxt - state.ipac - accrued_interest
# Clamp: don't pay more principal than remaining notional
net_payment = jnp.where(
jnp.abs(net_payment) > jnp.abs(state.nt),
state.nt,
net_payment,
)
return state.nsc * net_payment
def _pof_md(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_MD: Maturity - final principal + accrued interest + fees.
Formula: Nsc × Nt + Isc × Ipac + Feac
No R(CNTRL) — all signed state variables.
"""
return state.nsc * state.nt + state.isc * state.ipac + state.feac
def _pof_pp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
rf_obs: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_PP_NAM: Principal Prepayment.
Formula:
POF_PP_NAM = X^CURS_CUR(t) × f(O_ev(CID, PP, t))
The prepayment amount is observed from the risk factor observer.
"""
try:
pp_amount = rf_obs.observe_event(
attrs.contract_id or "",
EventType.PP,
time,
state,
attrs,
)
return jnp.array(float(pp_amount), dtype=jnp.float32)
except (KeyError, NotImplementedError, TypeError):
return jnp.array(0.0, dtype=jnp.float32)
def _pof_py(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
rf_obs: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_PY: Penalty payment - observed penalty amount."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_fp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
rf_obs: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_FP: Fee payment - accrued fees.
Formula: Feac (signed state variable)
"""
return state.feac
def _pof_prd(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_PRD: Purchase - purchase price + accrued interest on IPCB.
Formula: R(CNTRL) × (-1) × (PPRD + Ipac + Y(Sd, t) × Ipnr × Ipcb)
R(CNTRL) needed because PPRD is an unsigned attribute.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued_interest = yf * state.ipnr * ipcb
role_sign = contract_role_sign(self.contract_role)
pprd = attrs.price_at_purchase_date or 0.0
return jnp.array(role_sign * -1.0, dtype=jnp.float32) * (
jnp.array(pprd, dtype=jnp.float32) + state.ipac + accrued_interest
)
def _pof_td(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_TD: Termination - termination price + accrued interest on IPCB.
Formula: R(CNTRL) × (PTD + Ipac + Y(Sd, t) × Ipnr × Ipcb)
R(CNTRL) needed because PTD is an unsigned attribute.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued_interest = yf * state.ipnr * ipcb
role_sign = contract_role_sign(self.contract_role)
ptd = attrs.price_at_termination_date or 0.0
return jnp.array(role_sign, dtype=jnp.float32) * (
jnp.array(ptd, dtype=jnp.float32) + state.ipac + accrued_interest
)
def _pof_ip(
self, state: ContractState, attrs: ContractAttributes, time: ActusDateTime
) -> jnp.ndarray:
"""POF_IP: Interest Payment - accrued interest on IPCB.
Formula: Isc × (Ipac + Y(Sd, t) × Ipnr × Ipcb)
No R(CNTRL) — all signed state variables.
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued_interest = yf * state.ipnr * ipcb
return state.isc * (state.ipac + accrued_interest)
def _pof_ipci(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_IPCI: Interest Capitalization - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ipcb(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_IPCB: Interest Calculation Base fixing - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_rr(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_RR: Rate Reset - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_rrf(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_RRF: Rate Reset Fixing - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_sc(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_SC: Scaling - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ce(self, state: ContractState, attrs: ContractAttributes) -> jnp.ndarray:
"""POF_CE: Credit Event - no cashflow."""
return jnp.array(0.0, dtype=jnp.float32)
[docs]
class NAMStateTransitionFunction(BaseStateTransitionFunction):
"""State transition function for NAM contracts.
Implements all NAM state transitions according to ACTUS specification.
The key difference from LAM is the PR event, which adjusts notional by
the NET payment amount (payment - interest), allowing notional to increase.
ACTUS Reference:
ACTUS v1.1 Section 7.4 - NAM State Transition Functions
"""
def _build_dispatch_table(self) -> dict[EventType, Any]:
"""Build event type → handler dispatch table."""
return {
EventType.AD: self._stf_ad,
EventType.IED: self._stf_ied,
EventType.PR: self._stf_pr,
EventType.MD: self._stf_md,
EventType.PP: self._stf_pp,
EventType.PY: self._stf_py,
EventType.FP: self._stf_fp,
EventType.PRD: self._stf_prd,
EventType.TD: self._stf_td,
EventType.IP: self._stf_ip,
EventType.IPCI: self._stf_ipci,
EventType.IPCB: self._stf_ipcb,
EventType.RR: self._stf_rr,
EventType.RRF: self._stf_rrf,
EventType.SC: self._stf_sc,
EventType.CE: self._stf_ce,
}
[docs]
def transition_state(
self,
event_type: Any,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""Apply state transition for NAM event via dict dispatch.
Args:
event_type: Type of event
state: Contract state before event
attributes: Contract attributes
time: Event time
risk_factor_observer: Observer for market data
Returns:
New contract state after event
"""
handler = self._build_dispatch_table().get(event_type)
if handler is not None:
return handler(state, attributes, time, risk_factor_observer) # type: ignore[no-any-return]
return state
def _stf_ad(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_AD: Analysis Date - update status date only."""
return state.replace(sd=time)
def _stf_ied(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_IED: Initial Exchange - initialize all state variables.
Same as LAM: Initialize Ipcb based on IPCB mode.
"""
role_sign = contract_role_sign(attrs.contract_role)
# Determine IPCB (Interest Calculation Base)
ipcb_mode = attrs.interest_calculation_base or "NT"
if ipcb_mode == "NTIED":
# Fixed at IED notional
ipcb = role_sign * jnp.array(attrs.notional_principal, dtype=jnp.float32)
elif ipcb_mode == "NT":
# Track current notional (will be updated at PR events)
ipcb = role_sign * jnp.array(attrs.notional_principal, dtype=jnp.float32)
else: # NTL
# Will be set at first IPCB event
ipcb = role_sign * jnp.array(attrs.notional_principal, dtype=jnp.float32)
# Initialize prnxt (signed with role_sign)
if attrs.next_principal_redemption_amount is not None:
prnxt = role_sign * jnp.array(attrs.next_principal_redemption_amount, dtype=jnp.float32)
else:
prnxt = state.prnxt or jnp.array(
0.0, dtype=jnp.float32
) # Keep auto-calculated value from initialize_state
# Use accrued_interest from attributes if provided (signed with role_sign)
ipac_val = role_sign * attrs.accrued_interest if attrs.accrued_interest is not None else 0.0
return state.replace(
sd=time,
nt=role_sign * jnp.array(attrs.notional_principal or 0.0, dtype=jnp.float32),
ipnr=jnp.array(attrs.nominal_interest_rate or 0.0, dtype=jnp.float32),
ipac=jnp.array(ipac_val, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
ipcb=ipcb,
prnxt=prnxt,
)
def _stf_pr(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_PR: Principal Redemption - adjust notional by NET payment.
Formula:
Nt = Nt - R(CNTRL) × (Prnxt - Ipac - Y(Sd, t) × Ipnr × Ipcb)
Ipac = 0 (interest paid/capitalized)
Ipcb = Nt (if IPCB='NT')
Key Feature: If Prnxt < interest, notional INCREASES (negative amortization).
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
# Calculate accrued interest using current IPCB
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued_interest = yf * state.ipnr * ipcb
# Total accrued interest (stored in ipac for IP event to pay out)
new_ipac = state.ipac + accrued_interest
# Net principal reduction = payment - all interest
# If negative, this represents principal increase (negative amortization)
# Per ACTUS spec: Nt_t = Nt_(t-) - (Prnxt - Ipac - Y*Ipnr*Ipcb)
prnxt = state.prnxt or jnp.array(0.0, dtype=jnp.float32)
net_principal_reduction = prnxt - new_ipac
# Clamp: don't reduce notional past zero
net_principal_reduction = jnp.where(
jnp.abs(net_principal_reduction) > jnp.abs(state.nt),
state.nt,
net_principal_reduction,
)
new_nt = state.nt - net_principal_reduction
# Update IPCB if mode is 'NT'
ipcb_mode = attrs.interest_calculation_base or "NT"
new_ipcb: jnp.ndarray
if ipcb_mode == "NT":
new_ipcb = new_nt
elif ipcb_mode == "NTIED":
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32) # Fixed at IED
else: # NTL
new_ipcb = state.ipcb or jnp.array(
0.0, dtype=jnp.float32
) # Only updated at IPCB events
return state.replace(
sd=time,
nt=new_nt,
ipac=new_ipac,
ipcb=new_ipcb,
)
def _stf_md(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_MD: Maturity - zero out all state variables."""
return state.replace(
sd=time,
nt=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
ipcb=jnp.array(0.0, dtype=jnp.float32),
)
def _stf_pp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_PP_NAM: Prepayment - accrue interest, reduce notional, update IPCB.
Updates:
ipac_t = Ipac_t⁻ + Y(Sd_t⁻, t) × Ipnr_t⁻ × Ipcb_t⁻
nt_t = Nt_t⁻ - PP_amount
ipcb_t = Nt_t (if IPCB='NT')
sd_t = t
"""
dcc = attrs.day_count_convention or DayCountConvention.A360
yf = year_fraction(state.sd, time, dcc)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Get prepayment amount from risk factor observer
try:
pp_amount = float(
risk_factor_observer.observe_event(
attrs.contract_id or "",
EventType.PP,
time,
state,
attrs,
)
)
except (KeyError, NotImplementedError, TypeError):
pp_amount = 0.0
new_nt = state.nt - jnp.array(pp_amount, dtype=jnp.float32)
# Update IPCB based on mode
ipcb_mode = attrs.interest_calculation_base or "NT"
if ipcb_mode in ("NT", "NTIED"):
new_ipcb = new_nt
else: # NTL - only updated at IPCB events
new_ipcb = state.ipcb or jnp.array(0.0, dtype=jnp.float32)
return state.replace(
sd=time,
nt=new_nt,
ipac=new_ipac,
ipcb=new_ipcb,
)
def _stf_py(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_PY: Penalty - accrue interest and fees."""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
return state.replace(sd=time, ipac=new_ipac)
def _stf_fp(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_FP: Fee Payment - reset accrued fees."""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
return state.replace(sd=time, ipac=new_ipac, feac=jnp.array(0.0, dtype=jnp.float32))
def _stf_prd(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_PRD: Purchase - accrue interest and update status date."""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
return state.replace(sd=time, ipac=new_ipac)
def _stf_td(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_TD: Termination - zero out all states."""
return state.replace(
sd=time,
nt=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
ipcb=jnp.array(0.0, dtype=jnp.float32),
)
def _stf_ip(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_IP: Interest Payment - reset accrued interest."""
return state.replace(sd=time, ipac=jnp.array(0.0, dtype=jnp.float32))
def _stf_ipci(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_IPCI: Interest Capitalization - add interest to notional.
Formula:
Nt = Nt + Ipac + Y(Sd, t) × Ipnr × Ipcb
Ipac = 0
Ipcb = Nt (if IPCB='NT')
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
accrued = yf * state.ipnr * ipcb
new_nt = state.nt + state.ipac + accrued
# Update IPCB if mode is 'NT'
ipcb_mode = attrs.interest_calculation_base or "NT"
new_ipcb = (
new_nt if ipcb_mode == "NT" else (state.ipcb or jnp.array(0.0, dtype=jnp.float32))
)
return state.replace(
sd=time, nt=new_nt, ipac=jnp.array(0.0, dtype=jnp.float32), ipcb=new_ipcb
)
def _stf_ipcb(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_IPCB: Interest Calculation Base fixing - reset IPCB to current notional.
Formula:
Ipcb = Nt
Ipac = Ipac + Y(Sd, t) × Ipnr × Ipcb_old
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb_old = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb_old
return state.replace(sd=time, ipcb=state.nt, ipac=new_ipac)
def _stf_rr(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_RR: Rate Reset - observe market rate and apply caps/floors.
Formula:
Ipac = Ipac + Y(Sd, t) * Ipnr * Ipcb
Ipnr = min(max(RRMLT * O_rf(RRMO, t) + RRSP, RRLF), RRLC)
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
# Observe market rate
market_object = attrs.rate_reset_market_object or ""
observed_rate = float(
risk_factor_observer.observe_risk_factor(market_object, time, state, attrs)
)
# Apply multiplier and spread
multiplier = attrs.rate_reset_multiplier if attrs.rate_reset_multiplier is not None else 1.0
spread = attrs.rate_reset_spread if attrs.rate_reset_spread is not None else 0.0
new_rate = multiplier * observed_rate + spread
# Apply floor and cap
if attrs.rate_reset_floor is not None:
new_rate = max(new_rate, attrs.rate_reset_floor)
if attrs.rate_reset_cap is not None:
new_rate = min(new_rate, attrs.rate_reset_cap)
return state.replace(
sd=time,
ipac=new_ipac,
ipnr=jnp.array(new_rate, dtype=jnp.float32),
)
def _stf_rrf(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_RRF: Rate Reset Fixing - fix interest rate to next reset rate.
Formula:
Ipac = Ipac + Y(Sd, t) * Ipnr * Ipcb
Ipnr = RRNXT
"""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
new_rate = attrs.rate_reset_next if attrs.rate_reset_next is not None else state.ipnr
return state.replace(
sd=time,
ipac=new_ipac,
ipnr=jnp.array(float(new_rate), dtype=jnp.float32),
)
def _stf_sc(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_SC: Scaling - update scaling multipliers."""
yf = year_fraction(state.sd, time, attrs.day_count_convention or DayCountConvention.A360)
ipcb = state.ipcb if state.ipcb is not None else state.nt
new_ipac = state.ipac + yf * state.ipnr * ipcb
return state.replace(sd=time, ipac=new_ipac)
def _stf_ce(
self,
state: ContractState,
attrs: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
) -> ContractState:
"""STF_CE: Credit Event - update status date."""
return state.replace(sd=time)
[docs]
class NegativeAmortizerContract(BaseContract):
"""Negative Amortizer (NAM) contract.
Amortizing loan where principal can increase when payments are less than
accrued interest (negative amortization). Extends LAM pattern with modified
principal redemption handling.
ACTUS Reference:
ACTUS v1.1 Section 7.4 - NAM: Negative Amortizer
Attributes:
attributes: Contract attributes
risk_factor_observer: Risk factor observer for market data
Example:
See module docstring for usage example.
"""
[docs]
def __init__(
self,
attributes: ContractAttributes,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: Any | None = None,
):
"""Initialize NAM contract.
Args:
attributes: Contract attributes
risk_factor_observer: Risk factor observer
Raises:
ValueError: If contract_type is not NAM
ValueError: If required attributes missing
"""
if attributes.contract_type != ContractType.NAM:
raise ValueError(f"Contract type must be NAM, got {attributes.contract_type}")
# Validate required attributes
if not attributes.initial_exchange_date:
raise ValueError("initial_exchange_date required for NAM")
if not attributes.principal_redemption_cycle and not attributes.maturity_date:
raise ValueError("Either principal_redemption_cycle or maturity_date required")
super().__init__(
attributes=attributes,
risk_factor_observer=risk_factor_observer,
child_contract_observer=child_contract_observer,
)
[docs]
def generate_event_schedule(self) -> EventSchedule:
"""Generate NAM event schedule per ACTUS specification.
Same as LAM schedule, with all event types supported.
Events before status_date are excluded.
"""
events: list[ContractEvent] = []
attrs = self.attributes
ied = attrs.initial_exchange_date
md = attrs.maturity_date
sd = attrs.status_date
currency = attrs.currency or "XXX"
assert ied is not None
bdc = attrs.business_day_convention
eomc = attrs.end_of_month_convention
cal = attrs.calendar
def _sched(
anchor: ActusDateTime, cycle: str, end: ActusDateTime | None
) -> list[ActusDateTime]:
return generate_schedule(
start=anchor,
cycle=cycle,
end=end,
end_of_month_convention=eomc or EndOfMonthConvention.SD,
business_day_convention=bdc or BusinessDayConvention.NULL,
calendar=cal or Calendar.NO_CALENDAR,
)
def _add(etype: EventType, time: ActusDateTime) -> None:
events.append(
ContractEvent(
event_type=etype,
event_time=time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=currency,
sequence=0,
)
)
# IED: only if IED >= SD
if ied >= sd:
_add(EventType.IED, ied)
# PR: Principal Redemption schedule
if attrs.principal_redemption_cycle:
pr_anchor = attrs.principal_redemption_anchor or ied
pr_dates = _sched(pr_anchor, attrs.principal_redemption_cycle, md)
# Long stub: remove last PR date before MD if it's not on cycle end
pr_cycle_str = attrs.principal_redemption_cycle or ""
if pr_cycle_str.endswith("+") and pr_dates and md and pr_dates[-1] != md:
pr_dates = pr_dates[:-1]
for dt in pr_dates:
if md and dt >= md:
break
if dt >= ied:
_add(EventType.PR, dt)
# IP: Interest Payment schedule from IPANX (or IED)
if attrs.interest_payment_cycle:
ip_anchor = attrs.interest_payment_anchor or ied
ipced = attrs.interest_capitalization_end_date
ip_dates = _sched(ip_anchor, attrs.interest_payment_cycle, md)
if ipced and ipced not in ip_dates:
ip_dates = sorted(set(ip_dates + [ipced]))
ip_cycle_str = attrs.interest_payment_cycle or ""
if md and md not in ip_dates and ip_dates:
if ip_cycle_str.endswith("+"):
ip_dates[-1] = md
else:
ip_dates.append(md)
ip_dates = sorted(set(ip_dates))
for dt in ip_dates:
if dt < ied:
continue
if ipced and dt <= ipced:
_add(EventType.IPCI, dt)
else:
_add(EventType.IP, dt)
# IPCB: Interest Calculation Base schedule (only if IPCB='NTL')
if attrs.interest_calculation_base == "NTL" and attrs.interest_calculation_base_cycle:
ipcb_anchor = attrs.interest_calculation_base_anchor or ied
ipcb_dates = _sched(ipcb_anchor, attrs.interest_calculation_base_cycle, md)
for dt in ipcb_dates:
if dt > ied:
_add(EventType.IPCB, dt)
# RR: Rate Reset schedule (exclude events at MD, handle long stub)
if attrs.rate_reset_cycle and attrs.rate_reset_anchor:
rr_dates = _sched(attrs.rate_reset_anchor, attrs.rate_reset_cycle, md)
rr_cycle_str = attrs.rate_reset_cycle or ""
if rr_cycle_str.endswith("+") and rr_dates and md and rr_dates[-1] != md:
rr_dates = rr_dates[:-1]
first_rr = True
for dt in rr_dates:
if md and dt >= md:
break
if first_rr and attrs.rate_reset_next is not None:
_add(EventType.RRF, dt)
first_rr = False
else:
_add(EventType.RR, dt)
first_rr = False
# FP: Fee Payment schedule
if attrs.fee_payment_cycle:
fp_anchor = attrs.fee_payment_anchor or ied
fp_dates = _sched(fp_anchor, attrs.fee_payment_cycle, md)
for dt in fp_dates:
if dt > ied:
_add(EventType.FP, dt)
# SC: Scaling Index schedule
if attrs.scaling_index_cycle:
sc_anchor = attrs.scaling_index_anchor or ied
sc_dates = _sched(sc_anchor, attrs.scaling_index_cycle, md)
for dt in sc_dates:
if dt > ied:
_add(EventType.SC, dt)
# PRD: Purchase date
if attrs.purchase_date:
_add(EventType.PRD, attrs.purchase_date)
# TD: Termination date
if attrs.termination_date:
_add(EventType.TD, attrs.termination_date)
# MD: Maturity Date
if md:
_add(EventType.MD, md)
# Filter out events before SD, sort
events = [e for e in events if e.event_time >= sd]
# If PRD exists, remove IED and non-PRD events before/at PRD
if attrs.purchase_date:
prd_time = attrs.purchase_date
events = [
e
for e in events
if e.event_type == EventType.PRD
or (e.event_type != EventType.IED and e.event_time > prd_time)
]
events.sort(
key=lambda e: (e.event_time.to_iso(), EVENT_SCHEDULE_PRIORITY.get(e.event_type, 99))
)
# If TD exists, remove all events after TD
if attrs.termination_date:
td_time = attrs.termination_date
events = [e for e in events if e.event_time <= td_time]
# Reassign sequence numbers
for i in range(len(events)):
events[i] = ContractEvent(
event_type=events[i].event_type,
event_time=events[i].event_time,
payoff=events[i].payoff,
currency=events[i].currency,
sequence=i,
)
return EventSchedule(events=tuple(events), contract_id=attrs.contract_id)
def _pre_simulate_to_prd(self, attrs: ContractAttributes, prnxt: jnp.ndarray) -> ContractState:
"""Pre-simulate events from IED to PRD to get correct initial state."""
ied = attrs.initial_exchange_date
prd = attrs.purchase_date
sd = attrs.status_date
md = attrs.maturity_date
assert ied is not None
assert prd is not None
bdc = attrs.business_day_convention
eomc = attrs.end_of_month_convention
cal = attrs.calendar
def _sched(
anchor: ActusDateTime | None, cycle: str | None, end: ActusDateTime | None
) -> list[ActusDateTime]:
if anchor is None or cycle is None or end is None:
return []
return generate_schedule(
start=anchor,
cycle=cycle,
end=end,
end_of_month_convention=eomc or EndOfMonthConvention.SD,
business_day_convention=bdc or BusinessDayConvention.NULL,
calendar=cal or Calendar.NO_CALENDAR,
)
pre_events: list[tuple[ActusDateTime, EventType]] = []
pre_events.append((ied, EventType.IED))
if attrs.principal_redemption_cycle:
pr_anchor = attrs.principal_redemption_anchor or ied
pr_dates = _sched(pr_anchor, attrs.principal_redemption_cycle, md or prd)
for dt in pr_dates:
if dt >= ied and dt <= prd:
pre_events.append((dt, EventType.PR))
if attrs.interest_payment_cycle:
ip_anchor = attrs.interest_payment_anchor or ied
ipced = attrs.interest_capitalization_end_date
ip_dates = _sched(ip_anchor, attrs.interest_payment_cycle, md or prd)
for dt in ip_dates:
if dt < ied or dt > prd:
continue
if ipced and dt <= ipced:
pre_events.append((dt, EventType.IPCI))
else:
pre_events.append((dt, EventType.IP))
if attrs.rate_reset_cycle and attrs.rate_reset_anchor:
rr_dates = _sched(attrs.rate_reset_anchor, attrs.rate_reset_cycle, md or prd)
first_rr = True
for dt in rr_dates:
if dt > prd:
break
if first_rr and attrs.rate_reset_next is not None:
pre_events.append((dt, EventType.RRF))
first_rr = False
else:
pre_events.append((dt, EventType.RR))
first_rr = False
if attrs.interest_calculation_base == "NTL" and attrs.interest_calculation_base_cycle:
ipcb_anchor = attrs.interest_calculation_base_anchor or ied
ipcb_dates = _sched(ipcb_anchor, attrs.interest_calculation_base_cycle, md or prd)
for dt in ipcb_dates:
if dt > ied and dt <= prd:
pre_events.append((dt, EventType.IPCB))
pre_events = [(t, e) for t, e in pre_events if t >= ied]
pre_events.sort(key=lambda e: (e[0].to_iso(), EVENT_SCHEDULE_PRIORITY.get(e[1], 99)))
state = ContractState(
sd=ied,
tmd=md or ied,
nt=jnp.array(0.0, dtype=jnp.float32),
ipnr=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prnxt=prnxt,
ipcb=jnp.array(0.0, dtype=jnp.float32),
)
stf = self.get_state_transition_function(None)
for time, etype in pre_events:
state = stf.transition_state(
event_type=etype,
state=state,
attributes=attrs,
time=time,
risk_factor_observer=self.risk_factor_observer,
)
if ied < sd and sd < prd:
state = state.replace(sd=sd, ipac=jnp.array(0.0, dtype=jnp.float32))
return state
[docs]
def initialize_state(self) -> ContractState:
"""Initialize NAM contract state.
When IED < SD (contract already existed), state is initialized
as if STF_IED already ran.
Returns:
Initial contract state
"""
attrs = self.attributes
sd = attrs.status_date
ied = attrs.initial_exchange_date
role_sign = contract_role_sign(attrs.contract_role)
# Initialize Prnxt (next principal redemption amount)
prnxt_val = attrs.next_principal_redemption_amount or 0.0
prnxt = jnp.array(role_sign * prnxt_val, dtype=jnp.float32)
# PRD pre-simulation
if attrs.purchase_date and ied:
return self._pre_simulate_to_prd(attrs, prnxt)
needs_post_ied = ied and ied < sd
if needs_post_ied:
assert ied is not None
# Contract already started - initialize post-IED state
nt = role_sign * (attrs.notional_principal or 0.0)
ipnr = attrs.nominal_interest_rate or 0.0
dcc = attrs.day_count_convention or DayCountConvention.A360
accrual_start = attrs.interest_payment_anchor or ied
if attrs.accrued_interest is not None:
ipac = attrs.accrued_interest
else:
yf = year_fraction(accrual_start, sd, dcc)
ipac = yf * ipnr * abs(nt)
ipcb_val = abs(nt)
return ContractState(
sd=sd,
tmd=attrs.maturity_date or sd,
nt=jnp.array(nt, dtype=jnp.float32),
ipnr=jnp.array(ipnr, dtype=jnp.float32),
ipac=jnp.array(ipac, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prnxt=prnxt,
ipcb=jnp.array(ipcb_val, dtype=jnp.float32),
)
return ContractState(
sd=sd,
tmd=attrs.maturity_date or sd,
nt=jnp.array(0.0, dtype=jnp.float32), # Set at IED
ipnr=jnp.array(0.0, dtype=jnp.float32), # Set at IED
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prnxt=prnxt,
ipcb=jnp.array(0.0, dtype=jnp.float32), # Set at IED
)
[docs]
def get_payoff_function(self, event_type: Any) -> NAMPayoffFunction:
"""Get NAM payoff function.
Args:
event_type: Type of event (not used, all events use same POF)
Returns:
NAM payoff function instance
"""
return NAMPayoffFunction(
contract_role=self.attributes.contract_role,
currency=self.attributes.currency,
settlement_currency=None,
)
[docs]
def get_state_transition_function(self, event_type: Any) -> NAMStateTransitionFunction:
"""Get NAM state transition function.
Args:
event_type: Type of event (not used, all events use same STF)
Returns:
NAM state transition function instance
"""
return NAMStateTransitionFunction()