Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,17 @@ public double inverseSurvivalProbability(double p) {
/** {@inheritDoc} */
@Override
public Sampler createSampler(UniformRandomProvider rng) {
// Map the bounds to a standard normal distribution
final double u = parentNormal.getMean();
final double s = parentNormal.getStandardDeviation();
final double a = (lower - u) / s;
final double b = (upper - u) / s;
// If the truncation covers a reasonable amount of the normal distribution
// then a rejection sampler can be used.
double threshold = REJECTION_THRESHOLD;
// If the truncation is entirely in the upper or lower half then adjust the
// threshold as twice the samples can be used
if (lower >= 0 || upper <= 0) {
if (a >= 0 || b <= 0) {
threshold *= 0.5;
}

Expand All @@ -249,21 +254,16 @@ public Sampler createSampler(UniformRandomProvider rng) {
final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng);
final DoubleSupplier gen;
// Use mirroring if possible
if (lower >= 0) {
if (a >= 0) {
// Return the upper-half of the Gaussian
gen = () -> Math.abs(sampler.sample());
} else if (upper <= 0) {
} else if (b <= 0) {
// Return the lower-half of the Gaussian
gen = () -> -Math.abs(sampler.sample());
} else {
// Return the full range of the Gaussian
gen = sampler::sample;
}
// Map the bounds to a standard normal distribution
final double u = parentNormal.getMean();
final double s = parentNormal.getStandardDeviation();
final double a = (lower - u) / s;
final double b = (upper - u) / s;
// Sample in [a, b] using rejection
return () -> {
double x = gen.getAsDouble();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.commons.statistics.distribution;

import java.time.Duration;
import org.apache.commons.numbers.gamma.Erf;
import org.apache.commons.numbers.gamma.Erfcx;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -351,6 +353,44 @@ void testAdditionalMoments(double lower, double upper,
}
}

/**
* This tests that the sampler can correctly sample values when the range
* is positive but fully below the mean value. This tests the case where
* the rejection sampler threshold is met, and fails due to a timeout
* if the rejection sampler incorrectly triggers the upper-half clause and
* is stuck in an infinite loop.
*/
@Test
void testSamplerPositiveBelowMeanWithRejection() {
// this triggers the rejection-sampler case in TruncatedNormalDistribution.createSampler(...)
final TruncatedNormalDistribution dist = TruncatedNormalDistribution.of(1d, 0.1, 0.7, 0.99);

final double x = Assertions.assertTimeoutPreemptively(Duration.ofSeconds(1),
() -> dist.createSampler(RandomSource.XO_SHI_RO_256_PP.create(123456789L)).sample());

Assertions.assertTrue(x >= dist.getSupportLowerBound() && x <= dist.getSupportUpperBound(),
() -> "Sample outside support: " + x);
}

/**
* This tests that the sampler can correctly sample values when the range
* is negative but fully above the mean value. This tests the case where
* the rejection sampler threshold is met, and fails due to a timeout
* if the rejection sampler incorrectly triggers the lower-half clause and
* is stuck in an infinite loop.
*/
@Test
void testSamplerNegativeAboveMeanWithRejection() {
// this triggers the rejection-sampler case in TruncatedNormalDistribution.createSampler(...)
final TruncatedNormalDistribution dist = TruncatedNormalDistribution.of(-1d, 0.1, -0.99, -0.7);

final double x = Assertions.assertTimeoutPreemptively(Duration.ofSeconds(1),
() -> dist.createSampler(RandomSource.XO_SHI_RO_256_PP.create(123456789L)).sample());

Assertions.assertTrue(x >= dist.getSupportLowerBound() && x <= dist.getSupportUpperBound(),
() -> "Sample outside support: " + x);
}

/**
* Assert the mean of the truncated normal distribution is within the provided relative error.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Original bounds that are fully positive
# STATISTICS-92
parameters = 1.0, 0.1, 0.7, 1.3
# Computed using Python with SciPy v1.16.3:
# mean, std, clip_a, clip_b = 1.0, 0.1, 0.7, 1.3
# a, b = (clip_a - mean) / std, (clip_b - mean) / std
# truncnorm.var(a, b, loc=mean, scale=std)
mean = 1.0
variance = 0.009733369246625417
lower = 0.7
upper = 1.3
cdf.points = \
0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25, 1.35
cdf.values = \
0. , 0. , 0.00487292319299906,\
0.06563450301006861, 0.30801922979837476, 0.6919807702016253 ,\
0.9343654969899313 , 0.995127076807001 , 1.
pdf.values = \
0. , 0. , 0.17575751438109988,\
1.298682133570557 , 3.530184044629268 , 3.530184044629268 ,\
1.2986821335705594 , 0.17575751438109988, 0.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Original bounds that are fully negative
# STATISTICS-92
parameters = -1.0, 0.1, -1.4, -0.8
# Computed using Python with SciPy v1.16.3:
# mean, std, clip_a, clip_b = -1.0, 0.1, -1.4, -0.8
# a, b = (clip_a - mean) / std, (clip_b - mean) / std
# truncnorm.var(a, b, loc=mean, scale=std)
mean = -1.0055112703041382
variance = 0.00885915482691343
lower = -1.4
upper = -0.8
cdf.points = \
-1.45 , -1.3722222222222222, -1.2944444444444443,\
-1.2166666666666666, -1.1388888888888888, -1.0611111111111111,\
-0.9833333333333333, -0.9055555555555554, -0.8277777777777777,\
-0.75
cdf.values = \
0.0000000000000000e+00, 6.8630844903959877e-05,\
1.6229782566954216e-03, 1.5450458063193950e-02,\
8.4322619964621676e-02, 2.7683821471716175e-01,\
5.7935081767532270e-01, 8.4678840615650053e-01,\
9.7977472840069957e-01, 1.0000000000000000e+00
pdf.values = \
0.0000000000000000e+00, 4.0027355759704591e-03,\
5.3494059692712141e-02, 3.9042072314005954e-01,\
1.5561046905609728e+00, 3.3870640463590616e+00,\
4.0261194212948235e+00, 2.6135363401653997e+00,\
9.2650780123937682e-01, 0.0000000000000000e+00
Loading