Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

The Function Representation

In dynamic programming, the value function VT(x)V_T(x) is computed on a discrete grid but must be evaluated at arbitrary points when solving earlier periods. The function representation turns a pre-computed array VTarrV^\text{arr}_T into a callable function that:

  • Accepts named arguments (e.g., wealth=150.0)

  • Returns exact values at grid points

  • Linearly interpolates between grid points

This notebook explains how it works, using a minimal terminal-regime model.

The two steps

Converting an array into a callable function requires two things:

  1. Coordinate finding — For continuous variables, convert physical values (e.g., wealth = 150) to generalized coordinates (fractional indices into the grid). See the interpolation notebook for details. Discrete variables use integer codes that directly serve as array indices.

  2. Interpolation — Use the generalized coordinates with map_coordinates to linearly interpolate between grid points.

Worked example

We set up a minimal model with a single terminal regime: a retiree choosing consumption given wealth, with CRRA utility. The wealth grid is intentionally coarse (10 points) to clearly show the interpolation behavior.

import jax.numpy as jnp
import plotly.graph_objects as go

from lcm import AgeGrid, LinSpacedGrid, Model, Regime, categorical
from lcm.typing import ContinuousAction, ContinuousState, FloatND

blue, orange, green = "#4C78A8", "#F58518", "#54A24B"


def utility(consumption: ContinuousAction, risk_aversion: float) -> FloatND:
    return consumption ** (1 - risk_aversion) / (1 - risk_aversion)


def next_wealth(
    wealth: ContinuousState,
    consumption: ContinuousAction,
    interest_rate: float,
) -> ContinuousState:
    return (1 + interest_rate) * (wealth - consumption)


def borrowing_constraint(
    consumption: ContinuousAction, wealth: ContinuousState
) -> FloatND:
    return consumption <= wealth


@categorical(ordered=False)
class RegimeId:
    working_life: int
    retirement: int


retirement_regime = Regime(
    transition=None,
    functions={"utility": utility},
    constraints={"borrowing_constraint": borrowing_constraint},
    actions={"consumption": LinSpacedGrid(start=1, stop=400, n_points=50)},
    states={"wealth": LinSpacedGrid(start=1, stop=400, n_points=10)},
)

working_life_regime = Regime(
    transition=lambda: RegimeId.retirement,
    functions={"utility": utility},
    constraints={"borrowing_constraint": borrowing_constraint},
    actions={"consumption": LinSpacedGrid(start=1, stop=400, n_points=50)},
    states={
        "wealth": LinSpacedGrid(start=1, stop=400, n_points=10),
    },
    state_transitions={
        "wealth": next_wealth,
    },
)

model = Model(
    description="Minimal consumption-savings model",
    ages=AgeGrid(start=25, stop=65, step="20Y"),
    regimes={"working_life": working_life_regime, "retirement": retirement_regime},
    regime_id_class=RegimeId,
)

params = {
    "discount_factor": 0.95,
    "risk_aversion": 1.5,
    "interest_rate": 0.04,
}
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 55
     51         "wealth": next_wealth,
     52     },
     53 )
     54 
---> 55 model = Model(
     56     description="Minimal consumption-savings model",
     57     ages=AgeGrid(start=25, stop=65, step="20Y"),
     58     regimes={"working_life": working_life_regime, "retirement": retirement_regime},

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/model.py:136, in Model.__init__(self, description, ages, regimes, regime_id_class, enable_jit, fixed_params, derived_categoricals)
    127 self.regime_names_to_ids = MappingProxyType(
    128     dict(
    129         sorted(
   (...)    133     )
    134 )
    135 self.regimes = _merge_derived_categoricals(regimes, derived_categoricals)
--> 136 self.internal_regimes, self._params_template = build_regimes_and_template(
    137     ages=self.ages,
    138     regimes=self.regimes,
    139     regime_names_to_ids=self.regime_names_to_ids,
    140     enable_jit=enable_jit,
    141     fixed_params=self.fixed_params,
    142 )
    143 self.enable_jit = enable_jit
    144 self.simulation_output_dtypes = get_simulation_output_dtypes(
    145     regimes=self.regimes,
    146     regime_names_to_ids=self.regime_names_to_ids,
    147 )

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/model_processing.py:71, in build_regimes_and_template(ages, regimes, regime_names_to_ids, enable_jit, fixed_params)
     53 """Build internal regimes and params template in a single pass.
     54 
     55 Compose regime processing, template creation, and optional fixed-param partialling
   (...)     68 
     69 """
     70 if not fixed_params:
---> 71     internal_regimes = process_regimes(
     72         ages=ages,
     73         regimes=regimes,
     74         regime_names_to_ids=regime_names_to_ids,
     75         enable_jit=enable_jit,
     76     )
     77     params_template = create_params_template(internal_regimes)
     78 else:

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/regime_building/processing.py:131, in process_regimes(regimes, ages, regime_names_to_ids, enable_jit)
    128 for name, regime in regimes.items():
    129     regime_params_template = create_regime_params_template(regime)
--> 131     solve_functions = _build_solve_functions(
    132         regime=regime,
    133         regime_name=name,
    134         nested_transitions=nested_transitions[name],
    135         all_grids=all_grids,
    136         regime_params_template=regime_params_template,
    137         regime_names_to_ids=regime_names_to_ids,
    138         variable_info=variable_info[name],
    139         regimes_to_active_periods=regimes_to_active_periods,
    140         regime_to_v_interpolation_info=regime_to_v_interpolation_info,
    141         state_action_space=state_action_spaces[name],
    142         ages=ages,
    143         enable_jit=enable_jit,
    144     )
    146     simulate_functions = _build_simulate_functions(
    147         regime=regime,
    148         regime_name=name,
   (...)    161         solve_compute_regime_transition_probs=solve_functions.compute_regime_transition_probs,
    162     )
    164     internal_regimes[name] = InternalRegime(
    165         name=name,
    166         terminal=regime.terminal,
   (...)    173         _base_state_action_space=state_action_spaces[name],
    174     )

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/regime_building/processing.py:280, in _build_solve_functions(regime, regime_name, nested_transitions, all_grids, regime_params_template, regime_names_to_ids, variable_info, regimes_to_active_periods, regime_to_v_interpolation_info, state_action_space, ages, enable_jit)
    258     compute_intermediates = _build_compute_intermediates_per_period(
    259         flat_param_names=flat_param_names,
    260         regimes_to_active_periods=regimes_to_active_periods,
   (...)    270         enable_jit=enable_jit,
    271     )
    273 max_Q_over_a = _build_max_Q_over_a_per_period(
    274     state_action_space=state_action_space,
    275     Q_and_F_functions=Q_and_F_functions,
    276     grids=all_grids[regime_name],
    277     enable_jit=enable_jit,
    278 )
--> 280 compute_intermediates = _build_compute_intermediates_per_period(
    281     regime=regime,
    282     flat_param_names=frozenset(get_flat_param_names(regime_params_template)),
    283     regimes_to_active_periods=regimes_to_active_periods,
    284     functions=core.functions,
    285     constraints=core.constraints,
    286     transitions=core.transitions,
    287     stochastic_transition_names=core.stochastic_transition_names,
    288     compute_regime_transition_probs=compute_regime_transition_probs,
    289     regime_to_v_interpolation_info=regime_to_v_interpolation_info,
    290     state_action_space=state_action_space,
    291     grids=all_grids[regime_name],
    292     ages=ages,
    293     enable_jit=enable_jit,
    294 )
    296 return SolveFunctions(
    297     functions=core.functions,
    298     constraints=core.constraints,
   (...)    303     compute_intermediates=compute_intermediates,
    304 )

TypeError: _build_compute_intermediates_per_period() got an unexpected keyword argument 'regime'

Computing the last-period value function array

In the terminal period, the value function is the maximum of utility over feasible actions. We use the internal regime representation to access the compiled functions and grids.

from lcm.regime_building.Q_and_F import _get_U_and_F

internal_regime = model.internal_regimes["retirement"]

u_and_f = _get_U_and_F(
    functions=internal_regime.solve_functions.functions,
    constraints=internal_regime.solve_functions.constraints,
)
u_and_f.__signature__

The function returns (utility, feasibility) for scalar inputs:

_u, _f = u_and_f(consumption=100.0, wealth=50.0, utility__risk_aversion=1.5)
print(f"Utility: {_u}, feasible: {_f}")

To evaluate on the full state-action grid, we use productmap:

from lcm.utils.dispatchers import productmap

_variables = ("wealth", "consumption")
u_and_f_mapped = productmap(
    func=u_and_f, variables=_variables, batch_sizes=dict.fromkeys(_variables, 0)
)
grid_arrays = {name: g.to_jax() for name, g in internal_regime.grids.items()}
u, f = u_and_f_mapped(**grid_arrays, utility__risk_aversion=1.5)

V_arr = jnp.max(u, axis=1, where=f, initial=-jnp.inf)
wealth_grid = internal_regime.grids["wealth"].to_jax()

print(f"V_arr shape: {V_arr.shape} ({len(wealth_grid)} wealth grid points)")
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=wealth_grid,
        y=V_arr,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Pre-calculated values",
    )
)
fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    width=600,
    height=400,
)
fig.show()

Creating the function representation

The function representation turns V_arr into a callable that can be evaluated at any wealth value. The V_arr_name argument sets the name of the array parameter in the resulting function.

from lcm.regime_building.V import create_v_interpolation_info, get_V_interpolator

v_interpolation_info = create_v_interpolation_info(retirement_regime)

scalar_value_function = get_V_interpolator(
    v_interpolation_info=v_interpolation_info,
    state_prefix="next_",
    V_arr_name="V_arr",
)
scalar_value_function.__signature__

This scalar function is then wrapped with productmap so it can evaluate on arrays:

value_function = productmap(
    func=scalar_value_function,
    variables=("next_wealth",),
    batch_sizes={"next_wealth": 0},
)

Visualizing interpolation

We evaluate the function representation on the original grid points (which should match exactly) and on additional points between grid points (which are interpolated).

wealth_points_new = jnp.array([10.0, 25.0, 75.0, 210.0, 300.0])
wealth_all = jnp.concatenate([wealth_grid, wealth_points_new])

V_via_func = value_function(next_wealth=wealth_all, V_arr=V_arr)
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=wealth_grid,
        y=V_arr,
        mode="lines+markers",
        marker={"color": blue, "size": 8},
        line={"color": blue},
        name="Pre-calculated values (linear interpolation)",
    )
)
fig.add_trace(
    go.Scatter(
        x=wealth_all,
        y=V_via_func,
        mode="markers",
        marker={"color": orange, "size": 6},
        name="Function representation output",
    )
)
fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    width=700,
    height=400,
)
fig.show()

The orange points from the function representation lie exactly on the blue line connecting the grid points. The function representation behaves like an analytical function corresponding to this piecewise linear interpolation.

Technical details

The function representation is assembled from three building blocks, each implemented as a small function with a carefully chosen signature. These functions are composed using dags.concatenate_functions.

Lookup function

Indexes into the value function array using named axes. This is important because dags.concatenate_functions matches functions by argument names.

from lcm.regime_building.V import _get_lookup_function

lookup = _get_lookup_function(array_name="V_arr", axis_names=["wealth_index"])
print(f"Signature: {lookup.__signature__}")

# Look up values at indices 0, 2, 5
lookup(wealth_index=jnp.array([0, 2, 5]), V_arr=V_arr)

Coordinate finder

Converts physical values to generalized coordinates — fractional indices into the grid. For a linearly spaced grid [1, 45.3, 89.7, ...], the value 23.2 might correspond to coordinate 0.5 (halfway between indices 0 and 1).

from lcm.regime_building.V import _get_coordinate_finder

wealth_grid = LinSpacedGrid(start=1, stop=400, n_points=10)

wealth_coordinate_finder = _get_coordinate_finder(
    in_name="wealth",
    grid=wealth_grid,
)
print(f"Signature: {wealth_coordinate_finder.__signature__}")

wealth_values = jnp.array([1.0, (1 + 45.333336) / 2, 390.0])
coords = wealth_coordinate_finder(wealth=wealth_values)

for w, c in zip(wealth_values, coords, strict=True):
    print(f"  wealth = {w:8.2f}  →  coordinate = {float(c):.4f}")

Interpolator

Uses the generalized coordinates to linearly interpolate on the value function array via map_coordinates.

from lcm.regime_building.V import _get_interpolator

value_function_interpolator = _get_interpolator(
    name_of_values_on_grid="V_arr",
    axis_names=["wealth_index"],
)
print(f"Signature: {value_function_interpolator.__signature__}")

wealth_indices = wealth_coordinate_finder(wealth=wealth_values)
V_interpolations = value_function_interpolator(wealth_index=wealth_indices, V_arr=V_arr)
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=wealth_grid.to_jax(),
        y=V_arr,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Pre-calculated values",
    )
)
fig.add_trace(
    go.Scatter(
        x=wealth_values,
        y=V_interpolations,
        mode="markers",
        marker={"color": orange, "size": 6},
        name="Interpolated values",
    )
)
fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    width=600,
    height=400,
)
fig.show()

Re-implementation from scratch

To understand how the pieces fit together, let’s re-implement the function representation manually using dags.concatenate_functions.

The general idea: create functions for array lookup, coordinate finding, and interpolation, each with signatures that declare their dependencies. Then let dags wire them together.

Steps

  1. Discrete lookup — index into the array using discrete positions. With no discrete states in our model, this is the identity (returns the array unchanged)

  2. Coordinate finder for each continuous state — maps values to fractional indices

  3. Interpolator — uses coordinates to interpolate on the array

Implementation

space_info = create_v_interpolation_info(retirement_regime)

funcs = {}

print(f"Discrete states: {space_info.discrete_states}")
# Step 1: Discrete lookup — identity (no discrete states to index by)
def discrete_lookup(V_arr):
    return V_arr


funcs["__interpolation_data__"] = discrete_lookup
# Step 2: Coordinate finder for wealth
from lcm.grids.coordinates import get_linspace_coordinate


def wealth_coordinate_finder(wealth):
    return get_linspace_coordinate(value=wealth, start=1, stop=400, n_points=10)


funcs["__wealth_coord__"] = wealth_coordinate_finder
# Step 3: Interpolator using map_coordinates
from lcm.regime_building.ndimage import map_coordinates


def interpolator(__interpolation_data__, __wealth_coord__):
    coordinates = jnp.array([__wealth_coord__])
    return map_coordinates(input=__interpolation_data__, coordinates=coordinates)


funcs["__fval__"] = interpolator
# Compose with dags
from dags import concatenate_functions

value_function = concatenate_functions(functions=funcs, targets="__fval__")
print(f"Composed signature: {value_function.__signature__}")

V_evaluated = value_function(wealth=wealth_grid.to_jax(), V_arr=V_arr)
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=wealth_grid.to_jax(),
        y=V_arr,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Pre-calculated values",
    )
)
fig.add_trace(
    go.Scatter(
        x=wealth_grid.to_jax(),
        y=V_evaluated,
        mode="markers",
        marker={"color": orange, "size": 6},
        name="Re-implemented function representation",
    )
)
fig.update_layout(
    xaxis_title="Wealth (x)",
    yaxis_title="V(x)",
    width=600,
    height=400,
)
fig.show()

The orange points coincide perfectly with the blue grid points — our manual re-implementation matches pylcm’s built-in function representation.