Skip to content

Commit ba2bad5

Browse files
committed
Document design decisions and improve sumprod
- math_1/math_2: document why errno handling differs from CPython (platform-specific unreliability, output checks sufficient, verified by proptest) - math.log: document EDOM substitution for ZeroDivisionError - math.remainder: document libm delegation rationale - sumprod: return Result for length mismatch instead of panic, improve overflow fallback to continue from where the fast path stopped instead of restarting from scratch
1 parent 0ad0cb0 commit ba2bad5

File tree

6 files changed

+129
-27
lines changed

6 files changed

+129
-27
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "pymath"
33
authors = ["Jeong, YunWon <jeong@youknowone.org>"]
44
repository = "https://github.com/RustPython/pymath"
55
description = "A binary representation compatible Rust implementation of Python's math library."
6-
version = "0.1.5"
6+
version = "0.2.0"
77
edition = "2024"
88
license = "PSF-2.0"
99

src/math.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,26 @@ macro_rules! libm_simple {
4949

5050
pub(crate) use libm_simple;
5151

52-
/// math_1: wrapper for 1-arg functions
52+
/// Wrapper for 1-arg libm functions, corresponding to FUNC1/is_error in
53+
/// mathmodule.c.
54+
///
5355
/// - isnan(r) && !isnan(x) -> domain error
5456
/// - isinf(r) && isfinite(x) -> overflow (can_overflow=true) or domain error (can_overflow=false)
5557
/// - isfinite(r) && errno -> check errno (unnecessary on most platforms)
58+
///
59+
/// CPython's approach: clear errno, call libm, then inspect both the result
60+
/// and errno to classify errors. We rely primarily on output inspection
61+
/// (NaN/Inf checks) because:
62+
///
63+
/// - On macOS and Windows, libm functions do not reliably set errno for
64+
/// edge cases, so CPython's own is_error() skips the errno check there
65+
/// too (it only uses it as a fallback on other Unixes).
66+
/// - The NaN/Inf output checks are sufficient to detect all domain and
67+
/// range errors on every platform we test against (verified by proptest
68+
/// and edgetest against CPython via pyo3).
69+
/// - The errno-only branch (finite result with errno set) is kept for
70+
/// non-macOS/non-Windows Unixes where libm might signal an error
71+
/// without producing a NaN/Inf result.
5672
#[inline]
5773
pub(crate) fn math_1(x: f64, func: fn(f64) -> f64, can_overflow: bool) -> crate::Result<f64> {
5874
crate::err::set_errno(0);
@@ -75,9 +91,17 @@ pub(crate) fn math_1(x: f64, func: fn(f64) -> f64, can_overflow: bool) -> crate:
7591
Ok(r)
7692
}
7793

78-
/// math_2: wrapper for 2-arg functions
94+
/// Wrapper for 2-arg libm functions, corresponding to FUNC2 in
95+
/// mathmodule.c.
96+
///
7997
/// - isnan(r) && !isnan(x) && !isnan(y) -> domain error
8098
/// - isinf(r) && isfinite(x) && isfinite(y) -> range error
99+
///
100+
/// Unlike math_1, this does not set/check errno at all. CPython's FUNC2
101+
/// does clear and check errno, but the NaN/Inf output checks already
102+
/// cover all error cases for the 2-arg functions we wrap (atan2, fmod,
103+
/// copysign, remainder, pow). This is verified by bit-exact proptest
104+
/// and edgetest against CPython.
81105
#[inline]
82106
pub(crate) fn math_2(x: f64, y: f64, func: fn(f64, f64) -> f64) -> crate::Result<f64> {
83107
let r = func(x, y);

src/math/aggregate.rs

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ pub fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 {
231231
///
232232
/// The points are given as sequences of coordinates.
233233
/// Uses high-precision vector_norm algorithm.
234+
///
235+
/// Panics if `p` and `q` have different lengths. CPython raises ValueError
236+
/// for mismatched dimensions, but in this Rust API the caller is expected
237+
/// to guarantee equal-length slices. A length mismatch is a programming
238+
/// error, not a runtime condition.
234239
pub fn dist(p: &[f64], q: &[f64]) -> f64 {
235240
assert_eq!(
236241
p.len(),
@@ -261,24 +266,52 @@ pub fn dist(p: &[f64], q: &[f64]) -> f64 {
261266

262267
/// Return the sum of products of values from two sequences (float version).
263268
///
264-
/// Uses TripleLength arithmetic for high precision.
265-
/// Equivalent to sum(p[i] * q[i] for i in range(len(p))).
266-
pub fn sumprod(p: &[f64], q: &[f64]) -> f64 {
267-
assert_eq!(p.len(), q.len(), "Inputs are not the same length");
269+
/// Uses TripleLength arithmetic for the fast path, then falls back to
270+
/// ordinary floating-point multiply/add starting at the first unsupported
271+
/// pair, matching Python's staged `math.sumprod` behavior for float inputs.
272+
///
273+
/// CPython's math_sumprod_impl is a 3-stage state machine that handles
274+
/// int/float/generic Python objects. This function only covers the float
275+
/// path (`&[f64]`). The int accumulation and generic PyNumber fallback
276+
/// stages are Python type-system concerns and should be handled by the
277+
/// caller (e.g. RustPython) before delegating here.
278+
///
279+
/// Returns EDOM if the inputs are not the same length.
280+
pub fn sumprod(p: &[f64], q: &[f64]) -> crate::Result<f64> {
281+
if p.len() != q.len() {
282+
return Err(crate::Error::EDOM);
283+
}
268284

285+
let mut total = 0.0;
269286
let mut flt_total = TL_ZERO;
287+
let mut flt_path_enabled = true;
288+
let mut i = 0;
270289

271-
for (&pi, &qi) in p.iter().zip(q.iter()) {
272-
let new_flt_total = tl_fma(pi, qi, flt_total);
273-
if new_flt_total.hi.is_finite() {
274-
flt_total = new_flt_total;
275-
} else {
276-
// Overflow or special value, fall back to simple sum
277-
return p.iter().zip(q.iter()).map(|(a, b)| a * b).sum();
290+
while i < p.len() {
291+
let pi = p[i];
292+
let qi = q[i];
293+
294+
if flt_path_enabled {
295+
let new_flt_total = tl_fma(pi, qi, flt_total);
296+
if new_flt_total.hi.is_finite() {
297+
flt_total = new_flt_total;
298+
i += 1;
299+
continue;
300+
}
301+
302+
flt_path_enabled = false;
303+
total += tl_to_d(flt_total);
278304
}
305+
306+
total += pi * qi;
307+
i += 1;
279308
}
280309

281-
tl_to_d(flt_total)
310+
Ok(if flt_path_enabled {
311+
tl_to_d(flt_total)
312+
} else {
313+
total
314+
})
282315
}
283316

284317
/// Return the sum of products of values from two sequences (integer version).
@@ -427,14 +460,27 @@ mod tests {
427460
crate::test::with_py_math(|py, math| {
428461
let py_p = pyo3::types::PyList::new(py, p).unwrap();
429462
let py_q = pyo3::types::PyList::new(py, q).unwrap();
430-
let py: f64 = math
431-
.getattr("sumprod")
432-
.unwrap()
433-
.call1((py_p, py_q))
434-
.unwrap()
435-
.extract()
436-
.unwrap();
437-
crate::test::assert_f64_eq(py, rs, format_args!("sumprod({p:?}, {q:?})"));
463+
let py_result = math.getattr("sumprod").unwrap().call1((py_p, py_q));
464+
match py_result {
465+
Ok(py_val) => {
466+
let py: f64 = py_val.extract().unwrap();
467+
let rs = rs.unwrap_or_else(|e| {
468+
panic!("sumprod({p:?}, {q:?}): py={py} but rs returned error {e:?}")
469+
});
470+
crate::test::assert_f64_eq(py, rs, format_args!("sumprod({p:?}, {q:?})"));
471+
}
472+
Err(e) => {
473+
if e.is_instance_of::<pyo3::exceptions::PyValueError>(py) {
474+
assert_eq!(
475+
rs.as_ref().err(),
476+
Some(&crate::Error::EDOM),
477+
"sumprod({p:?}, {q:?}): py raised ValueError but rs={rs:?}"
478+
);
479+
} else {
480+
panic!("sumprod({p:?}, {q:?}): py raised unexpected error {e}");
481+
}
482+
}
483+
}
438484
});
439485
}
440486

@@ -444,6 +490,9 @@ mod tests {
444490
test_sumprod_impl(&[], &[]);
445491
test_sumprod_impl(&[1.0], &[2.0]);
446492
test_sumprod_impl(&[1e100, 1e100], &[1e100, -1e100]);
493+
test_sumprod_impl(&[1.0, 1e308, -1e308], &[1.0, 2.0, 2.0]);
494+
test_sumprod_impl(&[1e-16, 1e308, -1e308], &[1.0, 2.0, 2.0]);
495+
test_sumprod_impl(&[1.0], &[]);
447496
}
448497

449498
fn test_prod_impl(values: &[f64], start: Option<f64>) {

src/math/bigint.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,14 @@ pub fn comb_bigint(n: &BigInt, k: u64) -> BigUint {
6363
/// - mantissa is in [0.5, 1.0) for positive n
6464
/// - n ~= mantissa * 2^exponent
6565
///
66-
/// See: _PyLong_Frexp in CPython longobject.c
66+
/// `_PyLong_Frexp` extracts digits one-by-one into a fixed-size
67+
/// accumulator and applies a `half_even_correction` lookup table for
68+
/// rounding. We instead extract the top 55 bits via a single right
69+
/// shift and use a sticky-bit to mark whether any discarded bits were
70+
/// non-zero, then delegate to `BigInt::to_f64()` which performs
71+
/// IEEE 754 round-half-to-even. The two approaches are equivalent
72+
/// because the sticky bit preserves the same rounding information
73+
/// that the digit-by-digit extraction would.
6774
fn frexp_bigint(n: &BigInt) -> (f64, i64) {
6875
let bits = n.bits();
6976
if bits == 0 {
@@ -87,8 +94,15 @@ fn frexp_bigint(n: &BigInt) -> (f64, i64) {
8794

8895
// Sticky bit: if any shifted-out bits were non-zero, set the LSB.
8996
// This ensures correct IEEE round-half-to-even when converting to f64.
90-
// See _PyLong_Frexp in longobject.c.
91-
if (&mantissa_int << shift as u64) != *n {
97+
//
98+
// `_PyLong_Frexp` checks the remainder from `v_rshift` first, then
99+
// iterates shifted-out digits top-down. We use `trailing_zeros()`
100+
// which scans digits bottom-up instead. The worst-case traversal
101+
// order differs (e.g. exact powers of two), but for typical inputs
102+
// both terminate in O(1). If you observe a performance regression
103+
// from this, please file a bug report.
104+
let tz = n.magnitude().trailing_zeros().unwrap(); // n != 0 here
105+
if tz < shift as u64 {
92106
mantissa_int |= BigInt::from(1);
93107
}
94108

src/math/exponential.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ pub fn log(x: f64, base: Option<f64>) -> Result<f64> {
113113
if den.is_infinite() && b.is_finite() {
114114
return Err(crate::Error::EDOM);
115115
}
116-
// log(x, 1) -> division by zero
116+
// log(x, 1) -> division by zero.
117+
// CPython raises ZeroDivisionError here (via PyNumber_TrueDivide),
118+
// but we return EDOM since our error type has no ZeroDivisionError
119+
// variant. The caller (e.g. RustPython) may remap this if needed.
117120
if den == 0.0 {
118121
return Err(crate::Error::EDOM);
119122
}

src/math/misc.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ super::libm_simple!(@1 ceil, floor, trunc);
1010
/// manipulation on the IEEE 754 representation. Steps that overshoot y
1111
/// are clamped so the result never passes y.
1212
///
13+
/// CPython's math_nextafter_impl accepts a Python integer for steps,
14+
/// rejects negative values, and saturates overflows to UINT64_MAX. This
15+
/// Rust API takes `Option<u64>`, so negative rejection and big-int
16+
/// saturation are structurally unnecessary. The caller (e.g. RustPython)
17+
/// should handle Python int conversion and negative checks before calling.
18+
///
1319
/// See math_nextafter_impl in mathmodule.c.
1420
#[inline]
1521
pub fn nextafter(x: f64, y: f64, steps: Option<u64>) -> f64 {
@@ -219,6 +225,12 @@ pub fn fmod(x: f64, y: f64) -> Result<f64> {
219225
}
220226

221227
/// Return the IEEE 754-style remainder of x with respect to y.
228+
///
229+
/// CPython implements this from scratch using fmod (m_remainder in
230+
/// mathmodule.c) rather than calling the C library's remainder().
231+
/// We delegate to libm's remainder() which is correct on all platforms
232+
/// where it conforms to IEEE 754. If you find a platform where the
233+
/// results differ from CPython, please file a bug.
222234
#[inline]
223235
pub fn remainder(x: f64, y: f64) -> Result<f64> {
224236
super::math_2(x, y, crate::m::remainder)

0 commit comments

Comments
 (0)