Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 3 additions & 94 deletions bluemath_tk/core/dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import os

import psutil
import xarray as xr
from dask.distributed import Client, LocalCluster


Expand Down Expand Up @@ -30,97 +30,6 @@ def get_available_ram() -> int:
return psutil.virtual_memory().available


def get_available_cpus() -> int:
"""
Get the available CPU cores in the system.

Returns
-------
int
The number of available CPU cores.
"""

return int(psutil.cpu_count() * 0.5)


def calculate_optimal_chunks(
dataset: xr.Dataset,
cpu_cores_to_use: int = None,
total_ram_percentage_to_use: float = 0.5,
full_dims: list = None,
) -> dict:
"""
Calculate optimal chunk sizes for each variable in an xarray Dataset.
NOTE: This function is not beign used in the code, as it is a first
approximation to test how we could chunk given the hardware.

Parameters
----------
dataset : xr.Dataset
The input dataset containing multiple variables.
cpu_cores_to_use : int, optional
Number of CPU cores to use. If None, half of available cores are used.
Default is None.
total_ram_percentage_to_use : float, optional
Fraction of total RAM to use for chunking. Default is 0.5.
full_dims : list, optional
List of dimension names that should use all values (not be chunked).
Default is None.

Returns
-------
dict
Dictionary with variable names as keys and chunk dictionaries as values.
Example: {'var1': {'time': 1000, 'lat': 50, 'lon': 50}}
"""

# Get number of available CPU cores if not specified
cpu_cores_to_use = cpu_cores_to_use or get_available_cpus()

# Get available memory for chunking
available_mem = get_available_ram()
target_bytes = (available_mem * total_ram_percentage_to_use) / cpu_cores_to_use

full_dims = full_dims or []
chunks_dict = {}

# Process each variable in the dataset
for var_name, da in dataset.data_vars.items():
# Get shape and dtype info
shape = da.shape
dims = da.dims
dtype = da.dtype
bytes_per_elem = np.dtype(dtype).itemsize

# Separate chunked and full dimensions
chunk_dims = [d for d in dims if d not in full_dims]

# Calculate elements for chunked dimensions only
if chunk_dims:
# Calculate total elements considering full dimensions
full_dims_size = np.prod(
[s for d, s in zip(dims, shape) if d in full_dims], dtype=np.float64
)
total_chunk_elements = target_bytes / (bytes_per_elem * full_dims_size)

# Calculate base chunk size for remaining dimensions
chunk_size = int(np.power(total_chunk_elements, 1 / len(chunk_dims)))
else:
chunk_size = 0 # Not used if all dimensions are full

# Create chunks dictionary for this variable
var_chunks = {}
for dim_name, dim_size in zip(dims, shape):
if dim_name in full_dims:
var_chunks[dim_name] = dim_size # Use full dimension
else:
var_chunks[dim_name] = min(chunk_size, dim_size)

chunks_dict[var_name] = var_chunks

return chunks_dict


def setup_dask_client(n_workers: int = None, memory_limit: str = 0.5):
"""
Setup a Dask client with controlled resources.
Expand All @@ -146,7 +55,7 @@ def setup_dask_client(n_workers: int = None, memory_limit: str = 0.5):
"""

if n_workers is None:
n_workers = get_available_cpus()
n_workers = int(os.environ.get("BLUEMATH_NUM_WORKERS", "1"))
if isinstance(memory_limit, float):
memory_limit *= get_available_ram() / get_total_ram()

Expand Down
10 changes: 5 additions & 5 deletions bluemath_tk/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def wrapper(
subset_custom_scale_factor: dict = {},
normalize_target_data: bool = True,
target_custom_scale_factor: dict = {},
num_threads: int = None,
num_workers: int = None,
iteratively_update_sigma: bool = False,
):
if subset_data is None:
Expand Down Expand Up @@ -306,9 +306,9 @@ def wrapper(
raise TypeError("Normalize target data must be a bool")
if not isinstance(target_custom_scale_factor, dict):
raise TypeError("Target custom scale factor must be a dict")
if num_threads is not None:
if not isinstance(num_threads, int) or num_threads <= 0:
raise ValueError("Number of threads must be integer and > 0")
if num_workers is not None:
if not isinstance(num_workers, int) or num_workers <= 0:
raise ValueError("Number of workers must be integer and > 0")
if not isinstance(iteratively_update_sigma, bool):
raise TypeError("Iteratively update sigma must be a boolean")
return func(
Expand All @@ -320,7 +320,7 @@ def wrapper(
subset_custom_scale_factor,
normalize_target_data,
target_custom_scale_factor,
num_threads,
num_workers,
iteratively_update_sigma,
)

Expand Down
144 changes: 117 additions & 27 deletions bluemath_tk/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pickle
import sys
from abc import ABC, abstractmethod
from typing import List, Tuple, Union
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import Any, Callable, List, Tuple, TypeVar, Union

import numpy as np
import pandas as pd
Expand All @@ -27,6 +28,8 @@
standarize,
)

T = TypeVar("T")


class BlueMathModel(ABC):
"""
Expand All @@ -37,6 +40,33 @@ class BlueMathModel(ABC):
def __init__(self) -> None:
self._logger: logging.Logger = None
self._exclude_attributes: List[str] = []
self.num_workers: int = 1

# [UNDER DEVELOPMENT] Below, we try to generalise parallel processing
bluemath_num_workers = os.environ.get("BLUEMATH_NUM_WORKERS", None)
omp_num_threads = os.environ.get("OMP_NUM_THREADS", None)
if bluemath_num_workers is not None:
self.logger.warning(
f"Setting self.num_workers to {bluemath_num_workers} due to BLUEMATH_NUM_WORKERS. \n"
"Change it using self.set_num_processors_to_use method. \n"
"Also setting OMP_NUM_THREADS to 1, to avoid conflicts with BlueMath parallel processing."
)
self.set_num_processors_to_use(num_processors=int(bluemath_num_workers))
self.set_omp_num_threads(num_threads=1)
elif omp_num_threads is not None:
self.logger.warning(
f"Changing variable OMP_NUM_THREADS from {omp_num_threads} to 1. \n"
f"And setting self.num_workers to {omp_num_threads}. \n"
"To avoid conflicts with BlueMath parallel processing."
)
self.set_omp_num_threads(num_threads=1)
self.set_num_processors_to_use(num_processors=int(omp_num_threads))
else:
self.num_workers = 1 # self.get_num_processors_available()
self.logger.warning(
f"Setting self.num_workers to {self.num_workers}. \n"
"Change it using self.set_num_processors_to_use method."
)

def __getstate__(self):
"""Exclude certain attributes from being pickled."""
Expand Down Expand Up @@ -422,6 +452,26 @@ def get_degrees_from_uv(xu: np.ndarray, xv: np.ndarray) -> np.ndarray:

return get_degrees_from_uv(xu, xv)

def set_omp_num_threads(self, num_threads: int) -> None:
"""
Sets the number of threads for OpenMP.

Parameters
----------
num_threads : int
The number of threads.

Warning
-----
- This methos is under development.
"""

os.environ["OMP_NUM_THREADS"] = str(num_threads)

# Re-import numpy if it is already imported
if "numpy" in sys.modules:
importlib.reload(np)

def get_num_processors_available(self) -> int:
"""
Gets the number of processors available.
Expand All @@ -431,11 +481,12 @@ def get_num_processors_available(self) -> int:
int
The number of processors available.

TODO:
TODO
----
- Check whether available processors are used or not.
"""

return int(os.cpu_count() * 0.5)
return int(os.cpu_count() * 0.9)

def set_num_processors_to_use(self, num_processors: int) -> None:
"""
Expand All @@ -446,11 +497,6 @@ def set_num_processors_to_use(self, num_processors: int) -> None:
num_processors : int
The number of processors to use.
If -1, all available processors will be used.

Raises
------
ValueError
If the number of processors requested exceeds the number of processors available
"""

# Retrieve the number of processors available
Expand All @@ -462,7 +508,7 @@ def set_num_processors_to_use(self, num_processors: int) -> None:
elif num_processors <= 0:
raise ValueError("Number of processors must be greater than 0")
elif num_processors > num_processors_available:
raise ValueError(
raise self.logger.warning(
f"Number of processors requested ({num_processors}) "
f"exceeds the number of processors available ({num_processors_available})"
)
Expand All @@ -480,29 +526,73 @@ def set_num_processors_to_use(self, num_processors: int) -> None:
f"is more than 50% of the available processors ({num_processors_available})"
)
self.logger.info(f"Using {percentage * 100}% of the available processors")
os.environ["OMP_NUM_THREADS"] = str(num_processors)

# Re-import numpy if it is already imported
if "numpy" in sys.modules:
importlib.reload(np)
# Set the number of processors to use
self.num_workers = num_processors

def get_num_processors_used(self) -> int:
def parallel_execute(
self,
func: Callable,
items: List[Any],
num_workers: int,
cpu_intensive: bool = False,
**kwargs,
) -> List[T]:
"""
Gets the number of processors used.
Execute a function in parallel using concurrent.futures.

Parameters
----------
func : Callable
Function to execute for each item.
items : List[Any]
List of items to process.
num_workers : int
Number of parallel workers.
cpu_intensive : bool, optional
Whether the function is CPU intensive. Default is False.
**kwargs : dict
Additional keyword arguments for func.

Returns
-------
int
The number of processors used.

Notes
-----
- This method returns the number of processors used by the application.
- 1 is returned if the number of processors used is not set, as is the case of
serial processing like Python's built-in functions.
- Remember that if we run a parallel processing task, the number of processors used
will be the number of processors set by the task, ehich can be > 1.
Examples: np.linalg. or numerical models compiled with OpenMP or MPI.
List[T]
List of results from each function call
"""

return int(os.environ.get("OMP_NUM_THREADS", 1))
results = {}

if cpu_intensive:
self.logger.info("Using ProcessPoolExecutor for CPU intensive tasks.")
with ProcessPoolExecutor(max_workers=num_workers) as executor:
future_to_item = {
executor.submit(func, *item, **kwargs)
if isinstance(item, tuple)
else executor.submit(func, item, **kwargs): i
for i, item in enumerate(items)
}
for future in as_completed(future_to_item):
i = future_to_item[future]
try:
result = future.result()
results[i] = result
except Exception as exc:
self.logger.error(f"Job for {i} generated an exception: {exc}")
else:
self.logger.info("Using ThreadPoolExecutor for I/O bound tasks.")
with ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_item = {
executor.submit(func, *item, **kwargs)
if isinstance(item, tuple)
else executor.submit(func, item, **kwargs): i
for i, item in enumerate(items)
}
for future in as_completed(future_to_item):
i = future_to_item[future]
try:
result = future.result()
results[i] = result
except Exception as exc:
self.logger.error(f"Job for {i} generated an exception: {exc}")

return results
Loading