"""Financial mathematics utilities for ACTUS contracts.
This module provides mathematical functions for contract calculations including
contract role signs, annuity calculations, and discount factors.
References:
ACTUS Technical Specification v1.1, Table 1 (Contract Role Signs)
ACTUS Technical Specification v1.1, Section 5 (Mathematical Functions)
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from jactus.core.time import ActusDateTime
from jactus.core.types import ContractRole, DayCountConvention
from jactus.utilities.conventions import year_fraction
[docs]
def contract_role_sign(role: ContractRole) -> int:
"""Get the sign (+1 or -1) for a contract role.
The contract role sign determines the direction of cash flows from the
perspective of the contract holder.
Args:
role: Contract role
Returns:
+1 for long/receiving positions, -1 for short/paying positions
Example:
>>> contract_role_sign(ContractRole.RPA) # Receiving party A
1
>>> contract_role_sign(ContractRole.RPL) # Real position lender
-1
References:
ACTUS Technical Specification v1.1, Table 1
"""
# Use the method from ContractRole enum
return role.get_sign()
[docs]
@jax.jit
def contract_role_sign_vectorized(roles: jnp.ndarray) -> jnp.ndarray:
"""Vectorized contract role sign calculation for JAX.
Args:
roles: Array of contract role values (as integers)
Returns:
Array of signs (+1 or -1)
Example:
>>> roles = jnp.array([0, 1, 2]) # RPA, RPL, LG
>>> signs = contract_role_sign_vectorized(roles)
>>> signs
Array([1, -1, 1], dtype=int32)
Note:
This function is JIT-compiled for performance.
"""
# Map role indices to signs according to ACTUS Table 1
# This matches the ordering in ContractRole enum
sign_map = jnp.array(
[
1, # RPA - Real Position Asset
-1, # RPL - Real Position Liability
1, # LG - Long Position
-1, # ST - Short Position
1, # BUY - Protection Buyer
-1, # SEL - Protection Seller
1, # RFL - Receive First Leg
-1, # PFL - Pay First Leg
1, # COL - Collateral Instrument
1, # CNO - Close-out Netting Instrument
-1, # GUA - Guarantor
1, # OBL - Obligee
1, # UDL - Underlying
1, # UDLP - Underlying Positive
-1, # UDLM - Underlying Negative
],
dtype=jnp.int32,
)
return sign_map[roles]
[docs]
def annuity_amount(
notional: float,
rate: float,
tenor: ActusDateTime, # noqa: ARG001
maturity: ActusDateTime, # noqa: ARG001
n_periods: int,
day_count_convention: DayCountConvention, # noqa: ARG001
) -> float:
"""Calculate annuity payment amount.
Computes the periodic payment for an annuity given notional, rate, and term.
Uses the formula: A = N * r / (1 - (1 + r)^(-n))
Args:
notional: Notional principal amount
rate: Periodic interest rate (e.g., 0.05 for 5% per period)
tenor: Start date for year fraction calculation (reserved for future use)
maturity: End date for year fraction calculation (reserved for future use)
n_periods: Number of payment periods
day_count_convention: Day count convention (reserved for future use)
Returns:
Annuity payment amount per period
Example:
>>> # $100,000 loan at 5% annual rate, 12 monthly payments
>>> tenor = ActusDateTime(2024, 1, 1, 0, 0, 0)
>>> maturity = ActusDateTime(2025, 1, 1, 0, 0, 0)
>>> amount = annuity_amount(100000, 0.05/12, tenor, maturity, 12, DayCountConvention.A360)
>>> abs(amount - 8560.75) < 1 # Approximately $8,560.75 per month
True
References:
ACTUS Technical Specification v1.1, Section 5.1
"""
if n_periods == 0:
return 0.0
if abs(rate) < 1e-10:
# For zero or near-zero rates, annuity is just notional / periods
return notional / n_periods
# Standard annuity formula: A = N * r / (1 - (1 + r)^(-n))
denominator = 1.0 - (1.0 + rate) ** (-n_periods)
return notional * rate / denominator
[docs]
@jax.jit
def annuity_amount_vectorized(
notional: jnp.ndarray,
rate: jnp.ndarray,
n_periods: jnp.ndarray,
) -> jnp.ndarray:
"""Vectorized annuity calculation for JAX arrays.
Args:
notional: Array of notional amounts
rate: Array of periodic rates
n_periods: Array of number of periods
Returns:
Array of annuity amounts
Example:
>>> notionals = jnp.array([100000.0, 200000.0])
>>> rates = jnp.array([0.05/12, 0.04/12])
>>> periods = jnp.array([12, 24])
>>> amounts = annuity_amount_vectorized(notionals, rates, periods)
Note:
This function is JIT-compiled for performance.
"""
# Handle zero periods
zero_periods = n_periods == 0
result = jnp.where(zero_periods, 0.0, notional)
# Handle near-zero rates
near_zero_rate = jnp.abs(rate) < 1e-10
simple_annuity = notional / n_periods
result = jnp.where(near_zero_rate & ~zero_periods, simple_annuity, result)
# Standard annuity formula
valid_mask = ~zero_periods & ~near_zero_rate
denominator = 1.0 - jnp.power(1.0 + rate, -n_periods)
standard_annuity = notional * rate / denominator
return jnp.where(valid_mask, standard_annuity, result)
[docs]
def discount_factor(
rate: float,
start: ActusDateTime,
end: ActusDateTime,
day_count_convention: DayCountConvention,
) -> float:
"""Calculate discount factor for a time period.
Computes: DF = 1 / (1 + r * t)
where t is the year fraction between start and end.
Args:
rate: Annual interest rate (e.g., 0.05 for 5%)
start: Start date
end: End date
day_count_convention: Day count convention
Returns:
Discount factor
Example:
>>> start = ActusDateTime(2024, 1, 1, 0, 0, 0)
>>> end = ActusDateTime(2024, 7, 1, 0, 0, 0)
>>> df = discount_factor(0.05, start, end, DayCountConvention.AA)
>>> abs(df - 0.9756) < 0.001 # Approximately 0.9756
True
References:
ACTUS Technical Specification v1.1, Section 5.2
"""
yf = year_fraction(start, end, day_count_convention)
return 1.0 / (1.0 + rate * yf)
[docs]
@jax.jit
def discount_factor_vectorized(
rate: jnp.ndarray,
year_fraction: jnp.ndarray,
) -> jnp.ndarray:
"""Vectorized discount factor calculation.
Args:
rate: Array of interest rates
year_fraction: Array of year fractions
Returns:
Array of discount factors
Example:
>>> rates = jnp.array([0.05, 0.04, 0.06])
>>> yfs = jnp.array([0.5, 1.0, 0.25])
>>> dfs = discount_factor_vectorized(rates, yfs)
Note:
This function is JIT-compiled for performance.
"""
return 1.0 / (1.0 + rate * year_fraction)
[docs]
def compound_factor(
rate: float,
start: ActusDateTime,
end: ActusDateTime,
day_count_convention: DayCountConvention,
compounding_frequency: int = 1,
) -> float:
"""Calculate compound factor for a time period.
Computes: CF = (1 + r/m)^(m*t)
where m is the compounding frequency and t is the year fraction.
Args:
rate: Annual interest rate (e.g., 0.05 for 5%)
start: Start date
end: End date
day_count_convention: Day count convention
compounding_frequency: Number of compounding periods per year (default 1)
Returns:
Compound factor
Example:
>>> start = ActusDateTime(2024, 1, 1, 0, 0, 0)
>>> end = ActusDateTime(2025, 1, 1, 0, 0, 0)
>>> # Annual compounding
>>> cf = compound_factor(0.05, start, end, DayCountConvention.AA, 1)
>>> abs(cf - 1.05) < 0.001
True
>>> # Monthly compounding
>>> cf_monthly = compound_factor(0.05, start, end, DayCountConvention.AA, 12)
>>> abs(cf_monthly - 1.05116) < 0.001
True
References:
Standard financial mathematics
"""
yf = year_fraction(start, end, day_count_convention)
if compounding_frequency == 0:
# Continuous compounding: e^(r*t)
import math
return math.exp(rate * yf)
# Discrete compounding: (1 + r/m)^(m*t)
return float((1.0 + rate / compounding_frequency) ** (compounding_frequency * yf))
[docs]
@jax.jit
def compound_factor_vectorized(
rate: jnp.ndarray,
year_fraction: jnp.ndarray,
compounding_frequency: jnp.ndarray,
) -> jnp.ndarray:
"""Vectorized compound factor calculation.
Args:
rate: Array of interest rates
year_fraction: Array of year fractions
compounding_frequency: Array of compounding frequencies
Returns:
Array of compound factors
Example:
>>> rates = jnp.array([0.05, 0.04])
>>> yfs = jnp.array([1.0, 1.0])
>>> freqs = jnp.array([1, 12])
>>> cfs = compound_factor_vectorized(rates, yfs, freqs)
Note:
This function is JIT-compiled for performance.
For continuous compounding (frequency=0), use a very large frequency instead.
"""
# Handle continuous compounding (freq = 0) by using large number
freq_adjusted = jnp.where(compounding_frequency == 0, 1e6, compounding_frequency)
# (1 + r/m)^(m*t)
base = 1.0 + rate / freq_adjusted
exponent = freq_adjusted * year_fraction
return jnp.power(base, exponent)
[docs]
def present_value(
cash_flows: list[float],
dates: list[ActusDateTime],
valuation_date: ActusDateTime,
discount_rate: float,
day_count_convention: DayCountConvention,
) -> float:
"""Calculate present value of a series of cash flows.
Args:
cash_flows: List of cash flow amounts
dates: List of cash flow dates
valuation_date: Date to discount to
discount_rate: Annual discount rate
day_count_convention: Day count convention
Returns:
Present value of all cash flows
Example:
>>> cfs = [100, 100, 100]
>>> dates = [
... ActusDateTime(2024, 1, 1, 0, 0, 0),
... ActusDateTime(2024, 7, 1, 0, 0, 0),
... ActusDateTime(2025, 1, 1, 0, 0, 0),
... ]
>>> val_date = ActusDateTime(2024, 1, 1, 0, 0, 0)
>>> pv = present_value(cfs, dates, val_date, 0.05, DayCountConvention.AA)
>>> abs(pv - 295.14) < 1.0
True
References:
Standard financial mathematics
"""
if len(cash_flows) != len(dates):
raise ValueError("cash_flows and dates must have the same length")
total_pv = 0.0
for cf, date in zip(cash_flows, dates, strict=True):
if date >= valuation_date:
df = discount_factor(discount_rate, valuation_date, date, day_count_convention)
total_pv += cf * df
else:
# Cash flow in the past - compound forward
cf_factor = compound_factor(discount_rate, date, valuation_date, day_count_convention)
total_pv += cf * cf_factor
return total_pv
[docs]
@jax.jit
def present_value_vectorized(
cash_flows: jnp.ndarray,
year_fractions: jnp.ndarray,
discount_rate: float,
) -> jnp.ndarray:
"""Vectorized present value calculation.
Args:
cash_flows: Array of cash flow amounts
year_fractions: Array of year fractions from valuation date
discount_rate: Discount rate
Returns:
Present value (scalar JAX array)
Example:
>>> cfs = jnp.array([100.0, 100.0, 100.0])
>>> yfs = jnp.array([0.0, 0.5, 1.0])
>>> pv = present_value_vectorized(cfs, yfs, 0.05)
Note:
This function is JIT-compiled for performance.
Assumes all cash flows are in the future (year_fractions >= 0).
"""
discount_factors = 1.0 / (1.0 + discount_rate * year_fractions)
return jnp.sum(cash_flows * discount_factors)
[docs]
def calculate_actus_annuity(
start: ActusDateTime,
pr_schedule: list[ActusDateTime],
notional: float,
accrued_interest: float,
rate: float,
day_count_convention: DayCountConvention,
) -> float:
"""Calculate annuity amount using ACTUS specification formula.
Implements the ACTUS annuity formula from Section 3.8:
A(s, T, n, a, r) = (n + a) / Σ[∏((1 + Y_i × r)^-1)]
Where:
s = start time
T = maturity (last PR date)
n = notional principal
a = accrued interest
r = nominal interest rate
Y_i = year fraction for period i
Σ = sum over all PR events
∏ = product up to each PR event
This calculates the constant payment amount such that the total of all
payments exactly amortizes the notional plus accrued interest.
Args:
start: Start time for calculation
pr_schedule: List of principal redemption dates
notional: Notional principal amount
accrued_interest: Already accrued interest
rate: Annual interest rate (e.g., 0.05 for 5%)
day_count_convention: Day count convention for year fractions
Returns:
Annuity payment amount per period
Example:
>>> # $100,000 loan at 5% for 12 months
>>> start = ActusDateTime(2024, 1, 15, 0, 0, 0)
>>> pr_dates = [ActusDateTime(2024, i, 15, 0, 0, 0) for i in range(2, 14)]
>>> amount = calculate_actus_annuity(
... start, pr_dates, 100000.0, 0.0, 0.05, DayCountConvention.A360
... )
>>> 8500 < amount < 8600 # Approximately $8,560
True
References:
ACTUS Technical Specification v1.1, Section 3.8
"""
if not pr_schedule:
return 0.0
if abs(rate) < 1e-10:
# For zero or near-zero rates, payment is just (notional + accrued) / periods
return (notional + accrued_interest) / len(pr_schedule)
# Calculate the denominator: Σ[∏((1 + Y_i × r)^-1)]
# This is the sum of discount factors for each period
cumulative_discount = 0.0
product_term = 1.0 # Running product of (1 + Y_i × r)
prev_date = start
for pr_date in pr_schedule:
# Year fraction for this period
yf = year_fraction(prev_date, pr_date, day_count_convention)
# Update the cumulative product: (1 + Y_1 × r) × (1 + Y_2 × r) × ...
product_term *= 1.0 + yf * rate
# Add the discount factor for this period: 1 / ∏(1 + Y_i × r)
cumulative_discount += 1.0 / product_term
prev_date = pr_date
# A = (n + a) / Σ[∏((1 + Y_i × r)^-1)]
return (notional + accrued_interest) / cumulative_discount
[docs]
@jax.jit
def calculate_actus_annuity_jax(
year_fractions: jnp.ndarray,
notional: float,
accrued_interest: float,
rate: float,
) -> float:
"""JAX-compiled version of ACTUS annuity calculation.
Implements the ACTUS annuity formula:
A(s, T, n, a, r) = (n + a) / Σ[∏((1 + Y_i × r)^-1)]
Args:
year_fractions: Array of year fractions for each period
notional: Notional principal amount
accrued_interest: Already accrued interest
rate: Annual interest rate
Returns:
Annuity payment amount
Example:
>>> # 12 equal monthly periods (30/360 convention)
>>> yfs = jnp.array([30/360] * 12)
>>> amount = calculate_actus_annuity_jax(yfs, 100000.0, 0.0, 0.05)
Note:
This function is JIT-compiled for performance.
References:
ACTUS Technical Specification v1.1, Section 3.8
"""
# Handle edge cases
n_periods = year_fractions.shape[0]
is_zero_rate = jnp.abs(rate) < 1e-10
# For zero rate, simple division
simple_payment = (notional + accrued_interest) / n_periods
# Calculate cumulative discount factors
# product_terms[i] = ∏_{j=0}^{i} (1 + Y_j × r)
factors = 1.0 + year_fractions * rate
product_terms = jnp.cumprod(factors)
# discount_factors[i] = 1 / product_terms[i]
discount_factors = 1.0 / product_terms
# Sum of discount factors
denominator = jnp.sum(discount_factors)
# A = (n + a) / Σ[discount_factors]
actus_payment = (notional + accrued_interest) / denominator
# Return simple payment if rate is zero, otherwise ACTUS payment
result: float = jnp.where(is_zero_rate, simple_payment, actus_payment) # type: ignore[assignment]
return result