"""Payoff function framework for ACTUS contracts.
This module implements the payoff function (POF) framework as defined in
Section 2.7 and 3.9 of the ACTUS specification v1.1.
The payoff function f(e, S, M, t, o_rf) calculates the cashflow amount for
a contract event, applying contract role sign and FX rate adjustments.
References:
ACTUS Technical Specification v1.1:
- Section 2.7: Payoff Functions
- Section 3.9: Canonical Contract Payoff Function F(x,t)
- Section 3.10: Settlement Currency FX Rate X^CURS_CUR(t)
"""
from abc import ABC, abstractmethod
from typing import Protocol, runtime_checkable
import jax.numpy as jnp
from jactus.core import ActusDateTime, ContractAttributes, ContractState
from jactus.core.types import ContractRole, EventType
from jactus.utilities import contract_role_sign
[docs]
@runtime_checkable
class PayoffFunction(Protocol):
"""Protocol for payoff functions.
A payoff function calculates the cashflow amount for a contract event.
All concrete POF implementations must implement this protocol.
The payoff function signature is:
f(e, S, M, t, o_rf) -> payoff
Where:
e = event type
S = pre-event state
M = contract attributes
t = event time
o_rf = risk factor observer
Returns:
Payoff amount as JAX array (scalar)
"""
[docs]
def __call__(
self,
event_type: EventType,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: "RiskFactorObserver", # type: ignore # noqa: F821
) -> jnp.ndarray:
"""Calculate payoff for an event.
Args:
event_type: Type of contract event
state: Pre-event contract state
attributes: Contract attributes/terms
time: Event time
risk_factor_observer: Observer for market data
Returns:
Payoff amount as JAX array (scalar)
References:
ACTUS v1.1 Section 2.7
"""
...
[docs]
class BasePayoffFunction(ABC):
"""Base class for payoff functions with common logic.
This abstract base class implements the common payoff calculation pipeline:
1. Calculate base payoff amount (contract-specific, abstract)
2. Apply contract role sign R(CNTRL)
3. Apply FX rate X^CURS_CUR(t) if settlement currency differs
Subclasses must implement calculate_payoff() with contract-specific logic.
Attributes:
contract_role: Contract role (RPA, RPL, etc.)
currency: Contract currency
settlement_currency: Settlement currency (None = same as contract currency)
References:
ACTUS v1.1 Section 2.7, 3.10
"""
[docs]
def __init__(
self,
contract_role: ContractRole,
currency: str,
settlement_currency: str | None = None,
):
"""Initialize base payoff function.
Args:
contract_role: Contract role for sign adjustment
currency: Contract currency (e.g., "USD")
settlement_currency: Settlement currency (None = same as contract)
"""
self.contract_role = contract_role
self.currency = currency
self.settlement_currency = settlement_currency
[docs]
@abstractmethod
def calculate_payoff(
self,
event_type: EventType,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: "RiskFactorObserver", # type: ignore # noqa: F821
) -> jnp.ndarray:
"""Calculate base payoff before role sign and FX adjustments.
This is the contract-specific payoff logic that subclasses must implement.
Args:
event_type: Type of contract event
state: Pre-event contract state
attributes: Contract attributes/terms
time: Event time
risk_factor_observer: Observer for market data
Returns:
Base payoff amount as JAX array (scalar)
"""
...
[docs]
def apply_role_sign(self, amount: jnp.ndarray) -> jnp.ndarray:
"""Apply contract role sign R(CNTRL).
The contract role determines the sign of cashflows:
- RPA, LG, BUY, etc.: +1 (receive cashflows)
- RPL, ST, SEL, etc.: -1 (pay cashflows)
Args:
amount: Unsigned payoff amount
Returns:
Signed payoff amount
Formula:
signed_amount = amount * R(CNTRL)
References:
ACTUS v1.1 Table 1 (Contract Role Signs)
"""
sign = contract_role_sign(self.contract_role)
return amount * jnp.array(sign, dtype=jnp.float32)
[docs]
def apply_fx_rate(
self,
amount: jnp.ndarray,
time: ActusDateTime,
risk_factor_observer: "RiskFactorObserver", # type: ignore # noqa: F821
) -> jnp.ndarray:
"""Apply FX rate X^CURS_CUR(t) if settlement currency differs.
If the settlement currency differs from the contract currency, the payoff
must be converted using the FX rate observed at the event time.
Args:
amount: Payoff in contract currency
time: Event time
risk_factor_observer: Observer for FX rates
Returns:
Payoff in settlement currency
Formula:
If CURS != CUR:
payoff_settlement = payoff_contract * X^CURS_CUR(t)
Else:
payoff_settlement = payoff_contract
References:
ACTUS v1.1 Section 3.10
"""
fx_rate = settlement_currency_fx_rate(
time=time,
contract_currency=self.currency,
settlement_currency=self.settlement_currency,
risk_factor_observer=risk_factor_observer,
)
return amount * fx_rate
[docs]
def __call__(
self,
event_type: EventType,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: "RiskFactorObserver", # type: ignore # noqa: F821
) -> jnp.ndarray:
"""Calculate complete payoff with role sign and FX adjustments.
This method implements the complete payoff calculation pipeline:
1. Calculate base payoff (contract-specific)
2. Apply contract role sign
3. Apply FX rate if needed
Args:
event_type: Type of contract event
state: Pre-event contract state
attributes: Contract attributes/terms
time: Event time
risk_factor_observer: Observer for market data
Returns:
Final payoff amount as JAX array (scalar)
References:
ACTUS v1.1 Section 2.7
"""
# Step 1: Calculate base payoff
# Each POF method returns the complete payoff with correct sign.
# R(CNTRL) is applied only in formulas that use unsigned attributes
# (e.g., IED). Formulas using signed state variables (IP, MD, PR)
# already have the correct sign from state.nt.
payoff = self.calculate_payoff(event_type, state, attributes, time, risk_factor_observer)
# Step 2: Apply FX rate
return self.apply_fx_rate(payoff, time, risk_factor_observer)
[docs]
def settlement_currency_fx_rate(
time: ActusDateTime,
contract_currency: str,
settlement_currency: str | None,
risk_factor_observer: "RiskFactorObserver", # type: ignore # noqa: F821
) -> jnp.ndarray:
"""Get FX rate X^CURS_CUR(t) for settlement currency conversion.
Returns the FX rate to convert from contract currency (CUR) to settlement
currency (CURS) at the given time.
Args:
time: Time at which to observe FX rate
contract_currency: Contract currency code (e.g., "USD")
settlement_currency: Settlement currency code (None = same as contract)
risk_factor_observer: Observer for FX rate data
Returns:
FX rate as JAX array (1.0 if currencies are same)
Logic:
If settlement_currency is None: return 1.0
If settlement_currency == contract_currency: return 1.0
Otherwise: observe FX rate "contract_currency/settlement_currency"
Example:
>>> # Contract in EUR, settled in USD
>>> fx_rate = settlement_currency_fx_rate(
... time=t,
... contract_currency="EUR",
... settlement_currency="USD",
... risk_factor_observer=observer
... )
>>> # Returns EUR/USD rate, e.g., 1.18
References:
ACTUS v1.1 Section 3.10
"""
# If no settlement currency specified, or same as contract currency
if settlement_currency is None or settlement_currency == contract_currency:
return jnp.array(1.0, dtype=jnp.float32)
# Observe FX rate from risk factor observer
# Convention: "BASE/QUOTE" where BASE is contract currency
fx_identifier = f"{contract_currency}/{settlement_currency}"
return risk_factor_observer.observe_risk_factor(identifier=fx_identifier, time=time) # type: ignore[no-any-return]
[docs]
def canonical_contract_payoff(
contract: "BaseContract", # type: ignore # noqa: F821
time: ActusDateTime,
risk_factor_observer: "RiskFactorObserver", # type: ignore # noqa: F821
) -> jnp.ndarray:
"""Calculate canonical contract payoff F(x, t).
The canonical contract payoff is the sum of all future event payoffs at
time t, evaluated using the current risk factor conditions.
This function is used for contract valuation and mark-to-market calculations.
Args:
contract: Contract instance (must have get_events() and payoff_function)
time: Valuation time
risk_factor_observer: Observer for risk factors
Returns:
Total payoff of all future events as JAX array (scalar)
Formula:
F(x, t) = Σ f(e_i, S_i, M, t_i, o_rf) for all events e_i where t_i >= t
Where:
- e_i = i-th future event
- S_i = state at event time
- M = contract attributes
- t_i = event time
- o_rf = risk factor observer (frozen at current state)
Note:
This uses current risk factor conditions for all future events, which
may differ from the actual risk factors at those event times.
Example:
>>> contract = PAMContract(attributes)
>>> observer = MockRiskFactorObserver({'LIBOR': {t: 0.03}})
>>> f_xt = canonical_contract_payoff(contract, t, observer)
>>> print(f"Contract value: {f_xt}")
References:
ACTUS v1.1 Section 3.9
"""
# Simulate the contract to get all events with computed payoffs
result = contract.simulate(risk_factor_observer=risk_factor_observer)
# Sum payoffs of all future events (t_i >= t)
total = 0.0
for event in result.events:
if event.event_time >= time:
total += float(event.payoff)
return jnp.array(total, dtype=jnp.float32)