"""Stock (STK) contract implementation.
This module implements the STK contract type - an equity position with dividend payments.
STK is a simple contract representing stock ownership with potential dividend income.
ACTUS Reference:
ACTUS v1.1 Section 7.9 - STK: Stock
Key Features:
- Equity position value from market observation
- Fixed or market-observed dividend payments
- Purchase and termination events
- Minimal state (only performance and status date)
- 6 event types total
Example:
>>> from jactus.contracts.stk import StockContract
>>> from jactus.core import ContractAttributes, ContractType, ContractRole
>>> from jactus.observers import ConstantRiskFactorObserver
>>>
>>> attrs = ContractAttributes(
... contract_id="STK-001",
... contract_type=ContractType.STK,
... contract_role=ContractRole.RPA,
... status_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
... currency="USD",
... market_object_code="AAPL", # Stock ticker
... )
>>>
>>> rf_obs = ConstantRiskFactorObserver(constant_value=150.0) # Stock price
>>> contract = StockContract(
... attributes=attrs,
... risk_factor_observer=rf_obs
... )
>>> result = contract.simulate()
"""
from typing import Any
import flax.nnx as nnx
import jax.numpy as jnp
from jactus.contracts.base import BaseContract
from jactus.core import (
ActusDateTime,
ContractAttributes,
ContractEvent,
ContractState,
ContractType,
EventSchedule,
EventType,
)
from jactus.functions import BasePayoffFunction, BaseStateTransitionFunction
from jactus.observers import ChildContractObserver, RiskFactorObserver
from jactus.utilities import contract_role_sign
[docs]
class StockPayoffFunction(BasePayoffFunction):
"""Payoff function for STK contracts.
Implements all 6 STK payoff functions according to ACTUS specification.
ACTUS Reference:
ACTUS v1.1 Section 7.9 - STK Payoff Functions
Events:
AD: Analysis Date (0.0)
PRD: Purchase Date (pay purchase price)
TD: Termination Date (receive termination price)
DV(fix): Fixed Dividend Payment
DV: Market-Observed Dividend Payment
CE: Credit Event (0.0)
"""
[docs]
def calculate_payoff(
self,
event_type: Any,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""Calculate payoff for STK events.
Dispatches to specific payoff function based on event type.
Args:
event_type: Type of event
state: Current contract state
attributes: Contract attributes
time: Event time
risk_factor_observer: Observer for market data
Returns:
Payoff amount as JAX array
ACTUS Reference:
POF_[event]_STK functions from Section 7.9
"""
if event_type == EventType.AD:
return self._pof_ad(state, attributes, time, risk_factor_observer)
if event_type == EventType.PRD:
return self._pof_prd(state, attributes, time, risk_factor_observer)
if event_type == EventType.TD:
return self._pof_td(state, attributes, time, risk_factor_observer)
if event_type == EventType.DV:
return self._pof_dv(state, attributes, time, risk_factor_observer)
if event_type == EventType.CE:
return self._pof_ce(state, attributes, time, risk_factor_observer)
# Unknown event type - return 0
return jnp.array(0.0, dtype=jnp.float32)
def _pof_ad(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_AD_STK: Analysis Date has no cashflow.
Returns:
0.0
"""
return jnp.array(0.0, dtype=jnp.float32)
def _pof_prd(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_PRD_STK: Purchase Date - pay purchase price.
Formula:
POF_PRD_STK = X^CURS_CUR(t) × R(CNTRL) × (-PPRD)
Where:
PPRD: Price at purchase date
R(CNTRL): Role sign
X^CURS_CUR(t): FX rate
Returns:
Negative of purchase price (outflow for buyer)
"""
pprd = attributes.price_at_purchase_date or 0.0
role_sign = contract_role_sign(attributes.contract_role)
payoff = role_sign * (-pprd)
return jnp.array(payoff, dtype=jnp.float32)
def _pof_td(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_TD_STK: Termination Date - receive termination price.
Formula:
POF_TD_STK = X^CURS_CUR(t) × R(CNTRL) × PTD
Where:
PTD: Price at termination date
R(CNTRL): Role sign
X^CURS_CUR(t): FX rate
Returns:
Termination price (inflow for seller)
"""
ptd = attributes.price_at_termination_date or 0.0
role_sign = contract_role_sign(attributes.contract_role)
payoff = role_sign * ptd
return jnp.array(payoff, dtype=jnp.float32)
def _pof_dv(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_DV_STK: Dividend Payment.
Formula (observed):
POF_DV_STK = R(CNTRL) × O_dv(DVMO, t)
Where:
O_dv: Observed dividend amount from risk factor
R(CNTRL): Role sign
"""
role_sign = contract_role_sign(attributes.contract_role)
# Observe dividend amount from risk factors
dvmo = attributes.market_object_code_of_dividends or ""
if dvmo:
dv_amount = float(risk_factor_observer.observe_risk_factor(dvmo, time))
else:
dv_amount = 0.0
return jnp.array(role_sign * dv_amount, dtype=jnp.float32)
def _pof_ce(
self,
state: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> jnp.ndarray:
"""POF_CE_STK: Credit Event - no cashflow.
Returns:
0.0 (credit events handled in state transition)
"""
return jnp.array(0.0, dtype=jnp.float32)
[docs]
class StockStateTransitionFunction(BaseStateTransitionFunction):
"""State transition function for STK contracts.
Implements all 6 STK state transition functions according to ACTUS specification.
ACTUS Reference:
ACTUS v1.1 Section 7.9 - STK State Transition Functions
Note:
STK has minimal state - only status date (sd) and performance (prf).
All events simply update the status date.
"""
[docs]
def transition_state(
self,
event_type: Any,
state_pre: ContractState,
attributes: ContractAttributes,
time: ActusDateTime,
risk_factor_observer: RiskFactorObserver,
) -> ContractState:
"""Transition STK contract state.
All STK events have the same state transition: update status date only.
Args:
event_type: Type of event
state_pre: State before event
attributes: Contract attributes
time: Event time
risk_factor_observer: Observer for market data
Returns:
Updated contract state
ACTUS Reference:
STF_[event]_STK functions from Section 7.9
"""
# All STK events just update status date
# Performance tracking could be added here if needed
return ContractState(
sd=time,
tmd=state_pre.tmd,
nt=state_pre.nt,
ipnr=state_pre.ipnr,
ipac=state_pre.ipac,
feac=state_pre.feac,
nsc=state_pre.nsc,
isc=state_pre.isc,
)
[docs]
class StockContract(BaseContract):
"""Stock (STK) contract implementation.
Represents an equity position with potential dividend payments.
STK is one of the simplest ACTUS contracts, with minimal state
and straightforward cashflow logic.
ACTUS Reference:
ACTUS v1.1 Section 7.9
Attributes:
attributes: Contract terms and parameters
risk_factor_observer: Observer for market prices and dividends
child_contract_observer: Observer for child contracts (optional)
Example:
>>> attrs = ContractAttributes(
... contract_id="STK-001",
... contract_type=ContractType.STK,
... contract_role=ContractRole.RPA,
... status_date=ActusDateTime(2024, 1, 1, 0, 0, 0),
... currency="USD",
... market_object_code="AAPL",
... )
>>> contract = StockContract(attrs, risk_obs)
>>> result = contract.simulate()
"""
[docs]
def __init__(
self,
attributes: ContractAttributes,
risk_factor_observer: RiskFactorObserver,
child_contract_observer: ChildContractObserver | None = None,
rngs: nnx.Rngs | None = None,
):
"""Initialize STK contract.
Args:
attributes: Contract attributes
risk_factor_observer: Observer for market data
child_contract_observer: Optional observer for child contracts
rngs: Optional Flax NNX random number generators
Raises:
ValueError: If contract_type is not STK or required attributes missing
"""
super().__init__(
attributes=attributes,
risk_factor_observer=risk_factor_observer,
child_contract_observer=child_contract_observer,
rngs=rngs,
)
# Validate contract type
if attributes.contract_type != ContractType.STK:
raise ValueError(f"Contract type must be STK, got {attributes.contract_type}")
# STK doesn't have strict requirements beyond contract_type
# market_object_code is recommended for price observation but not required
def _apply_bdc(self, date: ActusDateTime) -> ActusDateTime:
"""Apply business day convention adjustment to a date."""
from jactus.utilities.calendars import MondayToFridayCalendar
bdc = self.attributes.business_day_convention
cal = self.attributes.calendar
if not bdc or bdc == "NULL" or not cal or cal in ("NO_CALENDAR", "NC"):
return date
calendar = MondayToFridayCalendar()
bdc_val = bdc.value if hasattr(bdc, "value") else str(bdc)
if bdc_val in ("CSF", "SCF", "CSMF", "SCMF"):
return calendar.next_business_day(date)
if bdc_val in ("CSP", "SCP", "CSMP", "SCMP"):
return calendar.previous_business_day(date)
return date
[docs]
def generate_event_schedule(self) -> EventSchedule:
"""Generate STK event schedule.
Generates events for stock contract:
- AD: Analysis dates (if specified)
- PRD: Purchase date (if specified)
- TD: Termination date (if specified)
- DV: Dividend events (if dividend schedule specified)
Returns:
EventSchedule with all contract events
ACTUS Reference:
STK Contract Schedule from Section 7.9
"""
events: list[ContractEvent] = []
# PRD: Purchase Date (if defined)
if self.attributes.purchase_date:
events.append(
ContractEvent(
event_type=EventType.PRD,
event_time=self.attributes.purchase_date,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=self.attributes.currency or "XXX",
state_pre=None,
state_post=None,
sequence=len(events),
)
)
# DV: Dividend events
if self.attributes.dividend_cycle:
from jactus.utilities.schedules import generate_schedule
dv_start = (
self.attributes.dividend_anchor
or self.attributes.purchase_date
or self.attributes.status_date
)
dv_end = self.attributes.termination_date or self.attributes.maturity_date
if dv_end:
dv_dates = generate_schedule(
start=dv_start,
cycle=self.attributes.dividend_cycle,
end=dv_end,
)
for dv_time in dv_dates:
if dv_time > self.attributes.status_date:
dv_time = self._apply_bdc(dv_time)
events.append(
ContractEvent(
event_type=EventType.DV,
event_time=dv_time,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=self.attributes.currency or "XXX",
state_pre=None,
state_post=None,
sequence=len(events),
)
)
# TD: Termination Date (if defined)
if self.attributes.termination_date:
events.append(
ContractEvent(
event_type=EventType.TD,
event_time=self.attributes.termination_date,
payoff=jnp.array(0.0, dtype=jnp.float32),
currency=self.attributes.currency or "XXX",
state_pre=None,
state_post=None,
sequence=len(events),
)
)
# Sort events by time
events.sort(key=lambda e: (e.event_time.to_iso(), e.sequence))
# Reassign sequence numbers
for i, event in enumerate(events):
events[i] = ContractEvent(
event_type=event.event_type,
event_time=event.event_time,
payoff=event.payoff,
currency=event.currency,
state_pre=event.state_pre,
state_post=event.state_post,
sequence=i,
)
return EventSchedule(
events=tuple(events),
contract_id=self.attributes.contract_id,
)
[docs]
def initialize_state(self) -> ContractState:
"""Initialize STK contract state.
STK has minimal state - only status date and performance.
ACTUS Reference:
STK State Initialization from Section 7.9
Returns:
Initial contract state
"""
# STK has minimal state - just status date
return ContractState(
sd=self.attributes.status_date,
tmd=self.attributes.termination_date or self.attributes.status_date,
nt=jnp.array(0.0, dtype=jnp.float32),
ipnr=jnp.array(0.0, dtype=jnp.float32),
ipac=jnp.array(0.0, dtype=jnp.float32),
feac=jnp.array(0.0, dtype=jnp.float32),
nsc=jnp.array(1.0, dtype=jnp.float32),
isc=jnp.array(1.0, dtype=jnp.float32),
)
[docs]
def get_payoff_function(self, event_type: Any) -> StockPayoffFunction:
"""Get payoff function for STK events."""
return StockPayoffFunction(
contract_role=self.attributes.contract_role,
currency=self.attributes.currency,
settlement_currency=None,
)
[docs]
def get_state_transition_function(self, event_type: Any) -> StockStateTransitionFunction:
"""Get state transition function for STK events."""
return StockStateTransitionFunction()