A three-period consumption-savings model with two regimes:
Working life (ages 25 and 45): The agent chooses whether to work and how much to consume. A simple tax-and-transfer system guarantees a consumption floor. Savings earn interest.
Retirement (age 65): Terminal regime. The agent consumes out of remaining wealth.
Model¶
An agent lives for three periods (ages 25, 45, and 65). In the first two periods (working life), the agent chooses whether to work and how much to consume . In the final period (retirement), the agent consumes out of remaining wealth.
Working life (ages 25 and 45):
subject to
where is wealth, earnings, the wage, a consumption floor guaranteed by transfers, the tax rate, and end-of-period wealth. The transfer only kicks in when the agent’s resources () fall below the consumption floor.
Retirement (age 65, terminal):
from pprint import pprint
import jax.numpy as jnp
import numpy as np
import pandas as pd
import plotly.express as px
from lcm import (
AgeGrid,
DiscreteGrid,
LinSpacedGrid,
LogSpacedGrid,
Model,
Regime,
categorical,
)
from lcm.typing import (
BoolND,
ContinuousAction,
ContinuousState,
DiscreteAction,
FloatND,
ScalarInt,
)Categorical Variables¶
@categorical(ordered=True)
class LaborSupply:
do_not_work: int
work: int
@categorical(ordered=False)
class RegimeId:
working_life: int
retirement: intModel Functions¶
# Utility
def utility(
consumption: ContinuousAction,
labor_supply: DiscreteAction,
disutility_of_work: float,
risk_aversion: float,
) -> FloatND:
return consumption ** (1 - risk_aversion) / (
1 - risk_aversion
) - disutility_of_work * (labor_supply == LaborSupply.work)
def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
return wealth ** (1 - risk_aversion) / (1 - risk_aversion)
# Auxiliary functions
def earnings(labor_supply: DiscreteAction, wage: float) -> FloatND:
return jnp.where(labor_supply == LaborSupply.work, wage, 0.0)
def taxes_transfers(
earnings: FloatND,
wealth: ContinuousState,
consumption_floor: float,
tax_rate: float,
) -> FloatND:
return jnp.where(
earnings >= consumption_floor,
tax_rate * (earnings - consumption_floor),
jnp.minimum(0.0, wealth + earnings - consumption_floor),
)
def end_of_period_wealth(
wealth: ContinuousState,
earnings: FloatND,
taxes_transfers: FloatND,
consumption: ContinuousAction,
) -> FloatND:
return wealth + earnings - taxes_transfers - consumption
# State transition
def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
return (1 + interest_rate) * end_of_period_wealth
# Constraints
def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
return end_of_period_wealth >= 0
# Regime transition
def next_regime(age: float, last_working_age: float) -> ScalarInt:
return jnp.where(
age >= last_working_age, RegimeId.retirement, RegimeId.working_life
)Regimes and Model¶
age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.exact_values[-1]
working_life = Regime(
transition=next_regime,
active=lambda age: age < retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25),
},
state_transitions={
"wealth": next_wealth,
},
actions={
"labor_supply": DiscreteGrid(LaborSupply),
"consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
},
functions={
"utility": utility,
"earnings": earnings,
"taxes_transfers": taxes_transfers,
"end_of_period_wealth": end_of_period_wealth,
},
constraints={
"borrowing_constraint_working": borrowing_constraint_working,
},
)
retirement = Regime(
transition=None,
active=lambda age: age >= retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25),
},
functions={"utility": utility_retirement},
)
model = Model(
regimes={
"working_life": working_life,
"retirement": retirement,
},
ages=age_grid,
regime_id_class=RegimeId,
description="A tiny three-period consumption-savings model.",
)---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 37
33 },
34 functions={"utility": utility_retirement},
35 )
36
---> 37 model = Model(
38 regimes={
39 "working_life": working_life,
40 "retirement": retirement,
File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/model.py:136, in Model.__init__(self, description, ages, regimes, regime_id_class, enable_jit, fixed_params, derived_categoricals)
127 self.regime_names_to_ids = MappingProxyType(
128 dict(
129 sorted(
(...) 133 )
134 )
135 self.regimes = _merge_derived_categoricals(regimes, derived_categoricals)
--> 136 self.internal_regimes, self._params_template = build_regimes_and_template(
137 ages=self.ages,
138 regimes=self.regimes,
139 regime_names_to_ids=self.regime_names_to_ids,
140 enable_jit=enable_jit,
141 fixed_params=self.fixed_params,
142 )
143 self.enable_jit = enable_jit
144 self.simulation_output_dtypes = get_simulation_output_dtypes(
145 regimes=self.regimes,
146 regime_names_to_ids=self.regime_names_to_ids,
147 )
File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/model_processing.py:71, in build_regimes_and_template(ages, regimes, regime_names_to_ids, enable_jit, fixed_params)
53 """Build internal regimes and params template in a single pass.
54
55 Compose regime processing, template creation, and optional fixed-param partialling
(...) 68
69 """
70 if not fixed_params:
---> 71 internal_regimes = process_regimes(
72 ages=ages,
73 regimes=regimes,
74 regime_names_to_ids=regime_names_to_ids,
75 enable_jit=enable_jit,
76 )
77 params_template = create_params_template(internal_regimes)
78 else:
File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/regime_building/processing.py:131, in process_regimes(regimes, ages, regime_names_to_ids, enable_jit)
128 for name, regime in regimes.items():
129 regime_params_template = create_regime_params_template(regime)
--> 131 solve_functions = _build_solve_functions(
132 regime=regime,
133 regime_name=name,
134 nested_transitions=nested_transitions[name],
135 all_grids=all_grids,
136 regime_params_template=regime_params_template,
137 regime_names_to_ids=regime_names_to_ids,
138 variable_info=variable_info[name],
139 regimes_to_active_periods=regimes_to_active_periods,
140 regime_to_v_interpolation_info=regime_to_v_interpolation_info,
141 state_action_space=state_action_spaces[name],
142 ages=ages,
143 enable_jit=enable_jit,
144 )
146 simulate_functions = _build_simulate_functions(
147 regime=regime,
148 regime_name=name,
(...) 161 solve_compute_regime_transition_probs=solve_functions.compute_regime_transition_probs,
162 )
164 internal_regimes[name] = InternalRegime(
165 name=name,
166 terminal=regime.terminal,
(...) 173 _base_state_action_space=state_action_spaces[name],
174 )
File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/324/src/lcm/regime_building/processing.py:280, in _build_solve_functions(regime, regime_name, nested_transitions, all_grids, regime_params_template, regime_names_to_ids, variable_info, regimes_to_active_periods, regime_to_v_interpolation_info, state_action_space, ages, enable_jit)
258 compute_intermediates = _build_compute_intermediates_per_period(
259 flat_param_names=flat_param_names,
260 regimes_to_active_periods=regimes_to_active_periods,
(...) 270 enable_jit=enable_jit,
271 )
273 max_Q_over_a = _build_max_Q_over_a_per_period(
274 state_action_space=state_action_space,
275 Q_and_F_functions=Q_and_F_functions,
276 grids=all_grids[regime_name],
277 enable_jit=enable_jit,
278 )
--> 280 compute_intermediates = _build_compute_intermediates_per_period(
281 regime=regime,
282 flat_param_names=frozenset(get_flat_param_names(regime_params_template)),
283 regimes_to_active_periods=regimes_to_active_periods,
284 functions=core.functions,
285 constraints=core.constraints,
286 transitions=core.transitions,
287 stochastic_transition_names=core.stochastic_transition_names,
288 compute_regime_transition_probs=compute_regime_transition_probs,
289 regime_to_v_interpolation_info=regime_to_v_interpolation_info,
290 state_action_space=state_action_space,
291 grids=all_grids[regime_name],
292 ages=ages,
293 enable_jit=enable_jit,
294 )
296 return SolveFunctions(
297 functions=core.functions,
298 constraints=core.constraints,
(...) 303 compute_intermediates=compute_intermediates,
304 )
TypeError: _build_compute_intermediates_per_period() got an unexpected keyword argument 'regime'Parameters¶
Use model.get_params_template() to see what parameters the model expects, organized
by regime and function.
pprint(dict(model.get_params_template()))Parameters shared across regimes (risk_aversion, discount_factor,
interest_rate) can be specified at the model level. Parameters unique to one
regime go under the regime name.
params = {
"discount_factor": 0.95,
"risk_aversion": 1.5,
"interest_rate": 0.03,
"working_life": {
"utility": {"disutility_of_work": 1.0},
"earnings": {"wage": 20.0},
"taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
"next_regime": {"last_working_age": age_grid.exact_values[-2]},
},
}Simulate¶
n_agents = 100
initial_df = pd.DataFrame(
{
"regime": "working_life",
"age": float(age_grid.exact_values[0]),
"wealth": np.linspace(1, 20, n_agents),
}
)
result = model.simulate(
params=params,
initial_conditions=initial_df,
period_to_regime_to_V_arr=None,
)df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
df["age"] == retirement_age, "wealth"
]
columns = [
"regime",
"labor_supply",
"consumption",
"wealth",
"earnings",
"taxes_transfers",
"end_of_period_wealth",
"value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
precision=1,
na_rep="",
)Source
# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.exact_values[0]
last_working_age = age_grid.exact_values[-2]
df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
index="subject_id",
columns="age",
values="labor_supply",
aggfunc="first",
)
work_pattern = (
work_by_age[first_working_age].astype(str)
+ ", "
+ work_by_age[last_working_age].astype(str)
)
assert "work, work" not in work_pattern.to_numpy(), (
"Plotting assumes that no agent works in both periods of working life."
)
label_map = {
"work, do_not_work": "low", # work early, not later
"do_not_work, work": "medium", # coast early, work later
"do_not_work, do_not_work": "high", # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")
# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)
df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
index="initial_wealth",
columns="age",
values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]
summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")Source
fig = px.line(
df_mean,
x="age",
y="consumption",
color="initial_wealth",
title="Consumption by Age",
template="plotly_dark",
)
fig.show()Source
fig = px.line(
df_mean,
x="age",
y="wealth",
color="initial_wealth",
title="Wealth by Age",
template="plotly_dark",
)
fig.show()