Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions src/pyrecest/sampling/euclidean_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def sample_rejection(

@staticmethod
def _validate_rejection_args(n_candidates, dim, max_density, bounding_box):
n_candidates = int(n_candidates)
dim = int(dim)
n_candidates = _validate_integral_argument(n_candidates, "n_candidates")
dim = _validate_integral_argument(dim, "dim")
max_density = float(max_density)

if n_candidates < 0:
Expand Down Expand Up @@ -175,8 +175,8 @@ def get_uniform_samples(self, n_samples: int, dim: int):

@staticmethod
def _validate_grid_args(n_samples: int, dim: int):
n_samples = int(n_samples)
dim = int(dim)
n_samples = _validate_integral_argument(n_samples, "n_samples")
dim = _validate_integral_argument(dim, "dim")
if n_samples < 0:
raise ValueError("n_samples must be nonnegative")
if dim < 1:
Expand Down Expand Up @@ -210,6 +210,27 @@ def _is_prime(n):
return True


def _validate_integral_argument(value, name: str) -> int:
"""Return a scalar integer argument without silently truncating floats."""
try:
array_value = np.asarray(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"{name} must be an integer") from exc

if array_value.ndim != 0:
raise ValueError(f"{name} must be a scalar integer")
if np.issubdtype(array_value.dtype, np.bool_):
raise ValueError(f"{name} must be an integer")
if np.issubdtype(array_value.dtype, np.integer):
return int(array_value)
if np.issubdtype(array_value.dtype, np.floating):
float_value = float(array_value)
if np.isfinite(float_value) and float_value.is_integer():
return int(float_value)

raise ValueError(f"{name} must be an integer")


def _validate_gaussian_transform_args(d, covariance, mean):
if covariance is None:
covariance = np.eye(d)
Expand Down Expand Up @@ -393,8 +414,8 @@ def _fibonacci_grid(
xy_gauss : np.ndarray of shape (d, n_points)
Gaussian grid on R^d with the given covariance and mean.
"""
d = int(d)
n_points = int(n_points)
d = _validate_integral_argument(d, "d")
n_points = _validate_integral_argument(n_points, "n_points")
if d < 1:
raise ValueError("d must be positive")
if n_points < 0:
Expand Down Expand Up @@ -510,4 +531,4 @@ def _fibonacci_grid(
C_vals = np.maximum(C_vals, 0.0)
xy_gauss = C_vecs @ np.diag(np.sqrt(C_vals)) @ xy_stdMM + mean.reshape(-1, 1)

return xy_equal, xy_stdMM, xy_gauss
return xy_equal, xy_stdMM, xy_gauss
68 changes: 68 additions & 0 deletions tests/test_euclidean_sampler_integer_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import pytest

from pyrecest.sampling.euclidean_sampler import (
FibonacciGridSampler,
FibonacciRejectionSampler,
HaltonGridSampler,
SobolGridSampler,
)


@pytest.mark.parametrize("sampler", [SobolGridSampler(), HaltonGridSampler()])
def test_qmc_grid_samplers_reject_non_integral_grid_arguments(sampler):
with pytest.raises(ValueError, match="n_samples"):
sampler.get_uniform_samples(2.5, 2)
with pytest.raises(ValueError, match="dim"):
sampler.get_uniform_samples(4, 1.5)


@pytest.mark.parametrize("sampler", [SobolGridSampler(), HaltonGridSampler()])
def test_qmc_grid_samplers_accept_integer_like_scalar_arguments(sampler):
samples = sampler.get_uniform_samples(np.array(4), np.float64(2.0))
assert samples.shape == (4, 2)


def test_fibonacci_grid_sampler_rejects_non_integral_grid_arguments():
sampler = FibonacciGridSampler()

with pytest.raises(ValueError, match="n_points"):
sampler.get_uniform_samples(2.5, 2)
with pytest.raises(ValueError, match="d"):
sampler.get_uniform_samples(4, 1.5)


def test_fibonacci_grid_sampler_accepts_integer_like_scalar_arguments():
samples = FibonacciGridSampler().get_uniform_samples(np.float64(4.0), np.array(2))
assert samples.shape == (4, 2)


def test_fibonacci_rejection_sampler_rejects_non_integral_rejection_arguments():
sampler = FibonacciRejectionSampler()

with pytest.raises(ValueError, match="n_candidates"):
sampler.sample_rejection(
lambda xs: np.ones(xs.shape[0]),
n_candidates=2.5,
dim=2,
max_density=1.0,
)
with pytest.raises(ValueError, match="dim"):
sampler.sample_rejection(
lambda xs: np.ones(xs.shape[0]),
n_candidates=4,
dim=1.5,
max_density=1.0,
)


def test_fibonacci_rejection_sampler_accepts_integer_like_scalar_arguments():
samples, info = FibonacciRejectionSampler().sample_rejection(
lambda xs: np.ones(xs.shape[0]),
n_candidates=np.float64(4.0),
dim=np.array(2),
max_density=1.0,
)

assert samples.shape == (4, 2)
assert info["n_candidates"] == 4
Loading