pylcm has three mechanisms for stochastic transitions:
Regime transitions — which regime an agent moves to next.
Discrete state transitions (
MarkovTransition) — which discrete state an agent occupies next, with user-supplied probability functions.Continuous shocks (
ShockGrid) — stochastic continuous states whose transition probabilities are defined intrinsically by the grid object.
All three produce probability weights that feed into expected value computations, but they use different runtime representations. This page explains why.
The Computation Hierarchy¶
The action-value function decomposes into nested expectations:
The outer sum runs over target regimes , weighted by regime transition probabilities. The inner sum runs over next-period states , weighted by state transition probabilities.
This hierarchy is reflected directly in the solve loop
(src/lcm/Q_and_F.py):
for target_regime_name in active_regimes_next_period:
next_states = state_transitions[target_regime_name](...)
weights = next_stochastic_states_weights[target_regime_name](...)
next_V_expected = jnp.average(next_V_arr, weights=joint_weights)
continuation_value += regime_probs[target_regime_name] * next_V_expectedThe outer sum is a Python
forloop over regime names.The inner sum is a vectorised
jnp.averageover state grid points.
This structural difference — loop-by-name vs array operation — is the fundamental reason why regime and state transition probabilities use different data structures.
Regime Transition Probabilities¶
Runtime format¶
MappingProxyType[str, Array] — an immutable dict mapping regime names to
per-subject probability arrays.
# Example: 3 subjects, 2 regimes
{
"working": Array([0.8, 0.65, 0.9]),
"retired": Array([0.2, 0.35, 0.1]),
}Why a dict?¶
The solve loop iterates over target regimes by name, and every co-indexed data
structure — state_transitions, next_V, next_stochastic_states_weights — is
also keyed by regime name. The dict format fits naturally:
continuation_value += regime_probs[target_regime_name] * next_V_expectedSwitching to a plain array indexed by integer regime ID would require:
continuation_value += regime_probs_arr[regime_names_to_ids[target_regime_name]] * next_V_expectedThis gains nothing — the loop must still iterate by name because the other dicts force it — and adds noise.
The array-to-dict-to-array round-trip¶
The user’s transition function returns a plain array indexed by integer regime ID.
The processing pipeline immediately wraps it into a dict
(_wrap_regime_transition in src/lcm/input_processing/regime_components.py):
def wrapped(*args, **kwargs):
result = func(*args, **kwargs) # Array[n_regimes]
return MappingProxyType(
{name: result[idx] for idx, name in enumerate(regime_names)}
)During simulation, draw_key_from_dict converts it back to a matrix for
sampling (src/lcm/simulation/utils.py):
regime_transition_probs = jnp.array(list(d.values())).T # dict → matrix
regime_ids = jnp.array([regime_names_to_ids[name] for name in d])
# ... jax.random.choice(key, regime_ids, p=p)This round-trip is the cost of keeping the solve path readable. The alternative — using plain arrays everywhere — would require restructuring the solve loop and all co-indexed data structures, a large change with no functional benefit.
Discrete State Transitions (MarkovTransition)¶
Runtime format¶
Plain Array — the last axis indexes outcome grid points.
# Health has 3 categories: good, fair, bad
# Weight function returns: probabilities over next-health values
weight_next_health = Array([0.7, 0.2, 0.1]) # P(good), P(fair), P(bad)How they flow through the pipeline¶
The user provides a transition function wrapped in MarkovTransition:
def next_health(health: DiscreteState, period: Period, probs_array: FloatND) -> FloatND:
return probs_array[period, health]During processing (_get_internal_functions in
src/lcm/input_processing/regime_processing.py),
this gets split into two internal functions:
weight_next_health— the user’s function (renamed params). Returns the probability distribution over outcomes.next_health— auto-generated. Returns the full grid of possible outcomes (jnp.arange(n_categories)), so the solution can evaluate at every possible next state.
When multiple states are stochastic, their marginal weights combine via outer
product (joint_weights_from_marginals in Q_and_F.py), forming a joint
distribution over the product space. The expected value is then a single
vectorised call:
next_V_expected = jnp.average(
next_V_at_stochastic_states,
weights=joint_weights,
)Why plain arrays work¶
Outcomes are always exactly on the
DiscreteGrid— no interpolation needed.The array’s last axis corresponds directly to grid positions.
No name-based lookup is required — everything is positional within a single regime’s computation.
jnp.averageconsumes the weights directly.
Continuous Shocks (ShockGrids)¶
Runtime format¶
Plain Array — same downstream format as discrete MarkovTransition states.
The difference is entirely in how the weights are produced.
Intrinsic transitions¶
ShockGrids — Rouwenhorst, Tauchen, Normal, Uniform, etc. — carry their
own transition probabilities. The grid object has two methods:
compute_gridpoints(**params)— the discretised shock values.compute_transition_probs(**params)— an transition matrix .
For IID shocks (Normal, Uniform), all rows of are identical — the draw is independent of the current state. For AR(1) shocks (Tauchen, Rouwenhorst), rows differ based on persistence. See Approximating Continuous Shocks for the underlying quadrature methods.
This is why ShockGrids must NOT appear in state_transitions: the transition
probabilities are intrinsic to the grid definition, not user-supplied functions.
Placing them in state_transitions would duplicate information and create
conflicting sources of truth.
The interpolation problem¶
For discrete MarkovTransition states, the agent’s current state is always exactly on a grid point — it’s a categorical variable with a finite set of values. Looking up the relevant row of the transition matrix is a simple index operation.
For continuous shocks, the situation is different. During the DP solution, the agent’s current shock value may not coincide with any grid point. (The finite discretisation approximates a continuous distribution.) The weight function must interpolate the transition matrix at the agent’s actual position.
This is implemented in _get_weights_func_for_shock
(src/lcm/input_processing/regime_processing.py):
def weights_func(**kwargs):
# Find fractional position on the shock grid
coordinate = get_irreg_coordinate(value=kwargs[name], points=gridpoints)
# Interpolate the transition matrix at that position
return map_coordinates(
input=transition_probs, # (n_points, n_points)
coordinates=[
jnp.full(n_points, fill_value=coordinate), # row: current state
jnp.arange(n_points), # col: all next states
],
)get_irreg_coordinate finds a fractional index on the grid (e.g., 2.7 means
70% of the way between grid points 2 and 3). map_coordinates then bilinearly
interpolates the transition matrix at that fractional row, producing smoothly
varying weights across all next-period grid points.
For IID shocks, the interpolation is a no-op in effect — all rows are identical, so interpolating between them returns the same weights regardless of position. But the same code path handles both IID and AR(1) cases uniformly.
Runtime parameters¶
ShockGrid parameters (e.g., rho and sigma for Tauchen) can be supplied at
runtime via the model’s parameter dict. When params_to_pass_at_runtime is
non-empty, the weight function calls compute_gridpoints() and
compute_transition_probs() inside JIT with the runtime values. This allows
the same model to be solved with different shock parameters without
re-processing.
Summary¶
| Aspect | Regime transitions | Discrete Markov | Continuous shocks |
|---|---|---|---|
| Runtime format | MappingProxyType[str, Array] | Array | Array |
| Defined via | Regime(transition=...) | state_transitions + MarkovTransition | Grid object (_ShockGrid) |
| Probabilities from | User function | User function | Grid’s compute_transition_probs() |
| Interpolation | No | No | Yes (map_coordinates) |
| Consumed by | Named loop in Q_and_F | jnp.average (vectorised) | jnp.average (vectorised) |
| Why this format | Co-indexed with regime-keyed dicts | Positional within single regime | Same as discrete Markov; interpolation is upstream |
The three mechanisms converge at two points:
Downstream consumption — discrete Markov states and continuous shocks both produce plain weight arrays consumed by
jnp.average. They differ only in how weights are produced (direct evaluation vs interpolation).User-facing pandas utilities —
transition_probs_from_seriesandvalidate_transition_probsprovide a unified API for building and validating probability arrays for both regime and state transitions.
Full harmonisation of the runtime formats is blocked by the solve loop structure: regime probabilities are consumed in a Python loop over named regimes (because every co-indexed data structure is also keyed by name), while state probabilities are consumed in a vectorised array operation within each iteration of that loop. Flattening this hierarchy would require restructuring the core DP computation.