"""Array-mode PAM simulation — JIT-compiled, vmap-able pure JAX.
This module provides a high-performance simulation path for PAM (Principal at
Maturity) contracts using ``jax.lax.scan`` for the event loop and
``jax.lax.switch`` for payoff/state-transition dispatch. The entire simulation
kernel is JIT-compilable and can be vectorized across a portfolio with
``jax.vmap``.
Architecture:
Pre-computation (Python) → Pure JAX kernel (jit + vmap)
The existing ``PrincipalAtMaturityContract`` generates event schedules and
initializes state (Python-level, runs once per contract). This module
converts the results to JAX arrays and runs the numerical simulation as a
pure function.
Example::
from jactus.contracts.pam_array import precompute_pam_arrays, simulate_pam_array
arrays = precompute_pam_arrays(attrs, rf_observer)
final_state, payoffs = simulate_pam_array(*arrays)
# Portfolio:
from jactus.contracts.pam_array import simulate_pam_portfolio
result = simulate_pam_portfolio(contracts, discount_rate=0.05)
"""
from __future__ import annotations
from datetime import datetime as _datetime
from typing import Any, NamedTuple
import jax
import jax.numpy as jnp
import numpy as np
from jactus.contracts.array_common import (
# Cached EventType indices
AD_IDX as _AD_IDX,
)
from jactus.contracts.array_common import (
CE_IDX as _CE_IDX,
)
from jactus.contracts.array_common import (
# Schedule helpers
CYCLE_MONTHS_MAP as _CYCLE_MONTHS_MAP,
)
from jactus.contracts.array_common import (
F32 as _F32,
)
from jactus.contracts.array_common import (
FP_IDX as _FP_IDX,
)
from jactus.contracts.array_common import (
IED_IDX as _IED_IDX,
)
from jactus.contracts.array_common import (
IP_IDX as _IP_IDX,
)
from jactus.contracts.array_common import (
IPCI_IDX as _IPCI_IDX,
)
from jactus.contracts.array_common import (
MD_IDX as _MD_IDX,
)
# Import shared infrastructure from array_common
from jactus.contracts.array_common import (
NOP_EVENT_IDX,
)
from jactus.contracts.array_common import (
PP_IDX as _PP_IDX,
)
from jactus.contracts.array_common import (
PRD_IDX as _PRD_IDX,
)
from jactus.contracts.array_common import (
PY_IDX as _PY_IDX,
)
from jactus.contracts.array_common import (
RR_IDX as _RR_IDX,
)
from jactus.contracts.array_common import (
RRF_IDX as _RRF_IDX,
)
from jactus.contracts.array_common import (
SC_IDX as _SC_IDX,
)
from jactus.contracts.array_common import (
TD_IDX as _TD_IDX,
)
from jactus.contracts.array_common import (
USE_BATCH_SCHEDULE as _USE_BATCH_SCHEDULE,
)
from jactus.contracts.array_common import (
USE_DATE_ARRAY as _USE_DATE_ARRAY,
)
from jactus.contracts.array_common import (
# Batch infrastructure
BatchContractParams as _BatchContractParams,
)
from jactus.contracts.array_common import (
RawPrecomputed as _RawPrecomputed,
)
from jactus.contracts.array_common import (
# Date helpers
adt_to_dt as _adt_to_dt,
)
from jactus.contracts.array_common import (
compute_max_ip as _compute_max_ip,
)
from jactus.contracts.array_common import (
compute_vectorised_year_fractions as _compute_vectorised_year_fractions,
)
from jactus.contracts.array_common import (
dt_to_adt as _dt_to_adt,
)
from jactus.contracts.array_common import (
# Encoding helpers
encode_fee_basis as _encode_fee_basis,
)
from jactus.contracts.array_common import (
encode_penalty_type as _encode_penalty_type,
)
from jactus.contracts.array_common import (
extract_batch_params as _extract_batch_params,
)
from jactus.contracts.array_common import (
fast_schedule as _fast_schedule,
)
from jactus.contracts.array_common import (
get_evt_priority as _get_evt_priority,
)
from jactus.contracts.array_common import (
get_role_sign as _get_role_sign,
)
from jactus.contracts.array_common import (
jax_batch_ip_schedule as _jax_batch_ip_schedule,
)
from jactus.contracts.array_common import (
jax_batch_year_fractions as _jax_batch_year_fractions,
)
from jactus.contracts.array_common import (
parse_cycle_fast as _parse_cycle_fast,
)
from jactus.contracts.array_common import (
prequery_risk_factors as _prequery_risk_factors,
)
from jactus.core import (
ContractAttributes,
)
from jactus.observers import RiskFactorObserver
from jactus.utilities.conventions import year_fraction
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
[docs]
class PAMArrayState(NamedTuple):
"""Minimal scan-loop state for PAM simulation.
All fields are scalar ``jnp.ndarray`` (float32). ``sd`` (status date) is
omitted because year fractions are pre-computed before the JIT boundary.
"""
nt: jnp.ndarray # Notional principal (signed)
ipnr: jnp.ndarray # Nominal interest rate
ipac: jnp.ndarray # Accrued interest
feac: jnp.ndarray # Accrued fees
nsc: jnp.ndarray # Notional scaling multiplier
isc: jnp.ndarray # Interest scaling multiplier
[docs]
class PAMArrayParams(NamedTuple):
"""Static contract parameters extracted from ``ContractAttributes``.
These do not change during the scan loop. Enum-based branches
(penalty type, fee basis) are encoded as integers for ``jnp.where``.
"""
role_sign: jnp.ndarray # +1.0 or -1.0
notional_principal: jnp.ndarray
nominal_interest_rate: jnp.ndarray
premium_discount_at_ied: jnp.ndarray
accrued_interest: jnp.ndarray # IPAC attribute (pre-computed for IED)
fee_rate: jnp.ndarray
fee_basis: jnp.ndarray # 0=A, 1=N, 2=other
penalty_rate: jnp.ndarray
penalty_type: jnp.ndarray # 0=A, 1=N, 2=I
price_at_purchase_date: jnp.ndarray
price_at_termination_date: jnp.ndarray
rate_reset_spread: jnp.ndarray
rate_reset_multiplier: jnp.ndarray
rate_reset_floor: jnp.ndarray
rate_reset_cap: jnp.ndarray
rate_reset_next: jnp.ndarray
has_rate_floor: jnp.ndarray # 1.0 if floor is active, else 0.0
has_rate_cap: jnp.ndarray # 1.0 if cap is active, else 0.0
ied_ipac: jnp.ndarray # Pre-computed accrued interest at IED
# ============================================================================
# Pure JAX payoff functions (state, params, yf, rf) → scalar payoff
# ============================================================================
def _pof_ad(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return jnp.array(0.0, dtype=_F32)
def _pof_ied(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return params.role_sign * (-1.0) * (params.notional_principal + params.premium_discount_at_ied)
def _pof_md(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return state.nsc * state.nt + state.isc * state.ipac + state.feac
def _pof_pp(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
# rf = pre-computed prepayment amount from observer
return rf
def _pof_py(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
# penalty_type: 0=A, 1=N, 2=I
pof_a = params.penalty_rate
pof_ni = yf * state.nt * params.penalty_rate
return jnp.where(params.penalty_type == 0, pof_a, pof_ni)
def _pof_fp(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
# fee_basis: 0=A, 1=N, 2=other
pof_a = params.fee_rate
pof_n = yf * state.nt * params.fee_rate + state.feac
return jnp.where(
params.fee_basis == 0,
pof_a,
jnp.where(params.fee_basis == 1, pof_n, state.feac),
)
def _pof_prd(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return (-1.0) * (params.price_at_purchase_date + state.ipac + yf * state.ipnr * state.nt)
def _pof_td(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return params.price_at_termination_date + state.ipac + yf * state.ipnr * state.nt
def _pof_ip(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return state.isc * (state.ipac + yf * state.ipnr * state.nt)
def _pof_ipci(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return jnp.array(0.0, dtype=_F32)
def _pof_rr(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return jnp.array(0.0, dtype=_F32)
def _pof_rrf(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return jnp.array(0.0, dtype=_F32)
def _pof_sc(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return jnp.array(0.0, dtype=_F32)
def _pof_ce(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return jnp.array(0.0, dtype=_F32)
def _pof_noop(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> jnp.ndarray:
return jnp.array(0.0, dtype=_F32)
# ============================================================================
# Pure JAX state transition functions (state, params, yf, rf) → new state
# ============================================================================
def _accrue_interest(state: PAMArrayState, yf: jnp.ndarray) -> jnp.ndarray:
"""Common sub-expression: ipac + yf * ipnr * nt."""
return state.ipac + yf * state.ipnr * state.nt
def _stf_ad(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return state._replace(ipac=_accrue_interest(state, yf))
def _stf_ied(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return PAMArrayState(
nt=params.role_sign * params.notional_principal,
ipnr=params.nominal_interest_rate,
ipac=params.ied_ipac,
feac=jnp.array(0.0, dtype=_F32),
nsc=jnp.array(1.0, dtype=_F32),
isc=jnp.array(1.0, dtype=_F32),
)
def _stf_md(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return state._replace(
nt=jnp.array(0.0, dtype=_F32),
ipac=jnp.array(0.0, dtype=_F32),
feac=jnp.array(0.0, dtype=_F32),
)
def _stf_pp(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
new_ipac = _accrue_interest(state, yf)
new_nt = state.nt - rf # rf = prepayment amount
return state._replace(nt=new_nt, ipac=new_ipac)
def _stf_py(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return _stf_ad(state, params, yf, rf)
def _stf_fp(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
new_ipac = _accrue_interest(state, yf)
return state._replace(ipac=new_ipac, feac=jnp.array(0.0, dtype=_F32))
def _stf_prd(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return _stf_ad(state, params, yf, rf)
def _stf_td(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return state._replace(
nt=jnp.array(0.0, dtype=_F32),
ipac=jnp.array(0.0, dtype=_F32),
feac=jnp.array(0.0, dtype=_F32),
)
def _stf_ip(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return state._replace(ipac=jnp.array(0.0, dtype=_F32))
def _stf_ipci(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
total_ipac = _accrue_interest(state, yf)
return state._replace(nt=state.nt + total_ipac, ipac=jnp.array(0.0, dtype=_F32))
def _stf_rr(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
new_ipac = _accrue_interest(state, yf)
# rf = observed market rate
raw_rate = params.rate_reset_multiplier * rf + params.rate_reset_spread
# Branchless floor/cap using jnp.where + jnp.maximum/minimum
clamped = jnp.where(
params.has_rate_floor > 0.5,
jnp.maximum(raw_rate, params.rate_reset_floor),
raw_rate,
)
clamped = jnp.where(
params.has_rate_cap > 0.5,
jnp.minimum(clamped, params.rate_reset_cap),
clamped,
)
return state._replace(ipnr=clamped, ipac=new_ipac)
def _stf_rrf(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
new_ipac = _accrue_interest(state, yf)
return state._replace(ipnr=params.rate_reset_next, ipac=new_ipac)
def _stf_sc(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return _stf_ad(state, params, yf, rf)
def _stf_ce(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return _stf_ad(state, params, yf, rf)
def _stf_noop(
state: PAMArrayState, params: PAMArrayParams, yf: jnp.ndarray, rf: jnp.ndarray
) -> PAMArrayState:
return state
# ============================================================================
# Dispatch tables — indexed by EventType.index (0..23) + NOP (24)
# ============================================================================
# fmt: off
_POF_TABLE: list[Any] = [
_pof_ad, # 0 AD
_pof_ied, # 1 IED
_pof_md, # 2 MD
_pof_noop, # 3 PR (not used in PAM)
_pof_noop, # 4 PI (not used in PAM)
_pof_pp, # 5 PP
_pof_py, # 6 PY
_pof_noop, # 7 PRF (not used in PAM)
_pof_fp, # 8 FP
_pof_prd, # 9 PRD
_pof_td, # 10 TD
_pof_ip, # 11 IP
_pof_ipci, # 12 IPCI
_pof_noop, # 13 IPCB (not used in PAM)
_pof_rr, # 14 RR
_pof_rrf, # 15 RRF
_pof_noop, # 16 DV (not used in PAM)
_pof_noop, # 17 DVF (not used in PAM)
_pof_sc, # 18 SC
_pof_noop, # 19 STD (not used in PAM)
_pof_noop, # 20 XD (not used in PAM)
_pof_ce, # 21 CE
_pof_noop, # 22 IPFX (not used in PAM)
_pof_noop, # 23 IPFL (not used in PAM)
_pof_noop, # 24 NOP (padding)
]
_STF_TABLE: list[Any] = [
_stf_ad, # 0 AD
_stf_ied, # 1 IED
_stf_md, # 2 MD
_stf_noop, # 3 PR
_stf_noop, # 4 PI
_stf_pp, # 5 PP
_stf_py, # 6 PY
_stf_noop, # 7 PRF
_stf_fp, # 8 FP
_stf_prd, # 9 PRD
_stf_td, # 10 TD
_stf_ip, # 11 IP
_stf_ipci, # 12 IPCI
_stf_noop, # 13 IPCB
_stf_rr, # 14 RR
_stf_rrf, # 15 RRF
_stf_noop, # 16 DV
_stf_noop, # 17 DVF
_stf_sc, # 18 SC
_stf_noop, # 19 STD
_stf_noop, # 20 XD
_stf_ce, # 21 CE
_stf_noop, # 22 IPFX
_stf_noop, # 23 IPFL
_stf_noop, # 24 NOP
]
# fmt: on
assert len(_POF_TABLE) == NOP_EVENT_IDX + 1
assert len(_STF_TABLE) == NOP_EVENT_IDX + 1
# ============================================================================
# JIT-compiled simulation kernel
# ============================================================================
[docs]
def simulate_pam_array(
initial_state: PAMArrayState,
event_types: jnp.ndarray,
year_fractions: jnp.ndarray,
rf_values: jnp.ndarray,
params: PAMArrayParams,
) -> tuple[PAMArrayState, jnp.ndarray]:
"""Run a PAM simulation as a pure JAX function.
This function is JIT-compilable and vmap-able.
Args:
initial_state: Starting state (6 scalar fields).
event_types: ``(num_events,)`` int32 — ``EventType.index`` values.
year_fractions: ``(num_events,)`` float32 — pre-computed YF per event.
rf_values: ``(num_events,)`` float32 — pre-computed risk factor per event.
params: Static contract parameters.
Returns:
``(final_state, payoffs)`` where payoffs is ``(num_events,)`` float32.
"""
def step(
state: PAMArrayState, inputs: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
) -> tuple[PAMArrayState, jnp.ndarray]:
evt_idx, yf, rf = inputs
payoff = jax.lax.switch(evt_idx, _POF_TABLE, state, params, yf, rf)
new_state = jax.lax.switch(evt_idx, _STF_TABLE, state, params, yf, rf)
return new_state, payoff
final_state, payoffs = jax.lax.scan(
step, initial_state, (event_types, year_fractions, rf_values), unroll=8
)
return final_state, payoffs
# JIT-compiled version for single-contract use
simulate_pam_array_jit = jax.jit(simulate_pam_array)
# Vmapped version (kept as fallback, e.g. for GPU where vmap is efficient)
batch_simulate_pam_vmap = jax.vmap(simulate_pam_array)
[docs]
def batch_simulate_pam_auto(
initial_states: PAMArrayState,
event_types: jnp.ndarray,
year_fractions: jnp.ndarray,
rf_values: jnp.ndarray,
params: PAMArrayParams,
) -> tuple[PAMArrayState, jnp.ndarray]:
"""Batched simulation using the optimal strategy for all backends.
Uses the single-scan batch approach (``batch_simulate_pam``) which
processes all contracts in shaped ``[B, T]`` arrays via a single
``lax.scan``. This is faster than ``vmap`` on CPU, GPU, *and* TPU
because it avoids per-contract dispatch overhead.
The ``vmap`` variant (``batch_simulate_pam_vmap``) remains available
for explicit use but is not selected automatically.
"""
return batch_simulate_pam(initial_states, event_types, year_fractions, rf_values, params) # type: ignore[no-any-return]
# ============================================================================
# Manually-batched simulation — eliminates vmap dispatch overhead on CPU
# ============================================================================
[docs]
@jax.jit
def batch_simulate_pam(
initial_states: PAMArrayState,
event_types: jnp.ndarray,
year_fractions: jnp.ndarray,
rf_values: jnp.ndarray,
params: PAMArrayParams,
) -> tuple[PAMArrayState, jnp.ndarray]:
"""Batched PAM simulation without vmap — single scan over ``[B]`` arrays.
Eliminates JAX vmap CPU dispatch overhead by operating directly on
batch-dimensioned arrays. Each scan step computes all event-type
outcomes for all contracts simultaneously using branchless
``jnp.where`` dispatch.
Args:
initial_states: ``PAMArrayState`` with each field shape ``[B]``.
event_types: ``[B, T]`` int32 — event type indices per contract.
year_fractions: ``[B, T]`` float32.
rf_values: ``[B, T]`` float32.
params: ``PAMArrayParams`` with each field shape ``[B]``.
Returns:
``(final_states, payoffs)`` where ``payoffs`` is ``[B, T]``.
"""
# Transpose to [T, B] so scan iterates over time steps
et_t = event_types.T
yf_t = year_fractions.T
rf_t = rf_values.T
def step(
states: PAMArrayState,
inputs: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
) -> tuple[PAMArrayState, jnp.ndarray]:
et, yf, rf = inputs # each [B]
# Common sub-expression: interest accrual
accrue = states.ipac + yf * states.ipnr * states.nt
# ---- Payoffs (branchless jnp.where dispatch) ----
payoff = jnp.zeros_like(states.nt)
# IED: role_sign * (-1) * (notional + premium)
payoff = jnp.where(
et == _IED_IDX,
params.role_sign
* (-1.0)
* (params.notional_principal + params.premium_discount_at_ied),
payoff,
)
# MD: nsc * nt + isc * ipac + feac (uses state.ipac, NOT accrue)
payoff = jnp.where(
et == _MD_IDX,
states.nsc * states.nt + states.isc * states.ipac + states.feac,
payoff,
)
# PP: rf (pre-computed prepayment amount)
payoff = jnp.where(et == _PP_IDX, rf, payoff)
# PY: penalty (type-dependent)
payoff = jnp.where(
et == _PY_IDX,
jnp.where(
params.penalty_type == 0,
params.penalty_rate,
yf * states.nt * params.penalty_rate,
),
payoff,
)
# FP: fee payment (basis-dependent)
payoff = jnp.where(
et == _FP_IDX,
jnp.where(
params.fee_basis == 0,
params.fee_rate,
jnp.where(
params.fee_basis == 1,
yf * states.nt * params.fee_rate + states.feac,
states.feac,
),
),
payoff,
)
# PRD: -(price + accrue)
payoff = jnp.where(
et == _PRD_IDX,
(-1.0) * (params.price_at_purchase_date + accrue),
payoff,
)
# TD: price + accrue
payoff = jnp.where(
et == _TD_IDX,
params.price_at_termination_date + accrue,
payoff,
)
# IP: isc * accrue
payoff = jnp.where(et == _IP_IDX, states.isc * accrue, payoff)
# ---- State transitions (branchless) ----
# nt: default unchanged
new_nt = states.nt
new_nt = jnp.where(
et == _IED_IDX,
params.role_sign * params.notional_principal,
new_nt,
)
new_nt = jnp.where((et == _MD_IDX) | (et == _TD_IDX), 0.0, new_nt)
new_nt = jnp.where(et == _PP_IDX, states.nt - rf, new_nt)
new_nt = jnp.where(et == _IPCI_IDX, states.nt + accrue, new_nt)
# ipnr: default unchanged
new_ipnr = states.ipnr
new_ipnr = jnp.where(et == _IED_IDX, params.nominal_interest_rate, new_ipnr)
# RR: clamp(multiplier * rf + spread, floor, cap)
raw_rate = params.rate_reset_multiplier * rf + params.rate_reset_spread
clamped = jnp.where(
params.has_rate_floor > 0.5,
jnp.maximum(raw_rate, params.rate_reset_floor),
raw_rate,
)
clamped = jnp.where(
params.has_rate_cap > 0.5,
jnp.minimum(clamped, params.rate_reset_cap),
clamped,
)
new_ipnr = jnp.where(et == _RR_IDX, clamped, new_ipnr)
new_ipnr = jnp.where(et == _RRF_IDX, params.rate_reset_next, new_ipnr)
# ipac: three distinct behaviours
# accrue group: AD, PP, PY, FP, PRD, RR, RRF, SC, CE
# zero group: MD, TD, IP, IPCI
# special: IED → ied_ipac
# default: unchanged (NOP and unused event types)
is_accrue = (
(et == _AD_IDX)
| (et == _PP_IDX)
| (et == _PY_IDX)
| (et == _FP_IDX)
| (et == _PRD_IDX)
| (et == _RR_IDX)
| (et == _RRF_IDX)
| (et == _SC_IDX)
| (et == _CE_IDX)
)
is_zero_ipac = (et == _MD_IDX) | (et == _TD_IDX) | (et == _IP_IDX) | (et == _IPCI_IDX)
new_ipac = states.ipac
new_ipac = jnp.where(is_accrue, accrue, new_ipac)
new_ipac = jnp.where(is_zero_ipac, 0.0, new_ipac)
new_ipac = jnp.where(et == _IED_IDX, params.ied_ipac, new_ipac)
# feac: zero at IED, MD, FP, TD; unchanged otherwise
new_feac = jnp.where(
(et == _IED_IDX) | (et == _MD_IDX) | (et == _FP_IDX) | (et == _TD_IDX),
0.0,
states.feac,
)
# nsc, isc: only change at IED (set to 1.0)
new_nsc = jnp.where(et == _IED_IDX, 1.0, states.nsc)
new_isc = jnp.where(et == _IED_IDX, 1.0, states.isc)
new_state = PAMArrayState(
nt=new_nt,
ipnr=new_ipnr,
ipac=new_ipac,
feac=new_feac,
nsc=new_nsc,
isc=new_isc,
)
return new_state, payoff
final_states, payoffs_t = jax.lax.scan(step, initial_states, (et_t, yf_t, rf_t), unroll=8)
# payoffs_t is [T, B]; transpose back to [B, T]
return final_states, payoffs_t.T
# ============================================================================
# Pre-computation bridge — Python → JAX arrays
# ============================================================================
def _extract_params(attrs: ContractAttributes) -> PAMArrayParams:
"""Extract ``PAMArrayParams`` from ``ContractAttributes``."""
role_sign = _get_role_sign(attrs.contract_role)
nt = attrs.notional_principal or 0.0
ipnr = attrs.nominal_interest_rate or 0.0
# Pre-compute IED accrued interest (same logic as _stf_ied)
ied_ipac = 0.0
if attrs.accrued_interest is not None:
ied_ipac = attrs.accrued_interest
elif (
attrs.interest_payment_anchor is not None
and attrs.initial_exchange_date is not None
and attrs.interest_payment_anchor < attrs.initial_exchange_date
):
from jactus.core.types import DayCountConvention
dcc = attrs.day_count_convention or DayCountConvention.A360
yf = year_fraction(attrs.interest_payment_anchor, attrs.initial_exchange_date, dcc)
ied_ipac = yf * ipnr * abs(role_sign * nt)
has_floor = attrs.rate_reset_floor is not None
has_cap = attrs.rate_reset_cap is not None
return PAMArrayParams(
role_sign=jnp.array(role_sign, dtype=_F32),
notional_principal=jnp.array(nt, dtype=_F32),
nominal_interest_rate=jnp.array(ipnr, dtype=_F32),
premium_discount_at_ied=jnp.array(attrs.premium_discount_at_ied or 0.0, dtype=_F32),
accrued_interest=jnp.array(attrs.accrued_interest or 0.0, dtype=_F32),
fee_rate=jnp.array(attrs.fee_rate or 0.0, dtype=_F32),
fee_basis=jnp.array(_encode_fee_basis(attrs), dtype=jnp.int32),
penalty_rate=jnp.array(attrs.penalty_rate or 0.0, dtype=_F32),
penalty_type=jnp.array(_encode_penalty_type(attrs), dtype=jnp.int32),
price_at_purchase_date=jnp.array(attrs.price_at_purchase_date or 0.0, dtype=_F32),
price_at_termination_date=jnp.array(attrs.price_at_termination_date or 0.0, dtype=_F32),
rate_reset_spread=jnp.array(attrs.rate_reset_spread or 0.0, dtype=_F32),
rate_reset_multiplier=jnp.array(
attrs.rate_reset_multiplier if attrs.rate_reset_multiplier is not None else 1.0,
dtype=_F32,
),
rate_reset_floor=jnp.array(attrs.rate_reset_floor or 0.0, dtype=_F32),
rate_reset_cap=jnp.array(attrs.rate_reset_cap or 1.0, dtype=_F32),
rate_reset_next=jnp.array(
attrs.rate_reset_next if attrs.rate_reset_next is not None else ipnr,
dtype=_F32,
),
has_rate_floor=jnp.array(1.0 if has_floor else 0.0, dtype=_F32),
has_rate_cap=jnp.array(1.0 if has_cap else 0.0, dtype=_F32),
ied_ipac=jnp.array(ied_ipac, dtype=_F32),
)
def _extract_params_raw(attrs: ContractAttributes) -> dict[str, float | int]:
"""Extract params as plain Python floats/ints (no jnp.array overhead)."""
role_sign = _get_role_sign(attrs.contract_role)
nt = attrs.notional_principal or 0.0
ipnr = attrs.nominal_interest_rate or 0.0
ied_ipac = 0.0
if attrs.accrued_interest is not None:
ied_ipac = attrs.accrued_interest
elif (
attrs.interest_payment_anchor is not None
and attrs.initial_exchange_date is not None
and attrs.interest_payment_anchor < attrs.initial_exchange_date
):
from jactus.core.types import DayCountConvention
dcc = attrs.day_count_convention or DayCountConvention.A360
yf = year_fraction(attrs.interest_payment_anchor, attrs.initial_exchange_date, dcc)
ied_ipac = yf * ipnr * abs(role_sign * nt)
return {
"role_sign": role_sign,
"notional_principal": nt,
"nominal_interest_rate": ipnr,
"premium_discount_at_ied": attrs.premium_discount_at_ied or 0.0,
"accrued_interest": attrs.accrued_interest or 0.0,
"fee_rate": attrs.fee_rate or 0.0,
"fee_basis": _encode_fee_basis(attrs),
"penalty_rate": attrs.penalty_rate or 0.0,
"penalty_type": _encode_penalty_type(attrs),
"price_at_purchase_date": attrs.price_at_purchase_date or 0.0,
"price_at_termination_date": attrs.price_at_termination_date or 0.0,
"rate_reset_spread": attrs.rate_reset_spread or 0.0,
"rate_reset_multiplier": (
attrs.rate_reset_multiplier if attrs.rate_reset_multiplier is not None else 1.0
),
"rate_reset_floor": attrs.rate_reset_floor or 0.0,
"rate_reset_cap": attrs.rate_reset_cap or 1.0,
"rate_reset_next": attrs.rate_reset_next if attrs.rate_reset_next is not None else ipnr,
"has_rate_floor": 1.0 if attrs.rate_reset_floor is not None else 0.0,
"has_rate_cap": 1.0 if attrs.rate_reset_cap is not None else 0.0,
"ied_ipac": ied_ipac,
}
def _params_raw_to_jax(raw: dict[str, float | int]) -> PAMArrayParams:
"""Convert raw Python params to JAX PAMArrayParams."""
return PAMArrayParams(
role_sign=jnp.array(raw["role_sign"], dtype=_F32),
notional_principal=jnp.array(raw["notional_principal"], dtype=_F32),
nominal_interest_rate=jnp.array(raw["nominal_interest_rate"], dtype=_F32),
premium_discount_at_ied=jnp.array(raw["premium_discount_at_ied"], dtype=_F32),
accrued_interest=jnp.array(raw["accrued_interest"], dtype=_F32),
fee_rate=jnp.array(raw["fee_rate"], dtype=_F32),
fee_basis=jnp.array(raw["fee_basis"], dtype=jnp.int32),
penalty_rate=jnp.array(raw["penalty_rate"], dtype=_F32),
penalty_type=jnp.array(raw["penalty_type"], dtype=jnp.int32),
price_at_purchase_date=jnp.array(raw["price_at_purchase_date"], dtype=_F32),
price_at_termination_date=jnp.array(raw["price_at_termination_date"], dtype=_F32),
rate_reset_spread=jnp.array(raw["rate_reset_spread"], dtype=_F32),
rate_reset_multiplier=jnp.array(raw["rate_reset_multiplier"], dtype=_F32),
rate_reset_floor=jnp.array(raw["rate_reset_floor"], dtype=_F32),
rate_reset_cap=jnp.array(raw["rate_reset_cap"], dtype=_F32),
rate_reset_next=jnp.array(raw["rate_reset_next"], dtype=_F32),
has_rate_floor=jnp.array(raw["has_rate_floor"], dtype=_F32),
has_rate_cap=jnp.array(raw["has_rate_cap"], dtype=_F32),
ied_ipac=jnp.array(raw["ied_ipac"], dtype=_F32),
)
# ---------------------------------------------------------------------------
# Fast schedule generation — bypasses PrincipalAtMaturityContract entirely
# ---------------------------------------------------------------------------
def _fast_pam_schedule(
attrs: ContractAttributes,
) -> list[tuple[int, _datetime, _datetime]]:
"""Generate PAM schedule as lightweight (evt_idx, evt_dt, calc_dt) tuples.
Replicates the logic of ``PrincipalAtMaturityContract.generate_event_schedule``
without creating ``ContractEvent`` objects or a ``PrincipalAtMaturityContract``.
"""
from jactus.core.types import BusinessDayConvention
ied = attrs.initial_exchange_date
md = attrs.maturity_date
sd = attrs.status_date
assert ied is not None
assert md is not None
bdc = attrs.business_day_convention
# For non-NULL BDC or non-SD EOMC, fall back to the full path
has_bdc = bdc is not None and bdc != BusinessDayConvention.NULL
has_eomc = (
attrs.end_of_month_convention is not None and attrs.end_of_month_convention.value != "SD"
)
if has_bdc or has_eomc:
return _fallback_pam_schedule(attrs)
ied_dt = _adt_to_dt(ied)
md_dt = _adt_to_dt(md)
sd_dt = _adt_to_dt(sd)
# events: (evt_type_idx, event_time_dt, calc_time_dt)
events: list[tuple[int, _datetime, _datetime]] = []
# IED
if ied_dt >= sd_dt:
events.append((_IED_IDX, ied_dt, ied_dt))
# IP / IPCI
if attrs.interest_payment_cycle:
ip_anchor = attrs.interest_payment_anchor or ied
ipced = attrs.interest_capitalization_end_date
ip_dates = _fast_schedule(ip_anchor, attrs.interest_payment_cycle, md)
ipced_dt = _adt_to_dt(ipced) if ipced else None
# Add IPCED if not already on a cycle date
if ipced_dt and ipced_dt not in ip_dates:
ip_dates = sorted(set(ip_dates + [ipced_dt]))
# Stub handling
if md_dt not in ip_dates and ip_dates:
ip_cycle_str = attrs.interest_payment_cycle or ""
if ip_cycle_str.endswith("+"):
ip_dates[-1] = md_dt
else:
ip_dates.append(md_dt)
ip_dates = sorted(set(ip_dates))
for dt in ip_dates:
if dt < ied_dt:
continue
if ipced_dt and dt <= ipced_dt:
events.append((_IPCI_IDX, dt, dt))
else:
events.append((_IP_IDX, dt, dt))
# RR / RRF
if attrs.rate_reset_cycle and attrs.rate_reset_anchor:
rr_dates = _fast_schedule(attrs.rate_reset_anchor, attrs.rate_reset_cycle, md)
rr_cycle_str = attrs.rate_reset_cycle or ""
if rr_cycle_str.endswith("+") and rr_dates and rr_dates[-1] != md_dt:
rr_dates = rr_dates[:-1]
first_rr = True
for dt in rr_dates:
if dt >= md_dt:
break
if first_rr and attrs.rate_reset_next is not None:
events.append((_RRF_IDX, dt, dt))
first_rr = False
else:
events.append((_RR_IDX, dt, dt))
first_rr = False
# FP
if attrs.fee_payment_cycle:
fp_anchor = attrs.fee_payment_anchor or ied
fp_dates = _fast_schedule(fp_anchor, attrs.fee_payment_cycle, md)
for dt in fp_dates:
if dt > ied_dt:
events.append((_FP_IDX, dt, dt))
# SC
if attrs.scaling_index_cycle:
sc_anchor = attrs.scaling_index_anchor or ied
sc_dates = _fast_schedule(sc_anchor, attrs.scaling_index_cycle, md)
for dt in sc_dates:
if dt > ied_dt:
events.append((_SC_IDX, dt, dt))
# PRD
if attrs.purchase_date:
events.append((_PRD_IDX, _adt_to_dt(attrs.purchase_date), _adt_to_dt(attrs.purchase_date)))
# TD
if attrs.termination_date:
events.append(
(_TD_IDX, _adt_to_dt(attrs.termination_date), _adt_to_dt(attrs.termination_date))
)
# MD
events.append((_MD_IDX, md_dt, md_dt))
# Filter: remove events before SD
events = [(ei, et, ct) for ei, et, ct in events if et >= sd_dt]
# If PRD exists, remove IED and events before PRD
if attrs.purchase_date:
prd_dt = _adt_to_dt(attrs.purchase_date)
events = [(ei, et, ct) for ei, et, ct in events if ei != _IED_IDX and et >= prd_dt]
# Sort by (event_time, priority)
events.sort(key=lambda e: (e[1], _get_evt_priority(e[0])))
# If TD exists, remove all events after TD
if attrs.termination_date:
td_dt = _adt_to_dt(attrs.termination_date)
events = [(ei, et, ct) for ei, et, ct in events if et <= td_dt]
return events
def _fallback_pam_schedule(
attrs: ContractAttributes,
) -> list[tuple[int, _datetime, _datetime]]:
"""Fall back to the full PrincipalAtMaturityContract for BDC/EOMC cases."""
from jactus.contracts.pam import PrincipalAtMaturityContract
from jactus.observers import ConstantRiskFactorObserver
rf_obs = ConstantRiskFactorObserver(constant_value=0.0)
contract = PrincipalAtMaturityContract(attrs, rf_obs)
schedule = contract.generate_event_schedule()
result: list[tuple[int, _datetime, _datetime]] = []
for event in schedule.events:
evt_dt = _adt_to_dt(event.event_time)
calc_dt = _adt_to_dt(event.calculation_time) if event.calculation_time else evt_dt
result.append((event.event_type.index, evt_dt, calc_dt))
return result
def _fast_pam_init_state(
attrs: ContractAttributes,
) -> tuple[float, float, float, float, float, float, _datetime]:
"""Compute initial PAM state as Python floats.
Returns ``(nt, ipnr, ipac, feac, nsc, isc, sd_datetime)``.
"""
sd = attrs.status_date
ied = attrs.initial_exchange_date
sd_dt = _adt_to_dt(sd)
needs_post_ied = (ied and ied < sd) or attrs.purchase_date
if needs_post_ied:
role_sign = _get_role_sign(attrs.contract_role)
nt = role_sign * (attrs.notional_principal or 0.0)
ipnr = attrs.nominal_interest_rate or 0.0
if ied and ied >= sd and attrs.purchase_date:
init_sd_dt = _adt_to_dt(ied)
ipac = 0.0
else:
init_sd_dt = sd_dt
accrual_start = attrs.interest_payment_anchor or ied
if attrs.accrued_interest is not None:
ipac = attrs.accrued_interest
elif accrual_start and accrual_start < sd:
from jactus.core.types import DayCountConvention
# Walk forward through the IP cycle to find the last
# payment date before SD — past IP events reset ipac.
if attrs.interest_payment_cycle:
from jactus.contracts.pam import _last_cycle_date_before
accrual_start = _last_cycle_date_before(
accrual_start, attrs.interest_payment_cycle, sd
)
dcc = attrs.day_count_convention or DayCountConvention.A360
ipac = year_fraction(accrual_start, sd, dcc) * ipnr * abs(nt)
else:
ipac = 0.0
return (nt, ipnr, ipac, 0.0, 1.0, 1.0, init_sd_dt)
return (0.0, 0.0, 0.0, 0.0, 1.0, 1.0, sd_dt)
def _precompute_raw(
attrs: ContractAttributes,
rf_observer: RiskFactorObserver,
) -> _RawPrecomputed:
"""Pre-compute all data as pure Python types (no JAX arrays).
This is the core pre-computation that can be batched efficiently —
all JAX array creation is deferred to the caller.
"""
from jactus.contracts.array_common import get_yf_fn
if _USE_DATE_ARRAY:
return _precompute_raw_da(attrs, rf_observer)
from jactus.core.types import DayCountConvention
# 1. Fast schedule generation (no contract object)
schedule = _fast_pam_schedule(attrs)
# 2. Fast state initialization (no contract object)
nt, ipnr, ipac, feac, nsc, isc, init_sd_dt = _fast_pam_init_state(attrs)
# 3. Compute year fractions and risk factors
dcc = attrs.day_count_convention or DayCountConvention.A360
yf_fn = get_yf_fn(dcc)
event_type_list: list[int] = []
yf_list: list[float] = []
current_sd_dt = init_sd_dt
for evt_idx, evt_dt, calc_dt in schedule:
event_type_list.append(evt_idx)
if yf_fn is not None:
yf_list.append(yf_fn(current_sd_dt, calc_dt))
else:
yf_list.append(year_fraction(_dt_to_adt(current_sd_dt), _dt_to_adt(calc_dt), dcc))
current_sd_dt = evt_dt
# Risk factor pre-query
rf_list = _prequery_risk_factors(schedule, attrs, rf_observer)
# 4. Extract params as raw Python dict
params_raw = _extract_params_raw(attrs)
return _RawPrecomputed(
state=(nt, ipnr, ipac, feac, nsc, isc),
event_types=event_type_list,
year_fractions=yf_list,
rf_values=rf_list,
params=params_raw,
)
# ---------------------------------------------------------------------------
# DateArray-based pre-computation (vectorised year fractions)
# ---------------------------------------------------------------------------
def _precompute_raw_da(
attrs: ContractAttributes,
rf_observer: RiskFactorObserver,
) -> _RawPrecomputed:
"""Pre-compute using vectorised year fractions (NumPy, no JAX overhead).
Schedule generation reuses ``_fast_pam_schedule`` (Python business logic),
but year fractions are computed in a single vectorised NumPy pass using
the Hinnant ordinal algorithm from :mod:`jactus.utilities.date_array`.
"""
from jactus.core.types import DayCountConvention
# 1. Schedule (same as before — Python business logic)
schedule = _fast_pam_schedule(attrs)
# 2. State initialisation (same as before)
nt, ipnr, ipac, feac, nsc, isc, init_sd_dt = _fast_pam_init_state(attrs)
if not schedule:
params_raw = _extract_params_raw(attrs)
return _RawPrecomputed(
state=(nt, ipnr, ipac, feac, nsc, isc),
event_types=[],
year_fractions=[],
rf_values=[],
params=params_raw,
)
# 3. Event types + vectorised year fractions
event_type_list = [evt_idx for evt_idx, _, _ in schedule]
dcc = attrs.day_count_convention or DayCountConvention.A360
yf_list = _compute_vectorised_year_fractions(schedule, init_sd_dt, dcc)
# 4. Risk factor pre-query
rf_list = _prequery_risk_factors(schedule, attrs, rf_observer)
# 5. Extract params
params_raw = _extract_params_raw(attrs)
return _RawPrecomputed(
state=(nt, ipnr, ipac, feac, nsc, isc),
event_types=event_type_list,
year_fractions=yf_list,
rf_values=rf_list,
params=params_raw,
)
def _classify_contracts_for_batch(
contracts: list[tuple[ContractAttributes, RiskFactorObserver]],
) -> tuple[list[int], list[int]]:
"""Partition contract indices into batch-eligible vs fallback.
Batch-eligible criteria (conservative):
- BDC is NULL or None, EOMC is SD or None
- IP cycle is month-based (M, Q, H, Y), no ``+`` stub
- No RR/FP/SC cycles, no PRD/TD/IPCED
- DCC in {A360, A365, E30360, B30360}
"""
from jactus.core.types import BusinessDayConvention, DayCountConvention
batch_dccs = frozenset(
{
DayCountConvention.A360,
DayCountConvention.A365,
DayCountConvention.E30360,
DayCountConvention.B30360,
}
)
batch_idx: list[int] = []
fallback_idx: list[int] = []
for i, (attrs, _obs) in enumerate(contracts):
# BDC / EOMC check
bdc = attrs.business_day_convention
if bdc is not None and bdc != BusinessDayConvention.NULL:
fallback_idx.append(i)
continue
if (
attrs.end_of_month_convention is not None
and attrs.end_of_month_convention.value != "SD"
):
fallback_idx.append(i)
continue
# IP cycle must be month-based, no + stub
ip_cycle = attrs.interest_payment_cycle
if ip_cycle:
_mult, period, stub = _parse_cycle_fast(ip_cycle)
if period not in _CYCLE_MONTHS_MAP:
fallback_idx.append(i)
continue
if stub == "+":
fallback_idx.append(i)
continue
# No complex features
if (
attrs.rate_reset_cycle
or attrs.fee_payment_cycle
or attrs.scaling_index_cycle
or attrs.purchase_date
or attrs.termination_date
or attrs.interest_capitalization_end_date
):
fallback_idx.append(i)
continue
# DCC check
dcc = attrs.day_count_convention or DayCountConvention.A360
if dcc not in batch_dccs:
fallback_idx.append(i)
continue
batch_idx.append(i)
return batch_idx, fallback_idx
def _jax_batch_assemble(
params: _BatchContractParams,
ip_ordinals: jnp.ndarray,
ip_valid: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Assemble full event schedules: IED + IP + stub + MD (JAX-native).
Returns:
``(event_types, event_ordinals, event_valid, n_events)``
Shapes: ``(N, max_events)``, ``(N, max_events)``, ``(N, max_events)``, ``(N,)``.
``max_events = max_ip + 3`` (IED + IP dates + stub + MD).
"""
n, max_ip = ip_ordinals.shape
max_events = max_ip + 3
# Initialise all as NOP (padding)
event_types = jnp.full((n, max_events), NOP_EVENT_IDX, dtype=jnp.int32)
event_ordinals = jnp.zeros((n, max_events), dtype=jnp.int32)
event_valid = jnp.zeros((n, max_events), dtype=jnp.bool_)
# --- Column 0: IED ---
ied_present = params.ied_ord >= params.sd_ord # (N,)
event_types = event_types.at[:, 0].set(jnp.where(ied_present, _IED_IDX, NOP_EVENT_IDX))
event_ordinals = event_ordinals.at[:, 0].set(params.ied_ord)
event_valid = event_valid.at[:, 0].set(ied_present)
# --- Columns 1..max_ip: IP events ---
# Filter: IP dates must be >= IED (already done in ip_valid)
# Also filter: >= SD
sd_ord = params.sd_ord.reshape(-1, 1)
ip_after_sd = ip_valid & (ip_ordinals >= sd_ord)
event_ordinals = event_ordinals.at[:, 1 : max_ip + 1].set(ip_ordinals)
event_valid = event_valid.at[:, 1 : max_ip + 1].set(ip_after_sd)
event_types = event_types.at[:, 1 : max_ip + 1].set(
jnp.where(ip_after_sd, _IP_IDX, NOP_EVENT_IDX)
)
# --- Column max_ip+1: Stub IP at MD (if no IP falls on MD) ---
md_ord = params.md_ord.reshape(-1, 1)
ip_at_md = jnp.any(ip_after_sd & (ip_ordinals == md_ord), axis=1) # (N,)
needs_stub = (params.has_ip_cycle.astype(jnp.bool_)) & (~ip_at_md)
# Stub must also be >= SD
needs_stub = needs_stub & (params.md_ord >= params.sd_ord)
stub_col = max_ip + 1
event_types = event_types.at[:, stub_col].set(jnp.where(needs_stub, _IP_IDX, NOP_EVENT_IDX))
event_ordinals = event_ordinals.at[:, stub_col].set(params.md_ord)
event_valid = event_valid.at[:, stub_col].set(needs_stub)
# --- Column max_ip+2: MD (always present if >= SD) ---
md_col = max_ip + 2
md_present = params.md_ord >= params.sd_ord
event_types = event_types.at[:, md_col].set(jnp.where(md_present, _MD_IDX, NOP_EVENT_IDX))
event_ordinals = event_ordinals.at[:, md_col].set(params.md_ord)
event_valid = event_valid.at[:, md_col].set(md_present)
# --- Sort each row by (ordinal, priority) ---
# Build priority lookup array (index by event type idx)
_evt_priorities = jnp.array(
[_get_evt_priority(i) for i in range(NOP_EVENT_IDX + 1)],
dtype=jnp.int32,
)
evt_priority = _evt_priorities[event_types] # (N, max_events)
# Composite sort key: invalid events get MAX ordinal to sort last
max_ord = jnp.int32(2_000_000) # ~5480 years, well beyond any contract
sort_ordinal = jnp.where(event_valid, event_ordinals, max_ord)
sort_key = sort_ordinal * 100 + evt_priority # max ~200M, fits int32
sort_idx = jnp.argsort(sort_key, axis=1) # (N, max_events)
# Apply sort (gather along axis=1)
row_idx = jnp.arange(n).reshape(-1, 1)
event_types = event_types[row_idx, sort_idx]
event_ordinals = event_ordinals[row_idx, sort_idx]
event_valid = event_valid[row_idx, sort_idx]
n_events = event_valid.sum(axis=1).astype(jnp.int32)
return event_types, event_ordinals, event_valid, n_events
def _batch_precompute_pam_impl(
params: _BatchContractParams,
max_ip: int,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Inner implementation for batch pre-computation (pure JAX)."""
ip_ords, ip_valid = _jax_batch_ip_schedule(params, max_ip)
evt_types, evt_ords, evt_valid, _n_events = _jax_batch_assemble(params, ip_ords, ip_valid)
yf = _jax_batch_year_fractions(evt_ords, evt_valid, params)
rf = jnp.zeros_like(yf) # no RR/FP/SC in batch-eligible contracts
masks = evt_valid.astype(jnp.float32)
return evt_types, yf, rf, masks
_batch_precompute_pam_jit = jax.jit(_batch_precompute_pam_impl, static_argnums=(1,))
[docs]
def batch_precompute_pam(
params: _BatchContractParams,
max_ip: int,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""JAX-native batch schedule generation + year fractions.
Wraps a JIT-compiled kernel. ``max_ip`` is a compile-time constant
(recompiles only when ``max_ip`` changes).
Args:
params: Batch contract parameters (shape ``(N,)`` per field).
max_ip: Maximum IP events (static, determines array shapes).
Returns:
``(event_types, year_fractions, rf_values, masks)`` —
all shape ``(N, max_events)`` where ``max_events = max_ip + 3``.
"""
return _batch_precompute_pam_jit(params, max_ip) # type: ignore[no-any-return]
def _raw_to_jax(
raw: _RawPrecomputed,
) -> tuple[PAMArrayState, jnp.ndarray, jnp.ndarray, jnp.ndarray, PAMArrayParams]:
"""Convert raw pre-computed data to JAX arrays."""
nt, ipnr, ipac, feac, nsc, isc = raw.state
return (
PAMArrayState(
nt=jnp.array(nt, dtype=_F32),
ipnr=jnp.array(ipnr, dtype=_F32),
ipac=jnp.array(ipac, dtype=_F32),
feac=jnp.array(feac, dtype=_F32),
nsc=jnp.array(nsc, dtype=_F32),
isc=jnp.array(isc, dtype=_F32),
),
jnp.array(raw.event_types, dtype=jnp.int32),
jnp.array(raw.year_fractions, dtype=_F32),
jnp.array(raw.rf_values, dtype=_F32),
_params_raw_to_jax(raw.params),
)
[docs]
def precompute_pam_arrays(
attrs: ContractAttributes,
rf_observer: RiskFactorObserver,
) -> tuple[PAMArrayState, jnp.ndarray, jnp.ndarray, jnp.ndarray, PAMArrayParams]:
"""Pre-compute JAX arrays for array-mode PAM simulation.
Generates the event schedule and initial state directly from attributes
(bypassing ``PrincipalAtMaturityContract``), then converts to JAX arrays
suitable for ``simulate_pam_array``.
Args:
attrs: Contract attributes (must be PAM type).
rf_observer: Risk factor observer (queried for RR/PP events).
Returns:
``(initial_state, event_types, year_fractions, rf_values, params)``
"""
return _raw_to_jax(_precompute_raw(attrs, rf_observer))
# ============================================================================
# Batch / portfolio API
# ============================================================================
def _raw_list_to_jax_batch(
raw_list: list[_RawPrecomputed],
) -> tuple[PAMArrayState, jnp.ndarray, jnp.ndarray, jnp.ndarray, PAMArrayParams, jnp.ndarray]:
"""Convert a list of ``_RawPrecomputed`` to padded JAX batch arrays.
Pads shorter contracts with NOP events and builds NumPy arrays first
(fast C-level construction) then transfers to JAX via ``jnp.asarray``.
"""
max_events = max(len(r.event_types) for r in raw_list)
# State fields: (batch,) each
state_nt = [r.state[0] for r in raw_list]
state_ipnr = [r.state[1] for r in raw_list]
state_ipac = [r.state[2] for r in raw_list]
state_feac = [r.state[3] for r in raw_list]
state_nsc = [r.state[4] for r in raw_list]
state_isc = [r.state[5] for r in raw_list]
# Event arrays: (batch, max_events) each, with padding
et_batch: list[list[int]] = []
yf_batch: list[list[float]] = []
rf_batch: list[list[float]] = []
mask_batch: list[list[float]] = []
for r in raw_list:
n_events = len(r.event_types)
pad_n = max_events - n_events
et_batch.append(r.event_types + [NOP_EVENT_IDX] * pad_n)
yf_batch.append(r.year_fractions + [0.0] * pad_n)
rf_batch.append(r.rf_values + [0.0] * pad_n)
mask_batch.append([1.0] * n_events + [0.0] * pad_n)
# Param fields: (batch,) each
param_fields: dict[str, list[float | int]] = {k: [] for k in PAMArrayParams._fields}
for r in raw_list:
for k in PAMArrayParams._fields:
param_fields[k].append(r.params[k])
# Build NumPy arrays first (fast C-level), then transfer to JAX
batched_states = PAMArrayState(
nt=jnp.asarray(np.array(state_nt, dtype=np.float32)),
ipnr=jnp.asarray(np.array(state_ipnr, dtype=np.float32)),
ipac=jnp.asarray(np.array(state_ipac, dtype=np.float32)),
feac=jnp.asarray(np.array(state_feac, dtype=np.float32)),
nsc=jnp.asarray(np.array(state_nsc, dtype=np.float32)),
isc=jnp.asarray(np.array(state_isc, dtype=np.float32)),
)
batched_et = jnp.asarray(np.array(et_batch, dtype=np.int32))
batched_yf = jnp.asarray(np.array(yf_batch, dtype=np.float32))
batched_rf = jnp.asarray(np.array(rf_batch, dtype=np.float32))
batched_masks = jnp.asarray(np.array(mask_batch, dtype=np.float32))
_int_fields = {"fee_basis", "penalty_type"}
batched_params = PAMArrayParams(
**{
k: jnp.asarray(
np.array(
param_fields[k],
dtype=np.int32 if k in _int_fields else np.float32,
)
)
for k in PAMArrayParams._fields
}
)
return batched_states, batched_et, batched_yf, batched_rf, batched_params, batched_masks
def _prepare_pam_batch_sequential(
contracts: list[tuple[ContractAttributes, RiskFactorObserver]],
) -> tuple[PAMArrayState, jnp.ndarray, jnp.ndarray, jnp.ndarray, PAMArrayParams, jnp.ndarray]:
"""Per-contract sequential pre-computation (original path)."""
raw_list = [_precompute_raw(attrs, obs) for attrs, obs in contracts]
return _raw_list_to_jax_batch(raw_list)
def _extract_batch_states_and_params(
contracts: list[tuple[ContractAttributes, RiskFactorObserver]],
indices: list[int],
) -> tuple[PAMArrayState, PAMArrayParams]:
"""Extract initial states and simulation params in bulk NumPy.
Replaces the per-contract loop over ``_fast_pam_init_state`` +
``_extract_params_raw`` with a single vectorised pass.
Returns JAX arrays ready for the simulation kernel.
"""
n = len(indices)
_int_fields = {"fee_basis", "penalty_type"}
# Pre-allocate NumPy arrays for states
s_nt = np.zeros(n, dtype=np.float32)
s_ipnr = np.zeros(n, dtype=np.float32)
s_ipac = np.zeros(n, dtype=np.float32)
s_feac = np.zeros(n, dtype=np.float32)
s_nsc = np.ones(n, dtype=np.float32)
s_isc = np.ones(n, dtype=np.float32)
# Pre-allocate NumPy arrays for params
p_arrays: dict[str, np.ndarray] = {
k: np.zeros(n, dtype=np.int32 if k in _int_fields else np.float32)
for k in PAMArrayParams._fields
}
for j, idx in enumerate(indices):
attrs, _ = contracts[idx]
nt, ipnr, ipac, feac, nsc, isc, _ = _fast_pam_init_state(attrs)
s_nt[j] = nt
s_ipnr[j] = ipnr
s_ipac[j] = ipac
s_feac[j] = feac
s_nsc[j] = nsc
s_isc[j] = isc
p = _extract_params_raw(attrs)
for k in PAMArrayParams._fields:
p_arrays[k][j] = p[k]
# Single NumPy → JAX transfer
states = PAMArrayState(
nt=jnp.asarray(s_nt),
ipnr=jnp.asarray(s_ipnr),
ipac=jnp.asarray(s_ipac),
feac=jnp.asarray(s_feac),
nsc=jnp.asarray(s_nsc),
isc=jnp.asarray(s_isc),
)
params = PAMArrayParams(**{k: jnp.asarray(p_arrays[k]) for k in PAMArrayParams._fields})
return states, params
def _prepare_pam_batch_all_eligible(
contracts: list[tuple[ContractAttributes, RiskFactorObserver]],
batch_idx: list[int],
) -> tuple[PAMArrayState, jnp.ndarray, jnp.ndarray, jnp.ndarray, PAMArrayParams, jnp.ndarray]:
"""Fast path when ALL contracts are batch-eligible.
Avoids the JAX→NumPy→JAX round-trip used by the mixed batch/fallback
path. Schedule arrays stay as JAX arrays throughout.
"""
bp = _extract_batch_params(contracts, batch_idx)
max_ip = _compute_max_ip(bp)
evt_types, yf, rf, masks = batch_precompute_pam(bp, max_ip)
# Trim trailing NOP padding
actual_max = int(masks.sum(axis=1).max())
evt_types = evt_types[:, :actual_max]
yf = yf[:, :actual_max]
rf = rf[:, :actual_max]
masks = masks[:, :actual_max]
# Extract states + params (NumPy bulk → single JAX transfer)
states, params = _extract_batch_states_and_params(contracts, batch_idx)
return states, evt_types, yf, rf, params, masks
[docs]
def prepare_pam_batch(
contracts: list[tuple[ContractAttributes, RiskFactorObserver]],
) -> tuple[PAMArrayState, jnp.ndarray, jnp.ndarray, jnp.ndarray, PAMArrayParams, jnp.ndarray]:
"""Pre-compute and pad arrays for a batch of PAM contracts.
When ``_USE_BATCH_SCHEDULE`` is enabled, eligible contracts have their
schedules and year fractions generated via a JAX-native batch path
(GPU/TPU-ready). Ineligible contracts fall back to per-contract
Python pre-computation.
When **all** contracts are batch-eligible, a fast path avoids the
JAX→NumPy→JAX round-trip, keeping schedule arrays on-device.
Args:
contracts: List of ``(attributes, rf_observer)`` pairs.
Returns:
``(initial_states, event_types, year_fractions, rf_values, params, masks)``
where each array has a leading batch dimension.
"""
if not _USE_BATCH_SCHEDULE or len(contracts) <= 1:
return _prepare_pam_batch_sequential(contracts)
batch_idx, fallback_idx = _classify_contracts_for_batch(contracts)
if not batch_idx:
return _prepare_pam_batch_sequential(contracts)
# --- Fast path: all contracts are batch-eligible ---
if not fallback_idx:
return _prepare_pam_batch_all_eligible(contracts, batch_idx)
# --- Mixed path: batch + fallback ---
bp = _extract_batch_params(contracts, batch_idx)
max_ip = _compute_max_ip(bp)
evt_types_jax, yf_jax, rf_jax, masks_jax = batch_precompute_pam(bp, max_ip)
# Trim batch arrays to actual max valid events (remove trailing NOP padding)
actual_max_batch = int(masks_jax.sum(axis=1).max())
evt_types_jax = evt_types_jax[:, :actual_max_batch]
yf_jax = yf_jax[:, :actual_max_batch]
rf_jax = rf_jax[:, :actual_max_batch]
masks_jax = masks_jax[:, :actual_max_batch]
max_events_batch = actual_max_batch
# --- Fallback path: per-contract Python precompute ---
fallback_raws = [_precompute_raw(*contracts[i]) for i in fallback_idx]
max_events_fallback = max((len(r.event_types) for r in fallback_raws), default=0)
# --- Determine final padded width ---
max_events = max(max_events_batch, max_events_fallback)
n_total = len(contracts)
_int_fields = {"fee_basis", "penalty_type"}
# --- Allocate final NumPy arrays ---
final_et = np.full((n_total, max_events), NOP_EVENT_IDX, dtype=np.int32)
final_yf = np.zeros((n_total, max_events), dtype=np.float32)
final_rf = np.zeros((n_total, max_events), dtype=np.float32)
final_mask = np.zeros((n_total, max_events), dtype=np.float32)
final_nt = np.zeros(n_total, dtype=np.float32)
final_ipnr = np.zeros(n_total, dtype=np.float32)
final_ipac = np.zeros(n_total, dtype=np.float32)
final_feac = np.zeros(n_total, dtype=np.float32)
final_nsc = np.zeros(n_total, dtype=np.float32)
final_isc = np.zeros(n_total, dtype=np.float32)
param_arrays = {
k: np.zeros(n_total, dtype=np.int32 if k in _int_fields else np.float32)
for k in PAMArrayParams._fields
}
# --- Place batch results (single bulk JAX → NumPy transfer) ---
batch_idx_np = np.array(batch_idx, dtype=np.intp)
final_et[batch_idx_np, :max_events_batch] = np.asarray(evt_types_jax)
final_yf[batch_idx_np, :max_events_batch] = np.asarray(yf_jax)
final_rf[batch_idx_np, :max_events_batch] = np.asarray(rf_jax)
final_mask[batch_idx_np, :max_events_batch] = np.asarray(masks_jax)
# States + params for batch contracts
for _j, idx in enumerate(batch_idx):
attrs, _ = contracts[idx]
nt, ipnr, ipac, feac, nsc, isc, _ = _fast_pam_init_state(attrs)
final_nt[idx] = nt
final_ipnr[idx] = ipnr
final_ipac[idx] = ipac
final_feac[idx] = feac
final_nsc[idx] = nsc
final_isc[idx] = isc
p = _extract_params_raw(attrs)
for k in PAMArrayParams._fields:
param_arrays[k][idx] = p[k]
# --- Place fallback results ---
for j, idx in enumerate(fallback_idx):
r = fallback_raws[j]
n_ev = len(r.event_types)
final_et[idx, :n_ev] = r.event_types
final_yf[idx, :n_ev] = r.year_fractions
final_rf[idx, :n_ev] = r.rf_values
final_mask[idx, :n_ev] = 1.0
final_nt[idx] = r.state[0]
final_ipnr[idx] = r.state[1]
final_ipac[idx] = r.state[2]
final_feac[idx] = r.state[3]
final_nsc[idx] = r.state[4]
final_isc[idx] = r.state[5]
for k in PAMArrayParams._fields:
param_arrays[k][idx] = r.params[k]
# --- Single NumPy → JAX transfer ---
return (
PAMArrayState(
nt=jnp.asarray(final_nt),
ipnr=jnp.asarray(final_ipnr),
ipac=jnp.asarray(final_ipac),
feac=jnp.asarray(final_feac),
nsc=jnp.asarray(final_nsc),
isc=jnp.asarray(final_isc),
),
jnp.asarray(final_et),
jnp.asarray(final_yf),
jnp.asarray(final_rf),
PAMArrayParams(**{k: jnp.asarray(param_arrays[k]) for k in PAMArrayParams._fields}),
jnp.asarray(final_mask),
)
[docs]
def simulate_pam_portfolio(
contracts: list[tuple[ContractAttributes, RiskFactorObserver]],
discount_rate: float | None = None,
year_fractions_from_valuation: jnp.ndarray | None = None,
) -> dict[str, Any]:
"""End-to-end portfolio simulation with optional PV.
Args:
contracts: List of ``(attributes, rf_observer)`` pairs.
discount_rate: If provided, compute present values.
year_fractions_from_valuation: ``(batch, max_events)`` year fractions
from valuation date for PV discounting. If ``None`` and
``discount_rate`` is set, year fractions are computed from each
contract's ``status_date``.
Returns:
Dict with ``payoffs``, ``masks``, ``final_states``, and optionally
``present_values`` and ``total_pv``.
"""
(
batched_states,
batched_et,
batched_yf,
batched_rf,
batched_params,
batched_masks,
) = prepare_pam_batch(contracts)
# Run batched simulation (auto-selects vmap on GPU/TPU, manual on CPU)
final_states, payoffs = batch_simulate_pam_auto(
batched_states, batched_et, batched_yf, batched_rf, batched_params
)
# Mask padding
masked_payoffs = payoffs * batched_masks
total_cashflows = jnp.sum(masked_payoffs, axis=1)
result = {
"payoffs": masked_payoffs,
"masks": batched_masks,
"final_states": final_states,
"total_cashflows": total_cashflows,
"num_contracts": len(contracts),
}
if discount_rate is not None:
# Use cumulative year fractions for discounting
if year_fractions_from_valuation is not None:
disc_yfs = year_fractions_from_valuation
else:
# Approximate: use cumulative sum of per-event year fractions
disc_yfs = jnp.cumsum(batched_yf, axis=1)
discount_factors = 1.0 / (1.0 + discount_rate * disc_yfs)
pvs = jnp.sum(masked_payoffs * discount_factors, axis=1)
result["present_values"] = pvs
result["total_pv"] = jnp.sum(pvs)
return result