Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions litebird_sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
)
from .madam import save_simulation_for_madam
from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo
from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, MPI_COMM_GRID
from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION
from .mueller_convolver import MuellerConvolver
from .noise import (
add_white_noise,
Expand Down Expand Up @@ -239,7 +239,6 @@
"MPI_COMM_WORLD",
"MPI_ENABLED",
"MPI_CONFIGURATION",
"MPI_COMM_GRID",
# mueller_convolver.py
"MuellerConvolver",
# observations.py
Expand Down
2 changes: 1 addition & 1 deletion litebird_sim/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def write_list_of_observations(
observations = [observations]
except IndexError:
# Empty list
# We do not want to return here, as we still need to participate to
# We do not want to return here, as we still need to participate in
# the call to _compute_global_start_index below
observations = [] # type: List[Observation]

Expand Down
18 changes: 8 additions & 10 deletions litebird_sim/mapmaking/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from numba import njit

from litebird_sim.coordinates import CoordinateSystem
from litebird_sim.mpi import MPI_COMM_GRID
from litebird_sim.observations import Observation
from litebird_sim.pointings_in_obs import _get_pointings_array, _get_pol_angle

Expand Down Expand Up @@ -109,15 +108,14 @@ def get_map_making_weights(
except AttributeError:
weights = np.ones(observations.n_detectors)

if check and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL:
if check:
# Check that there are no weird weights
assert np.all(np.isfinite(weights)), (
f"Not all the detectors' weights are finite numbers: {weights}"
)
assert np.all(weights > 0.0), (
f"Not all the detectors' weights are positive: {weights}"
)
if check:
# Check that there are no weird weights
assert np.all(np.isfinite(weights)), (
f"Not all the detectors' weights are finite numbers: {weights}"
)
assert np.all(weights > 0.0), (
f"Not all the detectors' weights are positive: {weights}"
)

return weights

Expand Down
204 changes: 93 additions & 111 deletions litebird_sim/mapmaking/destriper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from litebird_sim.coordinates import CoordinateSystem, coord_sys_to_healpix_string
from litebird_sim.hwp import HWP
from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID
from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD
from litebird_sim.observations import Observation
from litebird_sim.pointings_in_obs import (
_get_hwp_angle,
Expand Down Expand Up @@ -44,7 +44,7 @@


__DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits"
__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_GRID.COMM_OBS_GRID.rank:04d}.fits"
__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_WORLD.rank:04d}.fits"


def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]:
Expand Down Expand Up @@ -495,10 +495,8 @@ def _build_nobs_matrix(
)

# Now we must accumulate the result of every MPI process
if MPI_ENABLED and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL:
MPI_COMM_GRID.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM
)
if MPI_ENABLED:
MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM)

# `nobs_matrix_cholesky` will *not* contain the M_i maps shown in
# Eq. 9 of KurkiSuonio2009, but its Cholesky decomposition, i.e.,
Expand Down Expand Up @@ -745,12 +743,8 @@ def _compute_binned_map(
)

if MPI_ENABLED:
MPI_COMM_GRID.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM
)
MPI_COMM_GRID.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM
)
MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM)
MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM)

# Step 2: compute the “binned map” (Eq. 21)
_sum_map_to_binned_map(
Expand Down Expand Up @@ -990,7 +984,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float:
# the dot product
local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)])
if MPI_ENABLED:
return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM)
return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.SUM)
else:
return local_result

Expand All @@ -1007,7 +1001,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float:
"""
local_result = np.max(np.abs(residual))
if MPI_ENABLED:
return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX)
return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.MAX)
else:
return local_result

Expand Down Expand Up @@ -1421,7 +1415,7 @@ def _run_destriper(
bytes_in_temporary_buffers += mask.nbytes

if MPI_ENABLED:
bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(
bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce(
bytes_in_temporary_buffers,
op=mpi4py.MPI.SUM,
)
Expand Down Expand Up @@ -1623,103 +1617,91 @@ def my_gui_callback(
binned_map = np.empty((3, number_of_pixels))
hit_map = np.empty(number_of_pixels)

if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL:
# perform the following operations when MPI is not being used
# OR when the MPI_COMM_GRID.COMM_OBS_GRID is not a NULL communicator
if do_destriping:
try:
# This will fail if the parameter is a scalar
len(params.samples_per_baseline)

baseline_lengths_list = params.samples_per_baseline
assert len(baseline_lengths_list) == len(obs_list), (
f"The list baseline_lengths_list has {len(baseline_lengths_list)} "
f"elements, but there are {len(obs_list)} observations"
)
except TypeError:
# Ok, params.samples_per_baseline is a scalar, so we must
# figure out the number of samples in each baseline within
# each observation
baseline_lengths_list = [
split_items_evenly(
n=getattr(cur_obs, components[0]).shape[1],
sub_n=int(params.samples_per_baseline),
)
for cur_obs in obs_list
]

# Each element of this list is a 2D array with shape (N_det, N_baselines),
# where N_det is the number of detectors in the i-th Observation object
recycle_baselines = False
if baselines_list is None:
baselines_list = [
np.zeros(
(getattr(cur_obs, components[0]).shape[0], len(cur_baseline))
)
for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list)
]
else:
recycle_baselines = True

destriped_map = np.empty((3, number_of_pixels))
(
baselines_list,
baseline_errors_list,
history_of_stopping_factors,
best_stopping_factor,
converged,
bytes_in_temporary_buffers,
) = _run_destriper(
obs_list=obs_list,
nobs_matrix_cholesky=nobs_matrix_cholesky,
binned_map=binned_map,
destriped_map=destriped_map,
hit_map=hit_map,
baseline_lengths_list=baseline_lengths_list,
baselines_list_start=baselines_list,
recycle_baselines=recycle_baselines,
recycled_convergence=recycled_convergence,
dm_list=detector_mask_list,
tm_list=time_mask_list,
component=components[0],
threshold=params.threshold,
max_steps=params.iter_max,
use_preconditioner=params.use_preconditioner,
callback=callback,
callback_kwargs=callback_kwargs if callback_kwargs else {},
)
# perform the following operations when MPI is not being used
if do_destriping:
try:
# This will fail if the parameter is a scalar
len(params.samples_per_baseline)

if MPI_ENABLED:
bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(
bytes_in_temporary_buffers,
op=mpi4py.MPI.SUM,
baseline_lengths_list = params.samples_per_baseline
assert len(baseline_lengths_list) == len(obs_list), (
f"The list baseline_lengths_list has {len(baseline_lengths_list)} "
f"elements, but there are {len(obs_list)} observations"
)
except TypeError:
# Ok, params.samples_per_baseline is a scalar, so we must
# figure out the number of samples in each baseline within
# each observation
baseline_lengths_list = [
split_items_evenly(
n=getattr(cur_obs, components[0]).shape[1],
sub_n=int(params.samples_per_baseline),
)
for cur_obs in obs_list
]

# Each element of this list is a 2D array with shape (N_det, N_baselines),
# where N_det is the number of detectors in the i-th Observation object
recycle_baselines = False
if baselines_list is None:
baselines_list = [
np.zeros((getattr(cur_obs, components[0]).shape[0], len(cur_baseline)))
for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list)
]
else:
# No need to run the destriping, just compute the binned map with
# one single baseline set to zero
_compute_binned_map(
obs_list=obs_list,
output_sky_map=binned_map,
output_hit_map=hit_map,
nobs_matrix_cholesky=nobs_matrix_cholesky,
component=components[0],
dm_list=detector_mask_list,
tm_list=time_mask_list,
baselines_list=None,
baseline_lengths_list=[
np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int)
for cur_obs in obs_list
],
recycle_baselines = True

destriped_map = np.empty((3, number_of_pixels))
(
baselines_list,
baseline_errors_list,
history_of_stopping_factors,
best_stopping_factor,
converged,
bytes_in_temporary_buffers,
) = _run_destriper(
obs_list=obs_list,
nobs_matrix_cholesky=nobs_matrix_cholesky,
binned_map=binned_map,
destriped_map=destriped_map,
hit_map=hit_map,
baseline_lengths_list=baseline_lengths_list,
baselines_list_start=baselines_list,
recycle_baselines=recycle_baselines,
recycled_convergence=recycled_convergence,
dm_list=detector_mask_list,
tm_list=time_mask_list,
component=components[0],
threshold=params.threshold,
max_steps=params.iter_max,
use_preconditioner=params.use_preconditioner,
callback=callback,
callback_kwargs=callback_kwargs if callback_kwargs else {},
)

if MPI_ENABLED:
bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce(
bytes_in_temporary_buffers,
op=mpi4py.MPI.SUM,
)
bytes_in_temporary_buffers = 0
destriped_map = None
baseline_lengths_list = None
baselines_list = None
baseline_errors_list = None
history_of_stopping_factors = None
best_stopping_factor = None
converged = True
else:
# No need to run the destriping, just compute the binned map with
# one single baseline set to zero
_compute_binned_map(
obs_list=obs_list,
output_sky_map=binned_map,
output_hit_map=hit_map,
nobs_matrix_cholesky=nobs_matrix_cholesky,
component=components[0],
dm_list=detector_mask_list,
tm_list=time_mask_list,
baselines_list=None,
baseline_lengths_list=[
np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int)
for cur_obs in obs_list
],
)
bytes_in_temporary_buffers = 0
destriped_map = None
baseline_lengths_list = None
baselines_list = None
Expand Down Expand Up @@ -2018,11 +2000,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None:

primary_hdu = fits.PrimaryHDU()
primary_hdu.header["MPIRANK"] = (
MPI_COMM_GRID.COMM_OBS_GRID.rank,
MPI_COMM_WORLD.rank,
"The rank of the MPI process that wrote this file",
)
primary_hdu.header["MPISIZE"] = (
MPI_COMM_GRID.COMM_OBS_GRID.size,
MPI_COMM_WORLD.size,
"The number of MPI processes used in the computation",
)

Expand Down Expand Up @@ -2238,11 +2220,11 @@ def load_destriper_results(
baselines_file_name = folder / __BASELINES_FILE_NAME

with fits.open(baselines_file_name) as inpf:
assert MPI_COMM_GRID.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], (
assert MPI_COMM_WORLD.rank == inpf[0].header["MPIRANK"], (
"You must call load_destriper_results using the "
"same MPI layout that was used for save_destriper_results "
)
assert MPI_COMM_GRID.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], (
assert MPI_COMM_WORLD.size == inpf[0].header["MPISIZE"], (
"You must call load_destriper_results using the "
"same MPI layout that was used for save_destriper_results"
)
Expand Down
Loading
Loading