Source code for jactus.observers.prepayment

"""Prepayment behavioral risk model for ACTUS contracts.

This module implements a 2D surface-based prepayment model that computes
prepayment rates as a function of:

- **Spread** (dimension 1): The difference between the contract's nominal
  interest rate and the current market reference rate. A positive spread
  means the borrower has an incentive to refinance.
- **Loan age** (dimension 2): Time elapsed since the initial exchange date.
  Prepayment behavior typically follows a seasoning pattern, peaking in
  the middle years of a loan's life.

The model returns a **Multiplicative Reduction Delta (MRD)** — a fraction
by which the notional principal is reduced at each prepayment observation time.

This mirrors the ``TwoDimensionalPrepaymentModel`` from the ACTUS risk service,
with the addition of JAX compatibility for automatic differentiation.

References:
    ACTUS Risk Service v2.0 - TwoDimensionalPrepaymentModel
    ACTUS Technical Specification v1.1 - PP (Principal Prepayment) events
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import jax.numpy as jnp

from jactus.observers.behavioral import BaseBehaviorRiskFactorObserver, CalloutEvent
from jactus.utilities.surface import Surface2D

if TYPE_CHECKING:
    from jactus.core import ActusDateTime, ContractAttributes, ContractState
    from jactus.core.types import EventType


[docs] class PrepaymentSurfaceObserver(BaseBehaviorRiskFactorObserver): """Prepayment model using a 2D surface (spread x loan age). At each prepayment observation time, the model: 1. Computes the **spread** = ``state.ipnr - market_rate(time)`` 2. Computes the **loan age** = years since ``attributes.initial_exchange_date`` 3. Looks up the prepayment rate from the 2D surface 4. Returns the prepayment rate as a JAX array The market reference rate is obtained from a companion market observer or from a fixed reference rate. Attributes: surface: 2D surface mapping (spread, age) to prepayment rate. market_rate_id: Identifier for the market reference rate (e.g., ``"UST-5Y"``). If not provided, ``fixed_market_rate`` is used. market_observer: Optional companion market risk factor observer for fetching the current market rate. fixed_market_rate: Fixed market rate used when no market observer is provided (default 0.0). prepayment_cycle: Cycle string for prepayment event frequency (e.g., ``"6M"`` for semi-annual). Used by ``contract_start()`` to generate callout events. model_id: Identifier for this prepayment model instance. Example: >>> import jax.numpy as jnp >>> from jactus.utilities.surface import Surface2D >>> surface = Surface2D( ... x_margins=jnp.array([-5.0, 0.0, 1.0, 2.0, 3.0]), ... y_margins=jnp.array([0.0, 1.0, 2.0, 3.0, 5.0, 10.0]), ... values=jnp.array([ ... [0.00, 0.00, 0.00, 0.00, 0.00, 0.00], # spread=-5% ... [0.00, 0.00, 0.01, 0.00, 0.00, 0.00], # spread= 0% ... [0.00, 0.01, 0.02, 0.00, 0.00, 0.00], # spread= 1% ... [0.00, 0.02, 0.05, 0.03, 0.005, 0.00], # spread= 2% ... [0.01, 0.05, 0.10, 0.07, 0.02, 0.00], # spread= 3% ... ]), ... ) >>> observer = PrepaymentSurfaceObserver( ... surface=surface, ... fixed_market_rate=0.04, ... prepayment_cycle="6M", ... ) """
[docs] def __init__( self, surface: Surface2D, market_rate_id: str | None = None, market_observer: Any | None = None, fixed_market_rate: float = 0.0, prepayment_cycle: str = "6M", model_id: str = "prepayment-model", name: str | None = None, ): """Initialize prepayment surface observer. Args: surface: 2D surface mapping (spread, age) to prepayment rate. market_rate_id: Identifier for the market reference rate. market_observer: Optional market risk factor observer. fixed_market_rate: Fixed market rate when no observer is provided. prepayment_cycle: Cycle for prepayment observation frequency. model_id: Unique model identifier. name: Optional observer name for debugging. """ super().__init__(name or f"PrepaymentSurface({model_id})") self.surface = surface self.market_rate_id = market_rate_id self.market_observer = market_observer self.fixed_market_rate = fixed_market_rate self.prepayment_cycle = prepayment_cycle self.model_id = model_id
def _get_market_rate(self, time: ActusDateTime) -> float: """Get current market reference rate.""" if self.market_observer is not None and self.market_rate_id is not None: return float(self.market_observer.observe_risk_factor(self.market_rate_id, time)) return self.fixed_market_rate def _get_risk_factor( self, identifier: str, time: ActusDateTime, state: ContractState | None, attributes: ContractAttributes | None, ) -> jnp.ndarray: """Compute prepayment rate from surface. Uses the contract's current nominal interest rate and loan age to look up the prepayment rate from the 2D surface. Args: identifier: Risk factor identifier (typically the model_id). time: Current simulation time. state: Current contract state (must contain ``ipnr``). attributes: Contract attributes (must contain ``initial_exchange_date``). Returns: Prepayment rate as JAX array (MRD value). """ if state is None or attributes is None: return jnp.array(0.0, dtype=jnp.float32) # Compute spread: contract rate - market rate contract_rate = float(state.ipnr) market_rate = self._get_market_rate(time) spread = contract_rate - market_rate # Compute loan age in years ied = attributes.initial_exchange_date if ied is None: return jnp.array(0.0, dtype=jnp.float32) age_years = ied.days_between(time) / 365.25 # Look up prepayment rate from surface return self.surface.evaluate(spread, age_years) def _get_event_data( self, identifier: str, event_type: EventType, time: ActusDateTime, state: ContractState | None, attributes: ContractAttributes | None, ) -> Any: """Prepayment observer does not provide event data. Raises: KeyError: Always. """ raise KeyError(f"PrepaymentSurfaceObserver does not support event data for '{identifier}'")
[docs] def contract_start( self, attributes: ContractAttributes, ) -> list[CalloutEvent]: """Generate prepayment observation events over the contract life. Creates callout events at the specified prepayment cycle interval from the initial exchange date to the maturity date. Args: attributes: Contract attributes. Returns: List of CalloutEvent objects with callout_type ``"MRD"``. """ from jactus.core.time import add_period ied = attributes.initial_exchange_date md = attributes.maturity_date if ied is None or md is None: return [] events: list[CalloutEvent] = [] current = add_period(ied, self.prepayment_cycle) while current < md: events.append( CalloutEvent( model_id=self.model_id, time=current, callout_type="MRD", ) ) current = add_period(current, self.prepayment_cycle) return events