"""Cap-Floor (CAPFL) contract implementation.
This module implements interest rate cap and floor contracts. A cap/floor
protects against interest rate movements:
- Cap: pays max(0, floating_rate - cap_rate) * NT * YF
- Floor: pays max(0, floor_rate - floating_rate) * NT * YF
- Collar: both cap and floor
The CAPFL wraps an underlier (typically PAM or SWPPV) and generates events
on the underlier's IP/RR schedule.
Example:
>>> from jactus.contracts import CapFloorContract
>>> from jactus.core import ContractAttributes, ActusDateTime, ContractType, ContractRole
>>> from jactus.observers import ConstantRiskFactorObserver, MockChildContractObserver
>>>
>>> attrs = ContractAttributes(
... contract_id="CAP-001",
... contract_type=ContractType.CAPFL,
... contract_role=ContractRole.BUY,
... status_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
... maturity_date=ActusDateTime(2029, 1, 1, 0, 0, 0),
... rate_reset_cap=0.06,
... contract_structure='{"Underlying": "SWAP-001"}',
... )
>>> rf_obs = ConstantRiskFactorObserver(0.03)
>>> child_obs = MockChildContractObserver()
>>> cap = CapFloorContract(attrs, rf_obs, child_obs)
References:
ACTUS Technical Specification v1.1, Section 7.14
"""
import json
from typing import Any
import jax.numpy as jnp
from jactus.contracts.base import BaseContract, SimulationHistory
from jactus.core import (
ActusDateTime,
ContractAttributes,
ContractEvent,
ContractPerformance,
ContractRole,
ContractState,
ContractType,
EventSchedule,
EventType,
)
from jactus.core.types import DayCountConvention
from jactus.functions import BasePayoffFunction, BaseStateTransitionFunction
from jactus.observers import ChildContractObserver, RiskFactorObserver
from jactus.observers.behavioral import BehaviorRiskFactorObserver
from jactus.observers.scenario import Scenario
from jactus.utilities.conventions import year_fraction
from jactus.utilities.schedules import generate_schedule
[docs]
class CapFloorPayoffFunction(BasePayoffFunction):
"""Payoff function for CAPFL contracts.
Computes cap/floor differential payoffs at IP events.
"""
[docs]
def __init__(
self,
contract_role: ContractRole | None = None,
currency: str | None = None,
settlement_currency: str | None = None,
cap_rate: float | None = None,
floor_rate: float | None = None,
notional: float = 0.0,
day_count_convention: DayCountConvention = DayCountConvention.A365,
):
super().__init__(
contract_role or ContractRole.RPA,
currency or "USD",
settlement_currency,
)
self.cap_rate = cap_rate
self.floor_rate = floor_rate
self.notional = notional
self.dcc = day_count_convention
[docs]
def calculate_payoff(
self,
event_type: EventType,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""Calculate cap/floor payoff.
For IP events, computes the cap/floor differential:
Cap: max(0, floating_rate - cap_rate) * NT * YF
Floor: max(0, floor_rate - floating_rate) * NT * YF
"""
if event_type == EventType.IP:
return self._pof_ip(state, attributes, time)
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ip(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
) -> jnp.ndarray:
"""POF_IP_CAPFL: Interest payment - cap/floor differential."""
floating_rate = float(state.ipnr)
nt = self.notional
yf = year_fraction(state.sd, time, self.dcc)
payoff = 0.0
if self.cap_rate is not None:
payoff += max(0.0, floating_rate - self.cap_rate) * nt * yf
if self.floor_rate is not None:
payoff += max(0.0, self.floor_rate - floating_rate) * nt * yf
# Role sign: BUY receives protection, SEL pays it
role_sign = 1.0
if attributes.contract_role in (ContractRole.RPL, ContractRole.ST, ContractRole.SEL):
role_sign = -1.0
return jnp.array(role_sign * payoff, dtype=jnp.float32)
[docs]
class CapFloorStateTransitionFunction(BaseStateTransitionFunction):
"""State transition function for CAPFL contracts.
Tracks the floating rate through RR events and advances sd at IP events.
"""
[docs]
def __init__(self, dcc: DayCountConvention = DayCountConvention.A365):
super().__init__()
self.dcc = dcc
[docs]
def transition_state(
self,
event_type: EventType,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
if event_type == EventType.RR:
return self._stf_rr(state_pre, attributes, time, risk_factor_observer)
if event_type == EventType.IP:
return self._stf_ip(state_pre, time)
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=state_pre.ipnr,
ipac=state_pre.ipac,
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_rr(
self,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""STF_RR_CAPFL: Rate Reset - observe market rate for cap/floor tracking."""
market_object = attributes.rate_reset_market_object or ""
observed_rate = float(
risk_factor_observer.observe_risk_factor(market_object, time, state_pre, attributes)
)
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=jnp.array(observed_rate, dtype=jnp.float32),
ipac=state_pre.ipac,
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
def _stf_ip(self, state_pre: ContractState, time: ActusDateTime) -> ContractState:
"""STF_IP_CAPFL: Interest Payment - advance status date."""
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=state_pre.ipnr,
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
[docs]
class CapFloorContract(BaseContract):
"""Cap-Floor (CAPFL) contract.
An interest rate cap or floor that pays the differential when
the floating rate breaches the cap or floor rate.
The contract either:
1. References an underlier via contract_structure (child observer mode)
2. Contains embedded underlier terms for standalone operation
"""
[docs]
def __init__(
self,
attributes: ContractAttributes,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: ChildContractObserver | None = None,
):
if attributes.contract_type != ContractType.CAPFL:
raise ValueError(f"Expected contract_type=CAPFL, got {attributes.contract_type}")
if child_contract_observer is None:
raise ValueError("child_contract_observer is required for CAPFL contracts")
if attributes.contract_structure is None:
raise ValueError("contract_structure (CTST) is required and must contain Underlying")
# Parse contract structure (JSON string)
try:
ctst = json.loads(attributes.contract_structure)
except (json.JSONDecodeError, TypeError) as e:
raise ValueError(f"contract_structure must be valid JSON: {e}") from e
if not isinstance(ctst, dict):
raise ValueError("contract_structure must be a JSON object (dictionary)")
if "Underlying" not in ctst:
raise ValueError("contract_structure must contain 'Underlying' key")
if attributes.rate_reset_cap is None and attributes.rate_reset_floor is None:
raise ValueError(
"At least one of rate_reset_cap (RRLC) or rate_reset_floor (RRLF) must be set"
)
# Parse underlier terms if embedded
underlying = ctst["Underlying"]
self._underlier_terms: dict[str, Any] | None = None
if isinstance(underlying, dict):
self._underlier_terms = underlying
super().__init__(attributes, risk_factor_observer, child_contract_observer)
def _parse_contract_structure(self) -> dict[str, Any]:
return dict(json.loads(self.attributes.contract_structure or "{}"))
[docs]
def generate_event_schedule(self) -> EventSchedule:
"""Generate event schedule for CAPFL contract.
If underlier terms are embedded, generates IP and RR schedules
directly from those terms. Otherwise, queries the child observer.
"""
events: list[ContractEvent] = []
if self._underlier_terms:
events = self._generate_standalone_schedule()
else:
events = self._generate_child_observer_schedule()
# Add analysis dates
currency = self.attributes.currency or "USD"
if self.attributes.analysis_dates:
for ad_time in self.attributes.analysis_dates:
events.append(
ContractEvent(
event_type=EventType.AD,
event_time=ad_time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=currency,
)
)
# Add termination date
if self.attributes.termination_date:
events.append(
ContractEvent(
event_type=EventType.TD,
event_time=self.attributes.termination_date,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=currency,
)
)
events.sort(
key=lambda e: (
e.event_time.year,
e.event_time.month,
e.event_time.day,
e.sequence,
)
)
return EventSchedule(
contract_id=self.attributes.contract_id,
events=tuple(events),
)
def _generate_standalone_schedule(self) -> list[ContractEvent]:
"""Generate schedule from embedded underlier terms."""
events: list[ContractEvent] = []
terms = self._underlier_terms
assert terms is not None
# Parse underlier dates/cycles
ied_str = terms.get("initialExchangeDate")
md_str = terms.get("maturityDate")
if md_str is None:
return events
md = ActusDateTime.from_iso(md_str)
# Use IED as schedule start, or derive from MD by stepping back
if ied_str:
start = ActusDateTime.from_iso(ied_str)
else:
# Derive schedule anchor from MD by stepping backward
start = self._derive_start_from_md(md, terms)
currency = self.attributes.currency or terms.get("currency", "USD")
# IP schedule from underlier
# For CAPFL, IP runs BEFORE RR at the same timestamp so that IP
# uses the rate from the previous period (not the just-reset rate)
ip_cycle_str = terms.get("cycleOfInterestPayment", "")
if ip_cycle_str:
cycle = self._parse_cycle(ip_cycle_str)
ip_dates = generate_schedule(start=start, cycle=cycle, end=md)
for ip_time in ip_dates[1:]:
events.append(
ContractEvent(
event_type=EventType.IP,
event_time=ip_time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=currency,
sequence=0, # IP BEFORE RR at same time
)
)
# RR schedule from underlier
rr_cycle_str = terms.get("cycleOfRateReset", "")
if rr_cycle_str:
cycle = self._parse_cycle(rr_cycle_str)
rr_dates = generate_schedule(start=start, cycle=cycle, end=md)
# Skip first date, include subsequent
for rr_time in rr_dates[1:]:
events.append(
ContractEvent(
event_type=EventType.RR,
event_time=rr_time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=currency,
sequence=1, # RR AFTER IP at same time
)
)
return events
def _derive_start_from_md(self, md: ActusDateTime, terms: dict[str, Any]) -> ActusDateTime:
"""Derive schedule start by stepping backward from MD in cycle increments.
When no IED is specified, find the earliest cycle-aligned date
after the status date by stepping backward from MD.
"""
import re
cycle_str = terms.get("cycleOfInterestPayment", terms.get("cycleOfRateReset", ""))
cycle = self._parse_cycle(cycle_str)
match = re.match(r"(\d+)([DWMY])", cycle)
if not match:
return self.attributes.status_date
n = int(match.group(1))
unit = match.group(2)
if unit == "M":
months = n
elif unit == "Y":
months = n * 12
else:
return self.attributes.status_date
# Step backward from MD until we pass status_date
sd = self.attributes.status_date
current = md
while True:
year = current.year
month = current.month - months
while month <= 0:
year -= 1
month += 12
day = min(current.day, 28)
try:
prev = ActusDateTime(year, month, day, 0, 0, 0)
except Exception:
prev = ActusDateTime(year, month, 28, 0, 0, 0)
if prev <= sd:
return current
current = prev
def _generate_child_observer_schedule(self) -> list[ContractEvent]:
"""Generate schedule using child observer (legacy approach)."""
assert self.child_contract_observer is not None
events: list[ContractEvent] = []
ctst = self._parse_contract_structure()
underlier_id = ctst["Underlying"]
uncapped_events = self.child_contract_observer.observe_events(
underlier_id,
self.attributes.status_date,
None,
)
capped_events = self.child_contract_observer.observe_events(
underlier_id,
self.attributes.status_date,
None,
)
uncapped_map: dict[ActusDateTime, ContractEvent] = {}
for event in uncapped_events:
if event.event_type == EventType.IP:
uncapped_map[event.event_time] = event
capped_map: dict[ActusDateTime, ContractEvent] = {}
for event in capped_events:
if event.event_type == EventType.IP:
capped_map[event.event_time] = event
all_times = set(uncapped_map.keys()) | set(capped_map.keys())
for time in all_times:
uncapped_payoff = float(uncapped_map[time].payoff) if time in uncapped_map else 0.0
capped_payoff = float(capped_map[time].payoff) if time in capped_map else 0.0
differential = abs(uncapped_payoff - capped_payoff)
if differential > 0.0:
events.append(
ContractEvent(
event_type=EventType.IP,
event_time=time,
payoff=jnp.array(differential, dtype=jnp.float32),
currency=self.attributes.currency or "USD",
)
)
if self.attributes.maturity_date:
events.append(
ContractEvent(
event_type=EventType.MD,
event_time=self.attributes.maturity_date,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=self.attributes.currency or "USD",
)
)
return events
@staticmethod
def _parse_cycle(cycle_str: str) -> str:
"""Convert ACTUS ISO cycle (P3ML1) to JACTUS format (3M)."""
s = cycle_str
if s.startswith("P"):
s = s[1:]
if "L" in s:
s = s[: s.index("L")]
return s
[docs]
def initialize_state(self) -> ContractState:
"""Initialize CAPFL contract state."""
# Get initial rate from underlier terms
ipnr = 0.0
if self._underlier_terms:
ipnr = float(self._underlier_terms.get("nominalInterestRate", 0.0))
return ContractState(
tmd=self.attributes.maturity_date or self.attributes.status_date,
sd=self.attributes.status_date,
nt=jnp.array(1.0, dtype=jnp.float32),
ipnr=jnp.array(ipnr, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prf=self.attributes.contract_performance or ContractPerformance.PF,
)
[docs]
def get_payoff_function(self, event_type: Any) -> CapFloorPayoffFunction:
# Get underlier's notional and DCC
notional = 0.0
dcc = DayCountConvention.A365
if self._underlier_terms:
notional = float(self._underlier_terms.get("notionalPrincipal", 0.0))
dcc_str = self._underlier_terms.get("dayCountConvention", "A365")
dcc = _parse_dcc(dcc_str)
return CapFloorPayoffFunction(
contract_role=self.attributes.contract_role,
currency=self.attributes.currency,
cap_rate=self.attributes.rate_reset_cap,
floor_rate=self.attributes.rate_reset_floor,
notional=notional,
day_count_convention=dcc,
)
[docs]
def get_state_transition_function(self, event_type: Any) -> CapFloorStateTransitionFunction:
dcc = DayCountConvention.A365
if self._underlier_terms:
dcc_str = self._underlier_terms.get("dayCountConvention", "A365")
dcc = _parse_dcc(dcc_str)
return CapFloorStateTransitionFunction(dcc=dcc)
[docs]
def simulate(
self,
risk_factor_observer: RiskFactorObserver | None = None,
child_contract_observer: ChildContractObserver | None = None,
scenario: Scenario | None = None,
behavior_observers: list[BehaviorRiskFactorObserver] | None = None,
) -> SimulationHistory:
"""Simulate CAPFL contract.
RR events are used internally for rate tracking but filtered from
the output since CAPFL only exposes IP events externally.
"""
risk_obs = risk_factor_observer or self.risk_factor_observer
# Store market object from underlier for RR observations
if self._underlier_terms:
market_object = self._underlier_terms.get("marketObjectCodeOfRateReset", "")
if market_object and not self.attributes.rate_reset_market_object:
self.attributes.rate_reset_market_object = market_object
result = super().simulate(
risk_obs,
child_contract_observer,
scenario=scenario,
behavior_observers=behavior_observers,
)
# Filter out internal RR events — CAPFL only outputs IP events
# Zero out state for IP events (CAPFL is a derivative with no own notional/rate)
zero_state = ContractState(
tmd=result.initial_state.tmd,
sd=result.initial_state.sd,
nt=jnp.array(0.0, dtype=jnp.float32),
ipnr=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
prf=ContractPerformance.PF,
)
filtered_events = []
for e in result.events:
if e.event_type == EventType.RR:
continue
filtered_events.append(
ContractEvent(
event_type=e.event_type,
event_time=e.event_time,
payoff=e.payoff,
currency=e.currency,
state_pre=zero_state,
state_post=zero_state,
sequence=e.sequence,
)
)
return SimulationHistory(
events=filtered_events,
states=[zero_state] * len(filtered_events),
initial_state=result.initial_state,
final_state=result.final_state,
)
def _parse_dcc(dcc_str: str) -> DayCountConvention:
"""Parse day count convention string to enum."""
mapping = {
"AA": DayCountConvention.AA,
"A360": DayCountConvention.A360,
"A365": DayCountConvention.A365,
"30E360": DayCountConvention.E30360,
"30E360ISDA": DayCountConvention.E30360ISDA,
"30360": DayCountConvention.B30360,
"BUS252": DayCountConvention.BUS252,
}
return mapping.get(dcc_str, DayCountConvention.A365)