Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

A Tiny Example

A three-period consumption-savings model with two regimes:

  • Working life (ages 25 and 45): The agent chooses whether to work and how much to consume. A simple tax-and-transfer system guarantees a consumption floor. Savings earn interest.

  • Retirement (age 65): Terminal regime. The agent consumes out of remaining wealth.

Model

An agent lives for three periods (ages 25, 45, and 65). In the first two periods (working life), the agent chooses whether to work dt{0,1}d_t \in \{0, 1\} and how much to consume ctc_t. In the final period (retirement), the agent consumes out of remaining wealth.

Working life (ages 25 and 45):

Vt(wt)=maxdt,ct{ct1σ1σϕdt+βVt+1(wt+1)}V_t(w_t) = \max_{d_t,\, c_t} \left\{ \frac{c_t^{1-\sigma}}{1-\sigma} - \phi \, d_t + \beta \, V_{t+1}(w_{t+1}) \right\}

subject to

et=dtwˉτ(et,wt)={θ(etc)if etcmin(0,  wt+etc)otherwiseat=wt+etτ(et,wt)ctwt+1=(1+r)atat0\begin{align} e_t &= d_t \cdot \bar{w} \\[4pt] \tau(e_t, w_t) &= \begin{cases} \theta\,(e_t - \underline{c}) & \text{if } e_t \geq \underline{c} \\ \min(0,\; w_t + e_t - \underline{c}) & \text{otherwise} \end{cases} \\[4pt] a_t &= w_t + e_t - \tau(e_t, w_t) - c_t \\[4pt] w_{t+1} &= (1 + r)\, a_t \\[4pt] a_t &\geq 0 \end{align}

where wtw_t is wealth, ete_t earnings, wˉ\bar{w} the wage, c\underline{c} a consumption floor guaranteed by transfers, θ\theta the tax rate, and ata_t end-of-period wealth. The transfer only kicks in when the agent’s resources (wt+etw_t + e_t) fall below the consumption floor.

Retirement (age 65, terminal):

V2(w2)=maxc2w2c21σ1σV_2(w_2) = \max_{c_2 \leq w_2} \frac{c_2^{1-\sigma}}{1-\sigma}
from pprint import pprint

import jax.numpy as jnp
import numpy as np
import pandas as pd
import plotly.express as px

from lcm import (
    AgeGrid,
    DiscreteGrid,
    LinSpacedGrid,
    LogSpacedGrid,
    Model,
    Regime,
    categorical,
)
from lcm.typing import (
    BoolND,
    ContinuousAction,
    ContinuousState,
    DiscreteAction,
    FloatND,
    ScalarInt,
)

Categorical Variables

@categorical(ordered=True)
class LaborSupply:
    do_not_work: int
    work: int


@categorical(ordered=False)
class RegimeId:
    working_life: int
    retirement: int

Model Functions

# Utility


def utility(
    consumption: ContinuousAction,
    labor_supply: DiscreteAction,
    disutility_of_work: float,
    risk_aversion: float,
) -> FloatND:
    return consumption ** (1 - risk_aversion) / (
        1 - risk_aversion
    ) - disutility_of_work * (labor_supply == LaborSupply.work)


def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
    return wealth ** (1 - risk_aversion) / (1 - risk_aversion)


# Auxiliary functions


def earnings(labor_supply: DiscreteAction, wage: float) -> FloatND:
    return jnp.where(labor_supply == LaborSupply.work, wage, 0.0)


def taxes_transfers(
    earnings: FloatND,
    wealth: ContinuousState,
    consumption_floor: float,
    tax_rate: float,
) -> FloatND:
    return jnp.where(
        earnings >= consumption_floor,
        tax_rate * (earnings - consumption_floor),
        jnp.minimum(0.0, wealth + earnings - consumption_floor),
    )


def end_of_period_wealth(
    wealth: ContinuousState,
    earnings: FloatND,
    taxes_transfers: FloatND,
    consumption: ContinuousAction,
) -> FloatND:
    return wealth + earnings - taxes_transfers - consumption


# State transition


def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
    return (1 + interest_rate) * end_of_period_wealth


# Constraints


def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
    return end_of_period_wealth >= 0


# Regime transition


def next_regime(age: float, last_working_age: float) -> ScalarInt:
    return jnp.where(
        age >= last_working_age, RegimeId.retirement, RegimeId.working_life
    )

Regimes and Model

age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.exact_values[-1]

working_life = Regime(
    transition=next_regime,
    active=lambda age: age < retirement_age,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25),
    },
    state_transitions={
        "wealth": next_wealth,
    },
    actions={
        "labor_supply": DiscreteGrid(LaborSupply),
        "consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
    },
    functions={
        "utility": utility,
        "earnings": earnings,
        "taxes_transfers": taxes_transfers,
        "end_of_period_wealth": end_of_period_wealth,
    },
    constraints={
        "borrowing_constraint_working": borrowing_constraint_working,
    },
)

retirement = Regime(
    transition=None,
    active=lambda age: age >= retirement_age,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25),
    },
    functions={"utility": utility_retirement},
)

model = Model(
    regimes={
        "working_life": working_life,
        "retirement": retirement,
    },
    ages=age_grid,
    regime_id_class=RegimeId,
    description="A tiny three-period consumption-savings model.",
)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 37
     33     },
     34     functions={"utility": utility_retirement},
     35 )
     36 
---> 37 model = Model(
     38     regimes={
     39         "working_life": working_life,
     40         "retirement": retirement,

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/model.py:136, in Model.__init__(self, description, ages, regimes, regime_id_class, enable_jit, fixed_params, derived_categoricals)
    127 self.regime_names_to_ids = MappingProxyType(
    128     dict(
    129         sorted(
   (...)    133     )
    134 )
    135 self.regimes = _merge_derived_categoricals(regimes, derived_categoricals)
--> 136 self.internal_regimes, self._params_template = build_regimes_and_template(
    137     ages=self.ages,
    138     regimes=self.regimes,
    139     regime_names_to_ids=self.regime_names_to_ids,
    140     enable_jit=enable_jit,
    141     fixed_params=self.fixed_params,
    142 )
    143 self.enable_jit = enable_jit
    144 self.simulation_output_dtypes = get_simulation_output_dtypes(
    145     regimes=self.regimes,
    146     regime_names_to_ids=self.regime_names_to_ids,
    147 )

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/model_processing.py:71, in build_regimes_and_template(ages, regimes, regime_names_to_ids, enable_jit, fixed_params)
     53 """Build internal regimes and params template in a single pass.
     54 
     55 Compose regime processing, template creation, and optional fixed-param partialling
   (...)     68 
     69 """
     70 if not fixed_params:
---> 71     internal_regimes = process_regimes(
     72         ages=ages,
     73         regimes=regimes,
     74         regime_names_to_ids=regime_names_to_ids,
     75         enable_jit=enable_jit,
     76     )
     77     params_template = create_params_template(internal_regimes)
     78 else:

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/regime_building/processing.py:131, in process_regimes(regimes, ages, regime_names_to_ids, enable_jit)
    128 for name, regime in regimes.items():
    129     regime_params_template = create_regime_params_template(regime)
--> 131     solve_functions = _build_solve_functions(
    132         regime=regime,
    133         regime_name=name,
    134         nested_transitions=nested_transitions[name],
    135         all_grids=all_grids,
    136         regime_params_template=regime_params_template,
    137         regime_names_to_ids=regime_names_to_ids,
    138         variable_info=variable_info[name],
    139         regimes_to_active_periods=regimes_to_active_periods,
    140         regime_to_v_interpolation_info=regime_to_v_interpolation_info,
    141         state_action_space=state_action_spaces[name],
    142         ages=ages,
    143         enable_jit=enable_jit,
    144     )
    146     simulate_functions = _build_simulate_functions(
    147         regime=regime,
    148         regime_name=name,
   (...)    161         solve_compute_regime_transition_probs=solve_functions.compute_regime_transition_probs,
    162     )
    164     internal_regimes[name] = InternalRegime(
    165         name=name,
    166         terminal=regime.terminal,
   (...)    173         _base_state_action_space=state_action_spaces[name],
    174     )

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/regime_building/processing.py:280, in _build_solve_functions(regime, regime_name, nested_transitions, all_grids, regime_params_template, regime_names_to_ids, variable_info, regimes_to_active_periods, regime_to_v_interpolation_info, state_action_space, ages, enable_jit)
    258     compute_intermediates = _build_compute_intermediates_per_period(
    259         flat_param_names=flat_param_names,
    260         regimes_to_active_periods=regimes_to_active_periods,
   (...)    270         enable_jit=enable_jit,
    271     )
    273 max_Q_over_a = _build_max_Q_over_a_per_period(
    274     state_action_space=state_action_space,
    275     Q_and_F_functions=Q_and_F_functions,
    276     grids=all_grids[regime_name],
    277     enable_jit=enable_jit,
    278 )
--> 280 compute_intermediates = _build_compute_intermediates_per_period(
    281     regime=regime,
    282     flat_param_names=frozenset(get_flat_param_names(regime_params_template)),
    283     regimes_to_active_periods=regimes_to_active_periods,
    284     functions=core.functions,
    285     constraints=core.constraints,
    286     transitions=core.transitions,
    287     stochastic_transition_names=core.stochastic_transition_names,
    288     compute_regime_transition_probs=compute_regime_transition_probs,
    289     regime_to_v_interpolation_info=regime_to_v_interpolation_info,
    290     state_action_space=state_action_space,
    291     grids=all_grids[regime_name],
    292     ages=ages,
    293     enable_jit=enable_jit,
    294 )
    296 return SolveFunctions(
    297     functions=core.functions,
    298     constraints=core.constraints,
   (...)    303     compute_intermediates=compute_intermediates,
    304 )

TypeError: _build_compute_intermediates_per_period() got an unexpected keyword argument 'regime'

Parameters

Use model.get_params_template() to see what parameters the model expects, organized by regime and function.

pprint(dict(model.get_params_template()))

Parameters shared across regimes (risk_aversion, discount_factor, interest_rate) can be specified at the model level. Parameters unique to one regime go under the regime name.

params = {
    "discount_factor": 0.95,
    "risk_aversion": 1.5,
    "interest_rate": 0.03,
    "working_life": {
        "utility": {"disutility_of_work": 1.0},
        "earnings": {"wage": 20.0},
        "taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
        "next_regime": {"last_working_age": age_grid.exact_values[-2]},
    },
}

Simulate

n_agents = 100

initial_df = pd.DataFrame(
    {
        "regime": "working_life",
        "age": float(age_grid.exact_values[0]),
        "wealth": np.linspace(1, 20, n_agents),
    }
)

result = model.simulate(
    params=params,
    initial_conditions=initial_df,
    period_to_regime_to_V_arr=None,
)
df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
    df["age"] == retirement_age, "wealth"
]
columns = [
    "regime",
    "labor_supply",
    "consumption",
    "wealth",
    "earnings",
    "taxes_transfers",
    "end_of_period_wealth",
    "value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
    precision=1,
    na_rep="",
)
Source
# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.exact_values[0]
last_working_age = age_grid.exact_values[-2]

df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
    index="subject_id",
    columns="age",
    values="labor_supply",
    aggfunc="first",
)
work_pattern = (
    work_by_age[first_working_age].astype(str)
    + ", "
    + work_by_age[last_working_age].astype(str)
)
assert "work, work" not in work_pattern.to_numpy(), (
    "Plotting assumes that no agent works in both periods of working life."
)

label_map = {
    "work, do_not_work": "low",  # work early, not later
    "do_not_work, work": "medium",  # coast early, work later
    "do_not_work, do_not_work": "high",  # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")

# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)

df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
    numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
    index="initial_wealth",
    columns="age",
    values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]

summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")
Source
fig = px.line(
    df_mean,
    x="age",
    y="consumption",
    color="initial_wealth",
    title="Consumption by Age",
    template="plotly_dark",
)
fig.show()
Source
fig = px.line(
    df_mean,
    x="age",
    y="wealth",
    color="initial_wealth",
    title="Wealth by Age",
    template="plotly_dark",
)
fig.show()