Source code for jactus.contracts.swppv

"""Plain Vanilla Interest Rate Swap (SWPPV) contract implementation.

This module implements a plain vanilla interest rate swap where one party pays a
fixed rate and receives a floating rate (or vice versa). The swap exchanges interest
payments on a notional amount without exchanging the principal.

Key Features:
    - Fixed leg: Uses nominal_interest_rate (IPNR)
    - Floating leg: Uses nominal_interest_rate_2 (IPNR2) with rate resets
    - Separate accrual tracking (ipac1 for fixed, ipac2 for floating)
    - Net or gross settlement modes
    - No notional exchange at inception or maturity

Example:
    >>> from jactus.contracts import PlainVanillaSwapContract
    >>> from jactus.core import ContractAttributes, ActusDateTime
    >>> from jactus.observers import ConstantRiskFactorObserver
    >>>
    >>> # Receive fixed, pay floating
    >>> attrs = ContractAttributes(
    ...     contract_id="SWAP-001",
    ...     contract_type=ContractType.SWPPV,
    ...     contract_role=ContractRole.RPA,  # Receive fixed
    ...     status_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
    ...     contract_deal_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
    ...     initial_exchange_date=ActusDateTime(2024, 1, 15, 0, 0, 0),
    ...     maturity_date=ActusDateTime(2029, 1, 15, 0, 0, 0),
    ...     notional_principal=1000000.0,
    ...     nominal_interest_rate=0.05,  # Fixed leg: 5%
    ...     nominal_interest_rate_2=0.03,  # Floating leg initial: 3%
    ...     interest_payment_cycle="P6M",  # Semi-annual
    ...     rate_reset_cycle="P3M",  # Quarterly resets
    ...     delivery_settlement="D",  # Net settlement
    ... )
    >>> rf_obs = ConstantRiskFactorObserver(0.04)
    >>> swap = PlainVanillaSwapContract(attrs, rf_obs)
    >>> cashflows = swap.simulate(rf_obs)
"""

from typing import Any

import jax.numpy as jnp

from jactus.contracts.base import BaseContract, SimulationHistory
from jactus.core import (
    ActusDateTime,
    ContractAttributes,
    ContractEvent,
    ContractPerformance,
    ContractRole,
    ContractState,
    ContractType,
    EventSchedule,
    EventType,
)
from jactus.core.types import DayCountConvention
from jactus.functions import BasePayoffFunction, BaseStateTransitionFunction
from jactus.observers import RiskFactorObserver
from jactus.observers.behavioral import BehaviorRiskFactorObserver
from jactus.observers.scenario import Scenario
from jactus.utilities import year_fraction
from jactus.utilities.schedules import generate_schedule


[docs] class PlainVanillaSwapPayoffFunction(BasePayoffFunction): """Payoff function for SWPPV contracts. Calculates cashflows for fixed and floating leg interest payments. """
[docs] def calculate_payoff( self, event_type: EventType, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """Dispatcher for payoff functions.""" if event_type == EventType.AD: return self._pof_ad(state, attributes, time, risk_factor_observer) if event_type == EventType.IED: return self._pof_ied(state, attributes, time, risk_factor_observer) if event_type == EventType.PRD: return self._pof_prd(state, attributes, time, risk_factor_observer) if event_type == EventType.PR: return self._pof_pr(state, attributes, time, risk_factor_observer) if event_type == EventType.IP: return self._pof_ip(state, attributes, time, risk_factor_observer) if event_type == EventType.IPFX: return self._pof_ipfx(state, attributes, time, risk_factor_observer) if event_type == EventType.IPFL: return self._pof_ipfl(state, attributes, time, risk_factor_observer) if event_type == EventType.MD: return self._pof_md(state, attributes, time, risk_factor_observer) if event_type == EventType.RR: return self._pof_rr(state, attributes, time, risk_factor_observer) if event_type == EventType.TD: return self._pof_td(state, attributes, time, risk_factor_observer) if event_type == EventType.CE: return self._pof_ce(state, attributes, time, risk_factor_observer) # Unknown event type return jnp.array(0.0, dtype=jnp.float32)
def _pof_ad( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_AD_SWPPV: Analysis Date payoff. Analysis dates have zero payoff. """ return jnp.array(0.0, dtype=jnp.float32) def _pof_ied( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_IED_SWPPV: Initial Exchange Date payoff. No notional exchange for plain vanilla swaps. """ return jnp.array(0.0, dtype=jnp.float32) def _pof_prd( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_PRD_SWPPV: Purchase Date payoff. Formula: POF_PRD = R(CNTRL) × (-PPRD) """ from jactus.utilities import contract_role_sign pprd = attributes.price_at_purchase_date or 0.0 role_sign = contract_role_sign(attributes.contract_role) return jnp.array(role_sign * (-pprd), dtype=jnp.float32) def _pof_md( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_MD_SWPPV: Maturity Date payoff. No notional exchange at maturity for plain vanilla swaps. """ return jnp.array(0.0, dtype=jnp.float32) def _pof_pr( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_PR_SWPPV: Principal Redemption payoff. No principal redemption for swaps (no notional exchange). """ return jnp.array(0.0, dtype=jnp.float32) def _pof_ip( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_IP_SWPPV: Interest Payment (Net settlement). Computes the net payment including accrual up to payment time. POF is called with pre-event state, so we accrue from state.sd to time, then compute: net = (ipac1 + fixed_accrual) - (ipac2 + floating_accrual) For RPA (receive fixed, pay floating): payoff = net (positive when fixed > floating) For RPL (pay fixed, receive floating): payoff = -net """ dcc = attributes.day_count_convention or DayCountConvention.A360 yf = year_fraction(state.sd, time, dcc) nt = float(state.nt) # Accrue fixed leg up to payment time fixed_rate = attributes.nominal_interest_rate or 0.0 ipac1 = float(state.ipac1) if state.ipac1 is not None else 0.0 total_ipac1 = ipac1 + yf * fixed_rate * nt # Accrue floating leg up to payment time floating_rate = float(state.ipnr) ipac2 = float(state.ipac2) if state.ipac2 is not None else 0.0 total_ipac2 = ipac2 + yf * floating_rate * nt # Net accrual (fixed - floating) net_accrual = total_ipac1 - total_ipac2 # Contract role sign: RPA receives fixed leg, RPL pays it role_sign = ( 1.0 if attributes.contract_role in (ContractRole.RPA, ContractRole.RFL) else -1.0 ) return jnp.array(role_sign * net_accrual, dtype=jnp.float32) def _pof_ipfx( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_IPFX_SWPPV: Fixed Leg Interest Payment. Formula: POF_IPFX = R(CNTRL) × (Ipac1 + Y(Sd, t) × IPNR × Nt) Where: Ipac1: Accumulated fixed leg accrual IPNR: Fixed nominal interest rate R(CNTRL): Role sign (PFL=-1 pays fixed, RFL=+1 receives fixed) """ from jactus.utilities import contract_role_sign dcc = attributes.day_count_convention or DayCountConvention.A360 yf = year_fraction(state.sd, time, dcc) nt = float(state.nt) fixed_rate = attributes.nominal_interest_rate or 0.0 ipac1 = float(state.ipac1) if state.ipac1 is not None else 0.0 total_fixed = ipac1 + yf * fixed_rate * nt role_sign = contract_role_sign(attributes.contract_role) return jnp.array(role_sign * total_fixed, dtype=jnp.float32) def _pof_ipfl( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_IPFL_SWPPV: Floating Leg Interest Payment. Formula: POF_IPFL = -R(CNTRL) × (Ipac2 + Y(Sd, t) × Ipnr × Nt) Where: Ipac2: Accumulated floating leg accrual Ipnr: Current floating rate (from state, updated by RR) R(CNTRL): Role sign (PFL=-1, so -(-1) = +1 receives floating) """ from jactus.utilities import contract_role_sign dcc = attributes.day_count_convention or DayCountConvention.A360 yf = year_fraction(state.sd, time, dcc) nt = float(state.nt) floating_rate = float(state.ipnr) ipac2 = float(state.ipac2) if state.ipac2 is not None else 0.0 total_floating = ipac2 + yf * floating_rate * nt role_sign = contract_role_sign(attributes.contract_role) return jnp.array(-role_sign * total_floating, dtype=jnp.float32) def _pof_rr( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_RR_SWPPV: Rate Reset payoff. Rate resets have zero payoff (only update state). """ return jnp.array(0.0, dtype=jnp.float32) def _pof_td( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_TD_SWPPV: Termination Date payoff. For SWPPV, PTD is the mark-to-market settlement amount (already directional). """ ptd = attributes.price_at_termination_date or 0.0 return jnp.array(ptd, dtype=jnp.float32) def _pof_ce( self, state: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> jnp.ndarray: """POF_CE_SWPPV: Credit Event payoff. Credit events have zero payoff (would trigger termination). """ return jnp.array(0.0, dtype=jnp.float32)
[docs] class PlainVanillaSwapStateTransitionFunction(BaseStateTransitionFunction): """State transition function for SWPPV contracts. Manages accrual tracking for fixed and floating legs. """
[docs] def transition_state( self, event_type: EventType, state_pre: ContractState, attributes: ContractAttributes, time: ActusDateTime, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """Calculate state transition for swap events.""" # Create event for helper functions event = ContractEvent( event_type=event_type, event_time=time, payoff=jnp.array(0.0, dtype=jnp.float32), currency=attributes.currency or "USD", ) if event_type == EventType.AD: return self._stf_ad(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.IED: return self._stf_ied(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.PRD: return self._stf_prd(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.PR: return self._stf_pr(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.IP: return self._stf_ip(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.IPFX: return self._stf_ipfx(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.IPFL: return self._stf_ipfl(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.MD: return self._stf_md(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.RR: return self._stf_rr(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.TD: return self._stf_td(state_pre, event, attributes, risk_factor_observer) if event_type == EventType.CE: return self._stf_ce(state_pre, event, attributes, risk_factor_observer) # Unknown event, return state unchanged return state_pre
@staticmethod def _adjust_eod_time(time: ActusDateTime) -> ActusDateTime: """Adjust end-of-day times (23:59:59) to next day midnight for accrual.""" if time.hour == 23 and time.minute == 59: from datetime import timedelta py_dt = time.to_datetime() + timedelta(seconds=1) return ActusDateTime(py_dt.year, py_dt.month, py_dt.day, 0, 0, 0) return time def _accrue_legs( self, state: ContractState, time: ActusDateTime, attributes: ContractAttributes, ) -> tuple[float, float, float]: """Accrue interest for both legs up to the given time. Returns: Tuple of (new_ipac1, new_ipac2, new_ipac) where: ipac1 = fixed leg accrual ipac2 = floating leg accrual ipac = R(CNTRL) × ipac1 (signed fixed leg accrual) """ from jactus.utilities import contract_role_sign dcc = attributes.day_count_convention or DayCountConvention.A360 # Adjust end-of-day times for correct day count adj_time = self._adjust_eod_time(time) yf = year_fraction(state.sd, adj_time, dcc) nt = float(state.nt) # Fixed leg accrual: uses IPNR (nominal_interest_rate) fixed_rate = attributes.nominal_interest_rate or 0.0 ipac1 = float(state.ipac1) if state.ipac1 is not None else 0.0 new_ipac1 = ipac1 + yf * fixed_rate * nt # Floating leg accrual: uses current ipnr (updated by rate resets) ipac2 = float(state.ipac2) if state.ipac2 is not None else 0.0 new_ipac2 = ipac2 + yf * float(state.ipnr) * nt # ipac = R(CNTRL) × ipac1 (signed fixed leg accrual for ACTUS state) role_sign = contract_role_sign(attributes.contract_role) new_ipac = role_sign * new_ipac1 return new_ipac1, new_ipac2, new_ipac def _stf_ad( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_AD_SWPPV: Analysis Date state transition. Accrue interest for both fixed and floating legs. """ new_ipac1, new_ipac2, new_ipac = self._accrue_legs(state, event.event_time, attributes) return ContractState( tmd=state.tmd, sd=event.event_time, nt=state.nt, ipnr=state.ipnr, ipac=jnp.array(new_ipac, dtype=jnp.float32), feac=state.feac, nsc=state.nsc, isc=state.isc, ipac1=jnp.array(new_ipac1, dtype=jnp.float32), ipac2=jnp.array(new_ipac2, dtype=jnp.float32), prf=state.prf if hasattr(state, "prf") else ContractPerformance.PF, ) def _stf_ied( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_IED_SWPPV: Initial Exchange Date state transition. Initialize accruals to zero. """ return ContractState( tmd=attributes.maturity_date or event.event_time, sd=event.event_time, nt=jnp.array(attributes.notional_principal or 1.0, dtype=jnp.float32), ipnr=jnp.array( attributes.nominal_interest_rate_2 or 0.0, dtype=jnp.float32 ), # Floating rate ipac=jnp.array(0.0, dtype=jnp.float32), feac=jnp.array(0.0, dtype=jnp.float32), nsc=jnp.array(1.0, dtype=jnp.float32), isc=jnp.array(1.0, dtype=jnp.float32), ipac1=jnp.array(0.0, dtype=jnp.float32), # Fixed leg accrual ipac2=jnp.array(0.0, dtype=jnp.float32), # Floating leg accrual prf=attributes.contract_performance or ContractPerformance.PF, ) def _stf_prd( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_PRD_SWPPV: Purchase Date state transition. Accrue interest from status date to purchase date. """ new_ipac1, new_ipac2, new_ipac = self._accrue_legs(state, event.event_time, attributes) return ContractState( tmd=state.tmd, sd=event.event_time, nt=state.nt, ipnr=state.ipnr, ipac=jnp.array(new_ipac, dtype=jnp.float32), feac=state.feac, nsc=state.nsc, isc=state.isc, ipac1=jnp.array(new_ipac1, dtype=jnp.float32), ipac2=jnp.array(new_ipac2, dtype=jnp.float32), prf=state.prf if hasattr(state, "prf") else ContractPerformance.PF, ) def _stf_md( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_MD_SWPPV: Maturity Date state transition. Reset notional to zero. """ return ContractState( tmd=event.event_time, sd=event.event_time, nt=jnp.array(0.0, dtype=jnp.float32), ipnr=state.ipnr, ipac=jnp.array(0.0, dtype=jnp.float32), feac=state.feac, nsc=state.nsc, isc=state.isc, ipac1=jnp.array(0.0, dtype=jnp.float32), ipac2=jnp.array(0.0, dtype=jnp.float32), prf=state.prf if hasattr(state, "prf") else ContractPerformance.PF, ) def _stf_pr( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_PR_SWPPV: Principal Redemption state transition. No state change for swaps (no principal). """ return state def _stf_ip( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_IP_SWPPV: Interest Payment state transition. Accrue both legs up to payment time, then reset all accruals to zero. The payoff function reads ipac (net) before this transition resets it. """ # First accrue up to payment time (payoff function will read state before transition) # After payment, reset all accruals to zero return ContractState( tmd=state.tmd, sd=event.event_time, nt=state.nt, ipnr=state.ipnr, ipac=jnp.array(0.0, dtype=jnp.float32), feac=state.feac, nsc=state.nsc, isc=state.isc, ipac1=jnp.array(0.0, dtype=jnp.float32), ipac2=jnp.array(0.0, dtype=jnp.float32), prf=state.prf if hasattr(state, "prf") else ContractPerformance.PF, ) def _stf_ipfx( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_IPFX_SWPPV: Fixed Leg Interest Payment state transition. Accrue fixed leg up to payment time, then reset fixed leg accrual. """ dcc = attributes.day_count_convention or DayCountConvention.A360 yf = year_fraction(state.sd, event.event_time, dcc) nt = float(state.nt) # Accrue floating leg (don't reset it - IPFL handles that) floating_rate = float(state.ipnr) ipac2 = float(state.ipac2) if state.ipac2 is not None else 0.0 new_ipac2 = ipac2 + yf * floating_rate * nt return ContractState( tmd=state.tmd, sd=event.event_time, nt=state.nt, ipnr=state.ipnr, ipac=jnp.array(0.0, dtype=jnp.float32), feac=state.feac, nsc=state.nsc, isc=state.isc, ipac1=jnp.array(0.0, dtype=jnp.float32), # Reset fixed accrual ipac2=jnp.array(new_ipac2, dtype=jnp.float32), prf=state.prf if hasattr(state, "prf") else ContractPerformance.PF, ) def _stf_ipfl( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_IPFL_SWPPV: Floating Leg Interest Payment state transition. Reset floating leg accrual after payment. Note: IPFL follows IPFX on the same date, so fixed accrual is already reset. """ return ContractState( tmd=state.tmd, sd=event.event_time, nt=state.nt, ipnr=state.ipnr, ipac=jnp.array(0.0, dtype=jnp.float32), feac=state.feac, nsc=state.nsc, isc=state.isc, ipac1=state.ipac1, ipac2=jnp.array(0.0, dtype=jnp.float32), # Reset floating accrual prf=state.prf if hasattr(state, "prf") else ContractPerformance.PF, ) def _stf_rr( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_RR_SWPPV: Rate Reset state transition. Accrue both legs up to reset time, then update floating rate. Formula: Ipac1_t = Ipac1_(t-) + Y(Sd_(t-), t) * IPNR * Nt Ipac2_t = Ipac2_(t-) + Y(Sd_(t-), t) * Ipnr_(t-) * Nt Ipnr_t = min(max(RRMLT * O_rf(RRMO, t) + RRSP, RRLF), RRLC) """ # First accrue both legs up to reset time using old rates new_ipac1, new_ipac2, new_ipac = self._accrue_legs(state, event.event_time, attributes) # Get rate reset parameters rrmlt = ( attributes.rate_reset_multiplier if attributes.rate_reset_multiplier is not None else 1.0 ) rrsp = attributes.rate_reset_spread if attributes.rate_reset_spread is not None else 0.0 rrmo = attributes.rate_reset_market_object or "" # Observe market rate market_rate = risk_factor_observer.observe_risk_factor(rrmo, event.event_time) # Calculate new floating rate new_rate = rrmlt * float(market_rate) + rrsp # Apply caps and floors if specified if attributes.rate_reset_floor is not None: new_rate = max(new_rate, attributes.rate_reset_floor) if attributes.rate_reset_cap is not None: new_rate = min(new_rate, attributes.rate_reset_cap) return ContractState( tmd=state.tmd, sd=event.event_time, nt=state.nt, ipnr=jnp.array(new_rate, dtype=jnp.float32), ipac=jnp.array(new_ipac, dtype=jnp.float32), feac=state.feac, nsc=state.nsc, isc=state.isc, ipac1=jnp.array(new_ipac1, dtype=jnp.float32), ipac2=jnp.array(new_ipac2, dtype=jnp.float32), prf=state.prf if hasattr(state, "prf") else ContractPerformance.PF, ) def _stf_td( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_TD_SWPPV: Termination Date state transition. Zero notional and accruals, preserve rate. """ return ContractState( tmd=event.event_time, sd=event.event_time, nt=jnp.array(0.0, dtype=jnp.float32), ipnr=state.ipnr, # Preserve current floating rate ipac=jnp.array(0.0, dtype=jnp.float32), feac=state.feac, nsc=state.nsc, isc=state.isc, ipac1=jnp.array(0.0, dtype=jnp.float32), ipac2=jnp.array(0.0, dtype=jnp.float32), prf=state.prf if hasattr(state, "prf") else ContractPerformance.PF, ) def _stf_ce( self, state: ContractState, event: ContractEvent, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, ) -> ContractState: """STF_CE_SWPPV: Credit Event state transition. Credit events don't change state (would trigger termination). """ return state
[docs] class PlainVanillaSwapContract(BaseContract): """Plain Vanilla Interest Rate Swap (SWPPV) contract. Swaps fixed and floating interest rate payments on a notional amount. No exchange of principal occurs. Attributes: attributes: Contract terms and conditions risk_factor_observer: Observer for market rates """
[docs] def __init__( self, attributes: ContractAttributes, risk_factor_observer: RiskFactorObserver, child_contract_observer: Any = None, ): """Initialize SWPPV contract. Args: attributes: Contract attributes risk_factor_observer: Observer for market data child_contract_observer: Not used for SWPPV Raises: ValueError: If required attributes are missing or invalid """ # Validate contract type if attributes.contract_type != ContractType.SWPPV: raise ValueError(f"Expected contract_type=SWPPV, got {attributes.contract_type}") # Validate required attributes if attributes.notional_principal is None or attributes.notional_principal == 0: raise ValueError("notional_principal (NT) is required and must be non-zero") if attributes.nominal_interest_rate is None: raise ValueError("nominal_interest_rate (IPNR) is required for fixed leg") if attributes.nominal_interest_rate_2 is None: raise ValueError( "nominal_interest_rate_2 (IPNR2) is required for initial floating rate" ) if attributes.interest_payment_cycle is None: raise ValueError("interest_payment_cycle (IPCL) is required") if attributes.maturity_date is None: raise ValueError("maturity_date (MD) is required") super().__init__(attributes, risk_factor_observer, child_contract_observer)
# Event ordering for same-date events (lower = earlier) _EVENT_ORDER = { EventType.IED: 0, EventType.PRD: 1, EventType.IPFX: 2, EventType.IPFL: 3, EventType.IP: 4, EventType.RR: 5, EventType.MD: 10, EventType.TD: 11, EventType.AD: 12, }
[docs] def generate_event_schedule(self) -> EventSchedule: """Generate event schedule for SWPPV contract. Returns: EventSchedule with all contract events """ events = [] ccy = self.attributes.currency or "USD" maturity = self.attributes.maturity_date # Determine settlement mode: D = separate (IPFX/IPFL), S = net (IP) ds = self.attributes.delivery_settlement or "D" use_separate = ds == "D" # IED: Initial Exchange Date if self.attributes.initial_exchange_date: events.append( ContractEvent( event_type=EventType.IED, event_time=self.attributes.initial_exchange_date, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) # Generate Rate Reset schedule (exclude maturity date) if self.attributes.rate_reset_cycle and self.attributes.rate_reset_anchor: rr_times = generate_schedule( start=self.attributes.rate_reset_anchor, end=maturity, cycle=self.attributes.rate_reset_cycle, ) for rr_time in rr_times: # Exclude RR at maturity date if maturity and rr_time.to_iso()[:10] == maturity.to_iso()[:10]: continue events.append( ContractEvent( event_type=EventType.RR, event_time=rr_time, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) elif self.attributes.rate_reset_anchor: # Single RR at anchor date (no cycle) rr_anchor = self.attributes.rate_reset_anchor if not maturity or rr_anchor.to_iso()[:10] != maturity.to_iso()[:10]: events.append( ContractEvent( event_type=EventType.RR, event_time=rr_anchor, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) # Generate Interest Payment schedule ip_anchor = ( self.attributes.interest_payment_anchor or self.attributes.interest_calculation_base_anchor or self.attributes.initial_exchange_date ) if self.attributes.interest_payment_cycle and ip_anchor: ip_times = generate_schedule( start=ip_anchor, end=maturity, cycle=self.attributes.interest_payment_cycle, ) # Add maturity date as final payment if not already included if maturity: maturity_iso = maturity.to_iso()[:10] if not any(t.to_iso()[:10] == maturity_iso for t in ip_times): ip_times.append(maturity) for ip_time in ip_times: if use_separate: events.append( ContractEvent( event_type=EventType.IPFX, event_time=ip_time, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) events.append( ContractEvent( event_type=EventType.IPFL, event_time=ip_time, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) else: events.append( ContractEvent( event_type=EventType.IP, event_time=ip_time, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) # PRD: Purchase Date if self.attributes.purchase_date: events.append( ContractEvent( event_type=EventType.PRD, event_time=self.attributes.purchase_date, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) # MD: Maturity Date if maturity: events.append( ContractEvent( event_type=EventType.MD, event_time=maturity, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) # Analysis dates if self.attributes.analysis_dates: for ad_time in self.attributes.analysis_dates: events.append( ContractEvent( event_type=EventType.AD, event_time=ad_time, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) # Termination date if self.attributes.termination_date: events.append( ContractEvent( event_type=EventType.TD, event_time=self.attributes.termination_date, payoff=jnp.array(0.0, dtype=jnp.float32), currency=ccy, ) ) # Sort by date, then by event type ordering (IP/IPFX/IPFL before RR) events.sort( key=lambda e: ( e.event_time.to_iso()[:10], self._EVENT_ORDER.get(e.event_type, 99), ) ) return EventSchedule( contract_id=self.attributes.contract_id, events=tuple(events), )
[docs] def initialize_state(self) -> ContractState: """Initialize contract state at status date. Returns: Initial ContractState """ return ContractState( tmd=self.attributes.maturity_date or self.attributes.status_date, sd=self.attributes.status_date, nt=jnp.array(self.attributes.notional_principal or 1.0, dtype=jnp.float32), ipnr=jnp.array(self.attributes.nominal_interest_rate_2 or 0.0, dtype=jnp.float32), ipac=jnp.array(0.0, dtype=jnp.float32), feac=jnp.array(0.0, dtype=jnp.float32), nsc=jnp.array(1.0, dtype=jnp.float32), isc=jnp.array(1.0, dtype=jnp.float32), ipac1=jnp.array(0.0, dtype=jnp.float32), ipac2=jnp.array(0.0, dtype=jnp.float32), prf=self.attributes.contract_performance or ContractPerformance.PF, )
[docs] def get_payoff_function(self, event_type: Any) -> PlainVanillaSwapPayoffFunction: """Get payoff function for SWPPV contract. Args: event_type: Type of event (not used, kept for interface compatibility) Returns: PlainVanillaSwapPayoffFunction instance """ return PlainVanillaSwapPayoffFunction( contract_role=self.attributes.contract_role, currency=self.attributes.currency, )
[docs] def get_state_transition_function( self, event_type: Any ) -> PlainVanillaSwapStateTransitionFunction: """Get state transition function for SWPPV contract. Args: event_type: Type of event (not used, kept for interface compatibility) Returns: PlainVanillaSwapStateTransitionFunction instance """ return PlainVanillaSwapStateTransitionFunction()
[docs] def simulate( self, risk_factor_observer: RiskFactorObserver | None = None, child_contract_observer: Any = None, scenario: Scenario | None = None, behavior_observers: list[BehaviorRiskFactorObserver] | None = None, ) -> SimulationHistory: """Simulate SWPPV contract. Overrides base to filter out events before purchaseDate and after terminationDate. The full event schedule is processed for state computation, but only visible events are returned. """ result = super().simulate( risk_factor_observer, child_contract_observer, scenario=scenario, behavior_observers=behavior_observers, ) # Filter events: keep only PRD onwards when purchaseDate is set if self.attributes.purchase_date: purchase_iso = self.attributes.purchase_date.to_iso() filtered = [e for e in result.events if e.event_time.to_iso() >= purchase_iso] result = SimulationHistory( events=filtered, states=result.states, initial_state=result.initial_state, final_state=result.final_state, ) # Filter events: keep only up to and including TD when terminationDate is set if self.attributes.termination_date: td_iso = self.attributes.termination_date.to_iso()[:10] filtered = [e for e in result.events if e.event_time.to_iso()[:10] <= td_iso] result = SimulationHistory( events=filtered, states=result.states, initial_state=result.initial_state, final_state=result.final_state, ) return result