Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Stochastic Transitions

pylcm has three mechanisms for stochastic transitions:

  1. Regime transitions — which regime an agent moves to next.

  2. Discrete state transitions (MarkovTransition) — which discrete state an agent occupies next, with user-supplied probability functions.

  3. Continuous shocks (ShockGrid) — stochastic continuous states whose transition probabilities are defined intrinsically by the grid object.

All three produce probability weights that feed into expected value computations, but they use different runtime representations. This page explains why.

The Computation Hierarchy

The action-value function QQ decomposes into nested expectations:

Q(s,a)=U(s,a)+βrP(rs,a)sP(ss,a,r)V(s,r)Q(s, a) = U(s, a) + \beta \sum_{r'} P(r' \mid s, a) \sum_{s'} P(s' \mid s, a, r') \, V(s', r')

The outer sum runs over target regimes rr', weighted by regime transition probabilities. The inner sum runs over next-period states ss', weighted by state transition probabilities.

This hierarchy is reflected directly in the solve loop (src/lcm/Q_and_F.py):

for target_regime_name in active_regimes_next_period:
    next_states = state_transitions[target_regime_name](...)
    weights = next_stochastic_states_weights[target_regime_name](...)
    next_V_expected = jnp.average(next_V_arr, weights=joint_weights)
    continuation_value += regime_probs[target_regime_name] * next_V_expected
  • The outer sum is a Python for loop over regime names.

  • The inner sum is a vectorised jnp.average over state grid points.

This structural difference — loop-by-name vs array operation — is the fundamental reason why regime and state transition probabilities use different data structures.

Regime Transition Probabilities

Runtime format

MappingProxyType[str, Array] — an immutable dict mapping regime names to per-subject probability arrays.

# Example: 3 subjects, 2 regimes
{
    "working": Array([0.8, 0.65, 0.9]),
    "retired": Array([0.2, 0.35, 0.1]),
}

Why a dict?

The solve loop iterates over target regimes by name, and every co-indexed data structure — state_transitions, next_V, next_stochastic_states_weights — is also keyed by regime name. The dict format fits naturally:

continuation_value += regime_probs[target_regime_name] * next_V_expected

Switching to a plain array indexed by integer regime ID would require:

continuation_value += regime_probs_arr[regime_names_to_ids[target_regime_name]] * next_V_expected

This gains nothing — the loop must still iterate by name because the other dicts force it — and adds noise.

The array-to-dict-to-array round-trip

The user’s transition function returns a plain array indexed by integer regime ID. The processing pipeline immediately wraps it into a dict (_wrap_regime_transition in src/lcm/input_processing/regime_components.py):

def wrapped(*args, **kwargs):
    result = func(*args, **kwargs)    # Array[n_regimes]
    return MappingProxyType(
        {name: result[idx] for idx, name in enumerate(regime_names)}
    )

During simulation, draw_key_from_dict converts it back to a matrix for sampling (src/lcm/simulation/utils.py):

regime_transition_probs = jnp.array(list(d.values())).T  # dict → matrix
regime_ids = jnp.array([regime_names_to_ids[name] for name in d])
# ... jax.random.choice(key, regime_ids, p=p)

This round-trip is the cost of keeping the solve path readable. The alternative — using plain arrays everywhere — would require restructuring the solve loop and all co-indexed data structures, a large change with no functional benefit.

Discrete State Transitions (MarkovTransition)

Runtime format

Plain Array — the last axis indexes outcome grid points.

# Health has 3 categories: good, fair, bad
# Weight function returns: probabilities over next-health values
weight_next_health = Array([0.7, 0.2, 0.1])  # P(good), P(fair), P(bad)

How they flow through the pipeline

The user provides a transition function wrapped in MarkovTransition:

def next_health(health: DiscreteState, period: Period, probs_array: FloatND) -> FloatND:
    return probs_array[period, health]

During processing (_get_internal_functions in src/lcm/input_processing/regime_processing.py), this gets split into two internal functions:

  • weight_next_health — the user’s function (renamed params). Returns the probability distribution over outcomes.

  • next_health — auto-generated. Returns the full grid of possible outcomes (jnp.arange(n_categories)), so the solution can evaluate VV at every possible next state.

When multiple states are stochastic, their marginal weights combine via outer product (joint_weights_from_marginals in Q_and_F.py), forming a joint distribution over the product space. The expected value is then a single vectorised call:

next_V_expected = jnp.average(
    next_V_at_stochastic_states,
    weights=joint_weights,
)

Why plain arrays work

  • Outcomes are always exactly on the DiscreteGrid — no interpolation needed.

  • The array’s last axis corresponds directly to grid positions.

  • No name-based lookup is required — everything is positional within a single regime’s computation.

  • jnp.average consumes the weights directly.

Continuous Shocks (ShockGrids)

Runtime format

Plain Array — same downstream format as discrete MarkovTransition states. The difference is entirely in how the weights are produced.

Intrinsic transitions

ShockGrids — Rouwenhorst, Tauchen, Normal, Uniform, etc. — carry their own transition probabilities. The grid object has two methods:

  • compute_gridpoints(**params) — the discretised shock values.

  • compute_transition_probs(**params) — an (n×n)(n \times n) transition matrix Pij=Pr(next=xjcurrent=xi)P_{ij} = \Pr(\text{next} = x_j \mid \text{current} = x_i).

For IID shocks (Normal, Uniform), all rows of PP are identical — the draw is independent of the current state. For AR(1) shocks (Tauchen, Rouwenhorst), rows differ based on persistence. See Approximating Continuous Shocks for the underlying quadrature methods.

This is why ShockGrids must NOT appear in state_transitions: the transition probabilities are intrinsic to the grid definition, not user-supplied functions. Placing them in state_transitions would duplicate information and create conflicting sources of truth.

The interpolation problem

For discrete MarkovTransition states, the agent’s current state is always exactly on a grid point — it’s a categorical variable with a finite set of values. Looking up the relevant row of the transition matrix is a simple index operation.

For continuous shocks, the situation is different. During the DP solution, the agent’s current shock value may not coincide with any grid point. (The finite discretisation approximates a continuous distribution.) The weight function must interpolate the transition matrix at the agent’s actual position.

This is implemented in _get_weights_func_for_shock (src/lcm/input_processing/regime_processing.py):

def weights_func(**kwargs):
    # Find fractional position on the shock grid
    coordinate = get_irreg_coordinate(value=kwargs[name], points=gridpoints)
    # Interpolate the transition matrix at that position
    return map_coordinates(
        input=transition_probs,      # (n_points, n_points)
        coordinates=[
            jnp.full(n_points, fill_value=coordinate),  # row: current state
            jnp.arange(n_points),                        # col: all next states
        ],
    )

get_irreg_coordinate finds a fractional index on the grid (e.g., 2.7 means 70% of the way between grid points 2 and 3). map_coordinates then bilinearly interpolates the transition matrix at that fractional row, producing smoothly varying weights across all nn next-period grid points.

For IID shocks, the interpolation is a no-op in effect — all rows are identical, so interpolating between them returns the same weights regardless of position. But the same code path handles both IID and AR(1) cases uniformly.

Runtime parameters

ShockGrid parameters (e.g., rho and sigma for Tauchen) can be supplied at runtime via the model’s parameter dict. When params_to_pass_at_runtime is non-empty, the weight function calls compute_gridpoints() and compute_transition_probs() inside JIT with the runtime values. This allows the same model to be solved with different shock parameters without re-processing.

Summary

AspectRegime transitionsDiscrete MarkovContinuous shocks
Runtime formatMappingProxyType[str, Array]ArrayArray
Defined viaRegime(transition=...)state_transitions + MarkovTransitionGrid object (_ShockGrid)
Probabilities fromUser functionUser functionGrid’s compute_transition_probs()
InterpolationNoNoYes (map_coordinates)
Consumed byNamed loop in Q_and_Fjnp.average (vectorised)jnp.average (vectorised)
Why this formatCo-indexed with regime-keyed dictsPositional within single regimeSame as discrete Markov; interpolation is upstream

The three mechanisms converge at two points:

  1. Downstream consumption — discrete Markov states and continuous shocks both produce plain weight arrays consumed by jnp.average. They differ only in how weights are produced (direct evaluation vs interpolation).

  2. User-facing pandas utilitiestransition_probs_from_series and validate_transition_probs provide a unified API for building and validating probability arrays for both regime and state transitions.

Full harmonisation of the runtime formats is blocked by the solve loop structure: regime probabilities are consumed in a Python loop over named regimes (because every co-indexed data structure is also keyed by name), while state probabilities are consumed in a vectorised array operation within each iteration of that loop. Flattening this hierarchy would require restructuring the core DP computation.