Skip to content

Refactor RNGHierarchy to enable tying the random number generators to detectors#523

Draft
anand-avinash wants to merge 16 commits into
masterfrom
seeder
Draft

Refactor RNGHierarchy to enable tying the random number generators to detectors#523
anand-avinash wants to merge 16 commits into
masterfrom
seeder

Conversation

@anand-avinash
Copy link
Copy Markdown
Contributor

The current RNGHierarchy class guarantees reproducible RNG hierarchy on a given MPI rank provided that the user provided seed, size of the MPI communicator, and number of detectors on that rank remain the same. This can be modified to pass the MPI communicator to RNGHierarchy instead of the size of the communicator. The size of the communicator will be inferred from the communicator itself within the class during its instantiation, keeping the overall behavior of the class unchanged, when the user provides the global MPI communicator to this class.

However, this updated interface can be used to provided same RNG for a given detector across different rank (or time blocks) if a suitable MPI communicator is supplied. Consider an example where 3 detector and their TODs are distributed across 2 detector blocks and 3 time blocks with a total 6 MPI processes:

import time
import litebird_sim as lbs
from litebird_sim import RNGHierarchy

sim = lbs.Simulation(
    start_time=0.0,
    duration_s=86400.0,
    random_seed=12345,
)

sim.create_observations(
    detectors=[
        lbs.DetectorInfo(name="det_A", sampling_rate_hz=1.0),
        lbs.DetectorInfo(name="det_B", sampling_rate_hz=1.0),
        lbs.DetectorInfo(name="det_C", sampling_rate_hz=1.0),
    ],
    split_list_over_processes=False,
    num_of_obs_per_detector=1,
    n_blocks_det=2,
    n_blocks_time=3,
)


global_comm = lbs.MPI_COMM_WORLD

global_rng = RNGHierarchy(
    base_seed=123,
    comm=global_comm,
    num_detectors_per_rank=sim.observations[0].n_detectors,
)


det_block_rng = RNGHierarchy(
    base_seed=456,
    comm=sim.observations[0].comm_det_block,
    num_detectors_per_rank=sim.observations[0].n_detectors,
)


time_block_rng = RNGHierarchy(
    base_seed=789,
    comm=sim.observations[0].comm_time_block,
    num_detectors_per_rank=sim.observations[0].n_detectors,
)


def print_det_level_generators(
    global_rank, local_rank, det_names, det_level_generators
):
    assert len(det_names) == len(det_level_generators)
    for det_name, gen in zip(det_names, det_level_generators):
        print(
            f"global_rank = {global_rank}, local_rank = {local_rank}",
            det_name,
            gen.__getstate__(),
        )


### RNG hierarchy for global communicator

print_det_level_generators(
    global_comm.rank,
    global_comm.rank,
    sim.observations[0].name,
    global_rng.get_detector_level_generators_on_rank(global_comm.rank),
)

time.sleep(0.5)
if global_comm.rank == 0:
    print("\n")
time.sleep(0.5)


### RNG hierarchy for time block communicator

print_det_level_generators(
    global_comm.rank,
    sim.observations[0].comm_time_block.rank,
    sim.observations[0].name,
    time_block_rng.get_detector_level_generators_on_rank(
        sim.observations[0].comm_time_block.rank
    ),
)


time.sleep(0.5)
if global_comm.rank == 0:
    print("\n")
time.sleep(0.5)


### RNG hierarchy for detector block communicator
### This serves no purpose

# print_det_level_generators(
#     global_comm.rank,
#     sim.observations[0].comm_det_block.rank,
#     sim.observations[0].name,
#     det_block_rng.get_detector_level_generators_on_rank(
#         sim.observations[0].comm_det_block.rank
#     ),
# )

The data distributions would look like the following:

det A, det B det C
time_block 1 rank 0 rank 3
time_block 2 rank 1 rank 4
time_block 3 rank 2 rank 5

With this distribution, the global communicator will be partitioned into three time block communicators accessible with sim.observations[0].comm_time_block. The first one will contain ranks 0 and 3, second will contain ranks 1 and 4, and the third will contain ranks 2 and 5. As a result, since for ranks 0, 1, and 2 the size of sim.observations[0].comm_time_block and the number of detectors per rank are same, RNGHierarchy will produce exactly the same hierarchy with a common seed on these ranks. The same goes for ranks 3, 4, and 5. So, the detector level RNGs for detector A will be same across rank 0, 1, and 2; detector level RNGs for detector B will be same across ranks 0, 1, and 2; and detector level RNGs for detector C will be same across ranks 3, 4,and 5 -- guaranteeing that random numbers produced for given detector are same across different MPI ranks. This completely solves the issue raised with #510.

Here I am showing the result of the print statements from the script I added above for RNG state verification:

### RNG hierarchy for global communicator - each block of TOD has an RNG with unique state


global_rank = 0, local_rank = 0 det_A {'bit_generator': 'PCG64', 'state': {'state': 144619848914138535302906084289083140116, 'inc': 137901272774577361021997142407036157741}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 0, local_rank = 0 det_B {'bit_generator': 'PCG64', 'state': {'state': 68488701009948191967673224139517843068, 'inc': 181070714709730587431689905102663183521}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 1, local_rank = 1 det_A {'bit_generator': 'PCG64', 'state': {'state': 31074665416037212809444864448348876291, 'inc': 271684247880976466715616379799581786115}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 1, local_rank = 1 det_B {'bit_generator': 'PCG64', 'state': {'state': 39625932023379125989704284680160134973, 'inc': 245272176840412043495140278152880377797}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 2, local_rank = 2 det_A {'bit_generator': 'PCG64', 'state': {'state': 327913130787634335600406052353961127019, 'inc': 285897430671043131730482796764641238999}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 2, local_rank = 2 det_B {'bit_generator': 'PCG64', 'state': {'state': 310112831948863994768429857504937681947, 'inc': 27809136912511448367194439464235745091}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 3, local_rank = 3 det_C {'bit_generator': 'PCG64', 'state': {'state': 83821247557234737719374691102984488495, 'inc': 106074201384900303510227639608675842753}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 4, local_rank = 4 det_C {'bit_generator': 'PCG64', 'state': {'state': 94339045799319636164492511438536753607, 'inc': 246373086755039278226331009459390934807}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 5, local_rank = 5 det_C {'bit_generator': 'PCG64', 'state': {'state': 228074635286621183168880197838952638317, 'inc': 14417719495420430718010116231659948761}, 'has_uint32': 0, 'uinteger': 0}



### RNG hierarchy for time block communicator - RNGs corresponding to a given detector have same state across different MPI ranks


global_rank = 0, local_rank = 0 det_A {'bit_generator': 'PCG64', 'state': {'state': 211423005980925913106542892991315253702, 'inc': 103874591357702541628135110433674472181}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 1, local_rank = 0 det_A {'bit_generator': 'PCG64', 'state': {'state': 211423005980925913106542892991315253702, 'inc': 103874591357702541628135110433674472181}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 2, local_rank = 0 det_A {'bit_generator': 'PCG64', 'state': {'state': 211423005980925913106542892991315253702, 'inc': 103874591357702541628135110433674472181}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 0, local_rank = 0 det_B {'bit_generator': 'PCG64', 'state': {'state': 333871133010337104201869106881567755991, 'inc': 156961187590396004783913536309161093731}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 1, local_rank = 0 det_B {'bit_generator': 'PCG64', 'state': {'state': 333871133010337104201869106881567755991, 'inc': 156961187590396004783913536309161093731}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 2, local_rank = 0 det_B {'bit_generator': 'PCG64', 'state': {'state': 333871133010337104201869106881567755991, 'inc': 156961187590396004783913536309161093731}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 3, local_rank = 1 det_C {'bit_generator': 'PCG64', 'state': {'state': 103944651775319475029580060687993738622, 'inc': 51519893583648684950972087145344717905}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 4, local_rank = 1 det_C {'bit_generator': 'PCG64', 'state': {'state': 103944651775319475029580060687993738622, 'inc': 51519893583648684950972087145344717905}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 5, local_rank = 1 det_C {'bit_generator': 'PCG64', 'state': {'state': 103944651775319475029580060687993738622, 'inc': 51519893583648684950972087145344717905}, 'has_uint32': 0, 'uinteger': 0}

Please take a look at this, if it seems suitable and sufficient, I would update the docstrings before merging this to master branch.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 20, 2026

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  litebird_sim
  __init__.py
  beam_convolution.py
  detectors.py
  gaindrifts.py
  grasp2alm.py
  hwp.py
  noise.py
  non_linearity.py 174
  observations.py
  scan_map.py
  seeding.py
  simulations.py 2114
  litebird_sim/mapmaking
  binner.py
  destriper.py
  pair_differencing.py
Project Total  

This report was generated by python-coverage-comment-action

@mj-gomes
Copy link
Copy Markdown
Contributor

mj-gomes commented May 26, 2026

Hi Avinash, thanks ! So, if I understand correctly, for the non linearity case, where we want the state to be the same for all time divisions in a given detector, the communicator to be used should be comm_time_block, right? But, in that case, shouldn't we have comm as an argument to regenerate_or_check_detector_generators() ? Because, in that way, when I call regenerate_or_check_detector_generators() in non_linearity.apply_quadratic_nonlin_to_observations(), I could call it with comm=sim.observations[0].comm_time_block to force the correct state for each detector, independently of the division in time. Did I understand this correctly?

@anand-avinash
Copy link
Copy Markdown
Contributor Author

Hi @mj-gomes! Yes you are right. Since we use regenerate_or_check_detector_generators() in multiple places to generate appropriate hierarchies, and it currently assumes the communicator is global, it makes sense to pass the specific communicator to it instead. I will update the function now.

@mj-gomes
Copy link
Copy Markdown
Contributor

We should probably also add a test in test_mpi.py or test_mpi_n4.py to check that for the non linearity case this works when multiple MPI tasks share the same detector. I can do it if you want.

@mj-gomes
Copy link
Copy Markdown
Contributor

mj-gomes commented May 28, 2026

Hi @anand-avinash, I had to remove the rng_hierarchy handling in sim.apply_quadratic_nonlin, as it was always creating the rng hierarchy as previously defined, not using your changes.

Now we have to pass user_seed to this method, to ensure it always works. I did a test in test_mpi.py using 2 MPI ranks.

We have 2 detectors, 4 time samples, 2 MPI blocks of time, 1 MPI block of detectors.

This means the following division:
rank 0 has detector 1 from t=0 to t=1; and detector 2 from t=0 to t=1;
rank 1 has detector 1 from t=2 to t=3; and detector 2 from t=2 to t=3;

The seeding is now correct when I run the test locally in my pc (using mpirun -n 2). However we are getting ValueError: Too many blocks: n_blocks_det x n_blocks_time = 2 but the number processes is 1 in the actions. I believe tests in test_mpi.py are automatically run in github actions with mpi, but maybe I am wrong, and I should specify it somewhere?

@anand-avinash
Copy link
Copy Markdown
Contributor Author

Hi @mj-gomes!

Hi @anand-avinash, I had to remove the rng_hierarchy handling in sim.apply_quadratic_nonlin, as it was always creating the rng hierarchy as previously defined, not using your changes.

Good catch!

The seeding is now correct when I run the test locally in my pc (using mpirun -n 2). However we are getting ValueError: Too many blocks: n_blocks_det x n_blocks_time = 2 but the number processes is 1 in the actions. I believe tests in test_mpi are automatically run in github actions with mpi.py, but maybe I am wrong, and I should specify it somewhere?

If you look at the github action workflow for the tests, after installing litebird_sim, it runs pytest without mpiexec/mpirun for all available tests (including test/test_mpi.py and test/test_mpi_n4.py). So by default it takes the size of global communicator as 1. But in your test, you are requesting 2 blocks so it is producing an error. A simple fix would be to skip the test if the size of communicator is not suitable for your test, as it is done here.

@mj-gomes
Copy link
Copy Markdown
Contributor

Hi @anand-avinash,

A simple fix would be to skip the test if the size of communicator is not suitable for your test, as it is done here.

Oh, right! Thanks, I did that. I also finished what was missing (updated the docs, merged master etc.). I think this is good to go. @anand-avinash @ziotom78 What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants