Source code for jactus.contracts.portfolio

"""Unified portfolio simulation API for mixed contract types.

Accepts a portfolio of contracts with mixed types, groups them by type,
dispatches to each type's batch kernel, and returns per-contract results.

Contracts without a dedicated array-mode implementation (CLM, UMP, SWAPS,
CAPFL, CEG, CEC) fall back to the scalar Python simulation path.

Example::

    from jactus.contracts.portfolio import simulate_portfolio

    contracts = [
        (pam_attrs, rf_obs),
        (lam_attrs, rf_obs),
        (csh_attrs, rf_obs),
        (optns_attrs, rf_obs),
    ]
    result = simulate_portfolio(contracts)
    # result["total_cashflows"]  -> jnp.array of shape (4,)
    # result["payoffs"]          -> list of per-contract payoff arrays
"""

from __future__ import annotations

from typing import Any

import jax.numpy as jnp
import numpy as np

from jactus.core import ContractAttributes, ContractType
from jactus.observers import RiskFactorObserver

# ---------------------------------------------------------------------------
# Type -> portfolio function registry
# ---------------------------------------------------------------------------

_PORTFOLIO_FN_REGISTRY: dict[ContractType, Any] = {}


def _get_portfolio_fn(ct: ContractType) -> Any | None:
    """Lazy-load portfolio function for a contract type."""
    if ct in _PORTFOLIO_FN_REGISTRY:
        return _PORTFOLIO_FN_REGISTRY[ct]

    fn: Any = None
    try:
        if ct == ContractType.PAM:
            from jactus.contracts.pam_array import simulate_pam_portfolio

            fn = simulate_pam_portfolio
        elif ct == ContractType.LAM:
            from jactus.contracts.lam_array import simulate_lam_portfolio

            fn = simulate_lam_portfolio
        elif ct == ContractType.NAM:
            from jactus.contracts.nam_array import simulate_nam_portfolio

            fn = simulate_nam_portfolio
        elif ct == ContractType.ANN:
            from jactus.contracts.ann_array import simulate_ann_portfolio

            fn = simulate_ann_portfolio
        elif ct == ContractType.LAX:
            from jactus.contracts.lax_array import simulate_lax_portfolio

            fn = simulate_lax_portfolio
        elif ct == ContractType.CSH:
            from jactus.contracts.csh_array import simulate_csh_portfolio

            fn = simulate_csh_portfolio
        elif ct == ContractType.STK:
            from jactus.contracts.stk_array import simulate_stk_portfolio

            fn = simulate_stk_portfolio
        elif ct == ContractType.COM:
            from jactus.contracts.com_array import simulate_com_portfolio

            fn = simulate_com_portfolio
        elif ct == ContractType.FXOUT:
            from jactus.contracts.fxout_array import simulate_fxout_portfolio

            fn = simulate_fxout_portfolio
        elif ct == ContractType.FUTUR:
            from jactus.contracts.futur_array import simulate_futur_portfolio

            fn = simulate_futur_portfolio
        elif ct == ContractType.OPTNS:
            from jactus.contracts.optns_array import simulate_optns_portfolio

            fn = simulate_optns_portfolio
        elif ct == ContractType.SWPPV:
            from jactus.contracts.swppv_array import simulate_swppv_portfolio

            fn = simulate_swppv_portfolio
    except ImportError:
        fn = None

    _PORTFOLIO_FN_REGISTRY[ct] = fn
    return fn


# Contract types that use the scalar Python fallback path
_FALLBACK_TYPES = frozenset(
    {
        ContractType.CLM,
        ContractType.UMP,
        ContractType.SWAPS,
        ContractType.CAPFL,
        ContractType.CEG,
        ContractType.CEC,
    }
)

# All types with dedicated array-mode batch kernels
BATCH_SUPPORTED_TYPES = frozenset(
    {
        ContractType.PAM,
        ContractType.LAM,
        ContractType.NAM,
        ContractType.ANN,
        ContractType.LAX,
        ContractType.CSH,
        ContractType.STK,
        ContractType.COM,
        ContractType.FXOUT,
        ContractType.FUTUR,
        ContractType.OPTNS,
        ContractType.SWPPV,
    }
)


def _simulate_scalar_fallback(
    attrs: ContractAttributes,
    rf_observer: RiskFactorObserver,
) -> float:
    """Simulate a single contract via the scalar Python path.

    Returns total cashflow (sum of all event payoffs).
    """
    from jactus.contracts import create_contract

    contract = create_contract(attrs, rf_observer)
    result = contract.simulate()
    return sum(float(e.payoff) for e in result.events)


[docs] def simulate_portfolio( contracts: list[tuple[ContractAttributes, RiskFactorObserver]], discount_rate: float | None = None, ) -> dict[str, Any]: """Simulate a mixed-type portfolio using optimal batch strategies. Groups contracts by type, dispatches to each type's batch kernel, and falls back to the scalar Python path for unsupported types. Args: contracts: List of ``(attributes, rf_observer)`` pairs. Contract types may be mixed. discount_rate: If provided, compute present values (passed to each type's portfolio function where supported). Returns: Dict with: - ``total_cashflows``: ``jnp.ndarray`` of shape ``(N,)`` — total cashflow for each contract in input order. - ``num_contracts``: Total number of contracts. - ``batch_contracts``: Number of contracts simulated via batch kernels. - ``fallback_contracts``: Number of contracts simulated via the scalar Python path. - ``types_used``: Set of ``ContractType`` values present. - ``per_type_results``: Dict mapping ``ContractType`` to the raw result dict from each type's portfolio function (only for batch-simulated types). """ n = len(contracts) if n == 0: return { "total_cashflows": jnp.zeros(0), "num_contracts": 0, "batch_contracts": 0, "fallback_contracts": 0, "types_used": set(), "per_type_results": {}, } # Group contracts by type, preserving original indices type_groups: dict[ContractType, list[tuple[int, ContractAttributes, RiskFactorObserver]]] = {} for i, (attrs, rf_obs) in enumerate(contracts): ct = attrs.contract_type if ct not in type_groups: type_groups[ct] = [] type_groups[ct].append((i, attrs, rf_obs)) # Output array total_cashflows = np.zeros(n, dtype=np.float32) per_type_results: dict[ContractType, dict[str, Any]] = {} batch_count = 0 fallback_count = 0 for ct, group in type_groups.items(): indices = [g[0] for g in group] group_contracts = [(g[1], g[2]) for g in group] portfolio_fn = _get_portfolio_fn(ct) if portfolio_fn is not None: # Batch simulation path kwargs: dict[str, Any] = {} if discount_rate is not None: kwargs["discount_rate"] = discount_rate result = portfolio_fn(group_contracts, **kwargs) per_type_results[ct] = result group_totals = result["total_cashflows"] for j, idx in enumerate(indices): total_cashflows[idx] = float(group_totals[j]) batch_count += len(group) else: # Scalar fallback path for _j, (idx, attrs, rf_obs) in enumerate(group): total_cashflows[idx] = _simulate_scalar_fallback(attrs, rf_obs) fallback_count += len(group) return { "total_cashflows": jnp.asarray(total_cashflows), "num_contracts": n, "batch_contracts": batch_count, "fallback_contracts": fallback_count, "types_used": set(type_groups.keys()), "per_type_results": per_type_results, }