Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions docs/source/guides/batching.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
.. _batching_guide:

Batching Strategy and Segmentation
=================================

The backward induction in ``dcegm`` is solved in batches. This is a computational detail to make array shapes compatible with fast JAX scans while preserving the model logic.

Why batching exists
-------------------

The number of feasible state-choice combinations usually changes over the life cycle. However, vectorized scan steps work best with equal leading dimensions. Batching groups state-choice rows into equal-sized chunks so each scan step can run with fixed shapes.

Two batching modes
------------------

``dcegm`` supports two batching modes:

- ``largest_block``:

- Finds large dependency-safe batches.
- Typically yields fewer and larger batches.
- Good default for smooth state-choice profiles.
- **Default configuration**: ``dcegm`` uses this batching mode with no segmentation if batching is not configured otherwise.

- ``period_max``:

- Uses one batch per period within a segment.
- Pads smaller period batches to the segment-specific maximum number of state choices per period.
- Useful when state-choice counts vary strongly by period.
- **Padding rule**: If a period has fewer state choices than the segment maximum, the batch is padded with a valid dummy state-choice index from the same batch (deterministically the first one). This keeps shapes aligned and does not change the solution logic.

Segmenting the horizon
----------------------

Use ``min_period_batch_segments`` to split the pre-terminal part of the horizon into segments.

- Without segmentation:

- ``batch_mode`` must be a single string.

- With segmentation:

- ``batch_mode`` can be a string (reused for all segments), or
- ``batch_mode`` can be a list with one entry per segment.

The number of segments is ``len(min_period_batch_segments) + 1``.

Valid strings are ``"largest_block"`` and ``"period_max"``.

Examples
~~~~~~~~

No segmentation:

.. code-block:: python

model_config = {
"n_periods": 20,
"choices": np.arange(3, dtype=int),
"continuous_states": {"assets_end_of_period": np.linspace(0, 100, 200)},
"n_quad_points": 5,
"batch_mode": "period_max",
}

With segmentation:

.. code-block:: python

model_config = {
"n_periods": 20,
"choices": np.arange(3, dtype=int),
"continuous_states": {"assets_end_of_period": np.linspace(0, 100, 200)},
"n_quad_points": 5,
"min_period_batch_segments": [8, 14],
"batch_mode": ["period_max", "largest_block", "period_max"],
}

Tipp: Use ``get_n_state_choices_per_period`` to choose segments
----------------------------------------------------------------

To determine sensible segments for batching, inspect the number of state-choice combinations per period.

.. code-block:: python

model = dcegm.setup_model(
model_config=model_config,
model_specs=model_specs,
utility_functions=utility_functions,
utility_functions_final_period=utility_functions_final_period,
budget_constraint=budget_constraint,
state_space_functions=state_space_functions,
stochastic_states_transitions=stochastic_states_transitions,
)

n_state_choices = model.get_n_state_choices_per_period()
print(n_state_choices)

This series can be used to detect structural breaks in complexity. Typical heuristics are:

- Keep periods with similar counts in one segment.
- Split where there are abrupt jumps/drops.
- Use ``period_max`` in highly uneven segments.
- Keep ``largest_block`` in smoother segments.

Example: experience growth and retirement regimes
-------------------------------------------------

Consider a model with a discrete experience state where:

- choice 0: no work, experience unchanged,
- choice 1: regular work, experience increases by 1,
- choice 2: intensive work, experience increases by 2,
- choice 3: retirement.

Suppose retirement becomes available from period 8, and is mandatory from period 14.

.. code-block:: python

def choice_set(period, lagged_choice):
if period >= 14:
return np.array([3], dtype=int) # mandatory retirement
if period >= 8:
return np.array([0, 1, 2, 3], dtype=int) # retirement becomes available
return np.array([0, 1, 2], dtype=int)

def next_period_deterministic_state(period, choice, experience):
if choice == 1:
experience_next = experience + 1
elif choice == 2:
experience_next = experience + 2
else:
experience_next = experience
return {
"period": period + 1,
"lagged_choice": choice,
"experience": experience_next,
}

In this setup you often see:

- gradual growth in state-choice counts early on,
- a jump when retirement becomes optional,
- a drop when retirement becomes mandatory.

This pattern is a good reason to separate segments around the two regime changes:

.. code-block:: python

model_config["min_period_batch_segments"] = [8, 14]
model_config["batch_mode"] = ["period_max", "largest_block", "period_max"]

We suggest testing different segmentation choices to determine the fastest solution for your model.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Check out our :ref:`guides<guides/index.rst>` to find information on getting sta
:hidden:

guides/practitioner_guide
guides/batching
guides/templates
guides/minimal_example.ipynb

Expand Down
52 changes: 43 additions & 9 deletions src/dcegm/pre_processing/batches/batch_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ def create_batches_and_information(
model_structure,
n_periods,
min_period_batch_segments=None,
batch_mode="largest_block",
):
"""Batches are used instead of periods to have chunks of equal sized state choices.
The batch inparams=paramsformation dictionary contains the following arrays
reflecting the.
The returned batch information dictionary contains the following arrays
reflecting steps in the backward induction:

steps in the backward induction:
- batches_state_choice_idx: The state choice indexes in each batch to be solved.
To solve the state choices in the egm step, we have to look at the child states
and the corresponding state choice indexes in the child states. For that we save
the following:
To solve the state choices in the egm step, we have to look at the child states
and the corresponding state choice indexes in the child states. For that we save
the following:
- child_state_choice_idxs_to_interp: The state choice indexes in we need to
interpolate the wealth on.
- child_states_idxs: The parent state indexes of the child states, i.e. the
Expand Down Expand Up @@ -64,10 +64,22 @@ def create_batches_and_information(
state_choice_space = model_structure["state_choice_space"]
bool_state_choices_to_batch = state_choice_space[:, 0] < n_periods - 2

valid_batch_modes = {"largest_block", "period_max"}

if min_period_batch_segments is None:
if isinstance(batch_mode, list):
raise ValueError(
"If min_period_batch_segments is not supplied, batch_mode must be a string."
)
if batch_mode not in valid_batch_modes:
raise ValueError(
f"batch_mode must be one of {valid_batch_modes}. Got {batch_mode}."
)

single_batch_segment_info = create_single_segment_of_batches(
bool_state_choices_to_batch, model_structure
bool_state_choices_to_batch,
model_structure,
batch_mode=batch_mode,
)
segment_infos = {
"n_segments": 1,
Expand Down Expand Up @@ -97,6 +109,24 @@ def create_batches_and_information(
"The periods to split the batches have to be increasing and at least two periods apart."
)

if isinstance(batch_mode, str):
if batch_mode not in valid_batch_modes:
raise ValueError(
f"batch_mode must be one of {valid_batch_modes}. Got {batch_mode}."
)
batch_mode = [batch_mode] * n_segments
elif isinstance(batch_mode, list):
if len(batch_mode) != n_segments:
raise ValueError(
"If min_period_batch_segments is supplied, batch_mode must be a list with one entry per segment."
)
if not all(mode in valid_batch_modes for mode in batch_mode):
raise ValueError(
f"All entries in batch_mode must be one of {valid_batch_modes}."
)
else:
raise ValueError("batch_mode must be a string or a list of strings.")

segment_infos = {
"n_segments": n_segments,
}
Expand All @@ -111,15 +141,19 @@ def create_batches_and_information(
bool_state_choices_segment = bool_state_choices_to_batch & (~split_cond)

segment_batch_info = create_single_segment_of_batches(
bool_state_choices_segment, model_structure
bool_state_choices_segment,
model_structure,
batch_mode=batch_mode[id_segment],
)
segment_infos[f"batches_info_segment_{id_segment}"] = segment_batch_info

# Set the bools to False which have been batched already
bool_state_choices_to_batch = bool_state_choices_to_batch & split_cond

last_segment_batch_info = create_single_segment_of_batches(
bool_state_choices_to_batch, model_structure
bool_state_choices_to_batch,
model_structure,
batch_mode=batch_mode[n_segments - 1],
)

# We loop until n_segments - 2 and then add the last segment
Expand Down
Loading
Loading