Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/constant/constant_anchor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Create an anchored query for ultra-fast constant interpolation at a fixed point.
- `x`: Grid points (must match grid used for interpolant construction)
- `xq`: Query point (scalar)
- `::Val{:constant}`: Type tag to distinguish from other anchor types
- `wrap`: If true, wrap `xq` to domain [x[1], x[end]) before anchoring.
- `wrap`: If true, wrap `xq` to closed domain [x[1], x[end]] before anchoring.
Used for `extrap=WrapExtrap()` mode.

# Returns
Expand Down Expand Up @@ -136,7 +136,7 @@ the grid used for interpolant construction.
- `x`: Grid points (must match interpolant's grid)
- `xq`: Query points (any Real type, auto-promoted to T)
- `::Val{:constant}`: Type tag
- `wrap`: If true, wrap query points to domain [x[1], x[end]) before anchoring.
- `wrap`: If true, wrap query points to closed domain [x[1], x[end]] before anchoring.

# Example
```julia
Expand Down Expand Up @@ -179,7 +179,7 @@ the caller reuses `buffer`. Writes `length(xq)` entries.
- `x::AbstractVector{T}`: Grid points (must match interpolant's grid)
- `xq::AbstractVector`: Query points (any Real type, auto-promoted to T)
- `::Val{:constant}`: Type tag for constant interpolation
- `wrap::Bool=false`: If true, wrap query points to domain [x[1], x[end])
- `wrap::Bool=false`: If true, wrap query points to closed domain [x[1], x[end]]

# Returns
The same `buffer` object, filled with anchored queries.
Expand Down Expand Up @@ -278,7 +278,9 @@ end
y::AbstractVector, x_last, aq::_ConstantAnchoredQuery,
op::AbstractEvalOp, side_param::AbstractSide, ::AbstractExtrap
)
aq.xq == x_last && return (op isa EvalValue ? (@inbounds y[end]) : 0 * first(y))
# Right-edge short-circuit: `y[aq.idxR]` resolves to `y[end]` for non-periodic
# (idxR == n) and to the cyclic `y[1]` for `_ExclusivePeriodicAxis` (idxR == 1).
aq.xq == x_last && return (op isa EvalValue ? (@inbounds y[aq.idxR]) : 0 * first(y))
@inbounds return _constant_kernel(op, y[aq.idxL], y[aq.idxR], aq.h, aq.dL, side_param)
end

Expand All @@ -288,7 +290,9 @@ end
op::AbstractEvalOp, side_param::AbstractSide, ::NoExtrap
)
aq.state != IN_DOMAIN && throw(DomainError(aq.xq, "query point outside domain"))
aq.xq == x_last && return (op isa EvalValue ? (@inbounds y[end]) : 0 * first(y))
# Right-edge short-circuit: `y[aq.idxR]` resolves to `y[end]` for non-periodic
# (idxR == n) and to the cyclic `y[1]` for `_ExclusivePeriodicAxis` (idxR == 1).
aq.xq == x_last && return (op isa EvalValue ? (@inbounds y[aq.idxR]) : 0 * first(y))
@inbounds return _constant_kernel(op, y[aq.idxL], y[aq.idxR], aq.h, aq.dL, side_param)
end

Expand All @@ -301,7 +305,9 @@ end
y_bnd = aq.state == OOB_LEFT ? first(y) : last(y)
return _eval_extrapolation(op, y_bnd, extrap, aq.xq)
end
aq.xq == x_last && return (op isa EvalValue ? (@inbounds y[end]) : 0 * first(y))
# Right-edge short-circuit: `y[aq.idxR]` resolves to `y[end]` for non-periodic
# (idxR == n) and to the cyclic `y[1]` for `_ExclusivePeriodicAxis` (idxR == 1).
aq.xq == x_last && return (op isa EvalValue ? (@inbounds y[aq.idxR]) : 0 * first(y))
@inbounds return _constant_kernel(op, y[aq.idxL], y[aq.idxR], aq.h, aq.dL, side_param)
end

Expand All @@ -328,8 +334,10 @@ end
x_min, x_max = first(itp.x), last(itp.x)
throw(DomainError(aq.xq, "query point outside domain [$x_min, $x_max]"))
end
# Right-edge short-circuit: `idxR` resolves to `n` for non-periodic and to
# `n+1` (cyclic `y[1]`) on `:exclusive` PeriodicBC's extended-grid persistent path.
if aq.xq == last(itp.x)
return op isa EvalValue ? (@inbounds itp.y[end]) : zero(T)
return op isa EvalValue ? (@inbounds itp.y[aq.idxR]) : zero(T)
end
@inbounds return _constant_kernel(op, itp.y[aq.idxL], itp.y[aq.idxR], aq.h, aq.dL, itp.side)
end
Expand Down
12 changes: 10 additions & 2 deletions src/constant/constant_oneshot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end
return _constant_eval_at_point(x, y, xi, InBounds(), side, op, searcher)
end

# WrapExtrap: wrap query to domain → search + kernel.
# WrapExtrap: wrap query to domain → right-edge short-circuit → search + kernel.
# `_wrap_to_domain(xi, x, ::WrapExtrap)` reads `(first(x), last(x))` from the
# axis — `_ExclusivePeriodicAxis` exposes the precomputed virtual endpoint via
# `last(g)`, so the wrap domain naturally spans one period for `:exclusive`.
Expand All @@ -101,6 +101,14 @@ end
searcher::S
) where {Tg, Tv, Tq <: Real, S <: Searcher}
xi_wrapped = _wrap_to_domain(xi, x)
# Right-edge short-circuit (closed-domain): `xi == last(x)` collapses
# uniformly to `last(y)`, bypassing side semantics. Mirrors the InBounds
# core's identical guard and the persistent anchor path's `aq.xq == x_last`
# short-circuit so scalar oneshot agrees with the persistent interpolant at
# the exact boundary. `last(_ExclusivePeriodicData) = inner[1]` so `:exclusive`
# cyclic wrap is preserved; raw Vector yields `y[n]`.
_extract_primal(xi_wrapped) == _extract_primal(last(x)) &&
return op isa EvalValue ? last(y) : 0 * first(y)
idx, idx_R, xL, xR = search_interval(searcher, x, xi_wrapped)
dL = xi_wrapped - xL
# Unwrap data once: `search_interval` already resolved the seam (idx_R = 1
Expand Down Expand Up @@ -137,7 +145,7 @@ Constant (step/piecewise constant) interpolation at a single point.
- `NoExtrap()` (default): throws DomainError if outside domain
- `ClampExtrap()`: clamp to boundary values
- `ExtendExtrap()`: same as ClampExtrap (slope=0)
- `WrapExtrap()`: wrap to [x_min, x_max)
- `WrapExtrap()`: wrap to closed domain [x_min, x_max] (xq==x_max returns y[end])
- `side::AbstractSide`: Side selection
- `NearestSide()` (default): nearest neighbor (left tie-breaking at midpoint)
- `LeftSide()`: always use left value
Expand Down
7 changes: 5 additions & 2 deletions src/core/anchor_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ Dual type. The interval search uses `_extract_primal(xq)` for comparisons.
xq_primal = _extract_primal(xq)

# Handle wrapping (for extrap=WrapExtrap() or periodic mode)
# Generic _wrap_to_domain handles AD primal extraction and returns Tg
if wrap && (xq_primal < x_min || xq_primal >= x_max)
# Generic _wrap_to_domain handles AD primal extraction and returns Tg.
# Closed-domain convention: `xq == x_max` is in-domain — no wrap needed.
# Only strictly-OOB queries (`xq < x_min` or `xq > x_max`) take the slow
# `mod()` path inside `_wrap_to_domain` (which itself uses `xi <= x_max`).
if wrap && (xq_primal < x_min || xq_primal > x_max)
xq = _wrap_to_domain(xq, x_min, x_max)
xq_primal = xq # xq is now Tg, no need for _extract_primal
end
Expand Down
17 changes: 14 additions & 3 deletions src/core/eval_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,19 @@ struct ExtendExtrap <: AbstractExtrap end
"""
WrapExtrap <: AbstractExtrap

Wrap extrapolation — wraps queries into the domain `[first(x), last(x))` using
modular arithmetic. For periodic data.
Wrap extrapolation — wraps queries into the closed domain `[first(x), last(x)]`
using modular arithmetic. For periodic data.

Closed-domain convention: `xq == last(x)` is an in-domain boundary query
(returns the right-corner value, e.g. `y[end]` for non-periodic data); only
strictly-OOB queries take the `mod()` path. Matches `ClampExtrap`/`FillExtrap`'s
closed convention. `:inclusive` PeriodicBC: forward **value** is invariant
(validated `y[1] ≈ y[end]`), but **adjoint** sensitivity at `xq == last(x)`
now scatters to slot `n` instead of slot `1` — delta-equivalent under the
`:inclusive` cycle constraint, but observably different if downstream code
does not enforce `y[1] == y[end]` on `f_bar`. `:exclusive` PeriodicBC is
fully invariant (forward + adjoint) via the seam-aware
`_ExclusivePeriodicAxis.search_interval` returning `idx_R = 1` at `xq >= inner[n]`.
Comment thread
mgyoo86 marked this conversation as resolved.

Tag struct with no fields: the wrap domain is read directly from the axis at
query time via `first(x)` / `last(x)`. After the surface-API axis resolution
Expand All @@ -307,7 +318,7 @@ itp = linear_interp(x, y; extrap=WrapExtrap())
struct WrapExtrap <: AbstractExtrap end

# Backward-compat: previous API was `WrapExtrap(x)` materializing the wrap
# domain `[first(x), last(x))` into the struct's `_x_min`/`_x_max` fields.
# domain `[first(x), last(x)]` into the struct's `_x_min`/`_x_max` fields.
# After the tag-struct refactor (axis IS the source of truth for the wrap
# domain), the axis-passing form is redundant — the kernel reads `(first(x),
# last(x))` directly via `_wrap_to_domain(xq, x)`. This shim accepts and
Expand Down
15 changes: 12 additions & 3 deletions src/core/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@
"""
_wrap_to_domain(xi::FT, x_min::FT, x_max::FT) where {FT<:AbstractFloat}

Wrap a query point `xi` to the domain [x_min, x_max).
Wrap a query point `xi` to the domain [x_min, x_max].
Used for periodic boundary conditions and extrap=WrapExtrap().

Closed-domain convention: `xi == x_max` is an in-domain boundary query
(returns `xi` unchanged); only strictly-OOB queries (`xi < x_min` or
`xi > x_max`) take the cold `mod()` path. `PeriodicBC{:inclusive}` is
forward-**value**-invariant because `y[1] ≈ y[end]` by construction; the
adjoint sensitivity at the seam now scatters to slot `n` instead of slot `1`
(delta-equivalent under the cycle constraint). `:exclusive` is fully invariant
(forward + adjoint) because `_ExclusivePeriodicAxis.search_interval` already
returns `idx_R = 1` for `xq >= inner[n]` at the seam.

Optimized: skips expensive `mod()` when xi is already in domain.
"""
@inline function _wrap_to_domain(xi::Tg, x_min::Tg, x_max::Tg) where {Tg}
Expand All @@ -25,7 +34,7 @@ Optimized: skips expensive `mod()` when xi is already in domain.
# bloat the caller (every WrapExtrap eval kernel) with mod-related
# asm. On constant rng+perEx persistent (3-4 ns baseline, 138 lines
# before split), this collapses the eval kernel to ~75 lines.
if (xi >= x_min) && (xi < x_max)
if (xi >= x_min) && (xi <= x_max)
return xi
end
return _wrap_to_domain_slow(xi, x_min, x_max)
Expand All @@ -37,7 +46,7 @@ end
@inline function _wrap_to_domain(xi::Real, x_min::Tg, x_max::Tg) where {Tg}
xi_primal = _extract_primal(xi)
# Fast path: already in domain, return original xi (preserves Dual type for AD)
if (xi_primal >= x_min) && (xi_primal < x_max)
if (xi_primal >= x_min) && (xi_primal <= x_max)
return xi
end
return _wrap_to_domain_slow(xi, x_min, x_max)
Expand Down
29 changes: 4 additions & 25 deletions src/core/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,24 +491,16 @@ end
"No-op vector domain check for non-NoExtrap modes: pass-through extrap."
@inline _check_domain(::AbstractVector, ::AbstractVector{<:Real}, extrap::AbstractExtrap) = extrap

# Clamp / Fill batch fast path: closed `[first, last]` — `last` is in-domain
# for clamp/fill semantics (no clamping or filling at the boundary).
# Closed-domain batch fast path: every OOB policy (`ClampExtrap`, `FillExtrap`,
# `WrapExtrap`) treats `[first(x), last(x)]` as the in-domain interval, so they
# share one batch promotion to `InBounds()`.
@inline function _check_domain(
x::AbstractVector, xi::AbstractVector{<:Real},
e::Union{ClampExtrap, FillExtrap}
e::Union{ClampExtrap, FillExtrap, WrapExtrap}
)
return _is_all_inbounds(x, xi) ? InBounds() : e
end

# WrapExtrap batch fast path: half-open `[first, last)` — `last` wraps to
# `first` per `_wrap_to_domain`'s `xi < x_max` fast-path check. Promoting a
# batch containing `last(x)` to InBounds would skip that wrap and return
# `y[last]` instead of `y[first]` (silent semantic regression). The
# half-open variant uses strict `<` to keep `last` in the wrap-needed set.
@inline function _check_domain(x::AbstractVector, xi::AbstractVector{<:Real}, e::WrapExtrap)
return _is_all_inbounds_halfopen(x, xi) ? InBounds() : e
end

"""
True iff every element of `queries` lies in the closed domain
`[first(x), last(x)]`. Enables batch-level fast paths that elide per-query
Expand Down Expand Up @@ -546,19 +538,6 @@ end
return minimum(queries) >= x.domain_lo && maximum(queries) <= x.domain_hi
end

# Half-open variant for WrapExtrap: `last(x)` belongs to the wrap-needed
# set because `_wrap_to_domain`'s fast path uses strict `xi < x_max`.
@inline function _is_all_inbounds_halfopen(x::AbstractVector, queries::AbstractVector{<:Real})
isempty(queries) && return true
return minimum(queries) >= _extract_primal(first(x)) &&
maximum(queries) < _extract_primal(last(x))
end

@inline function _is_all_inbounds_halfopen(x::_CachedRange, queries::AbstractVector{<:Real})
isempty(queries) && return true
return minimum(queries) >= x.domain_lo && maximum(queries) < x.domain_hi
end

# ========================================
# Extrapolation value helpers (shared by all interpolation methods)
# ========================================
Expand Down
2 changes: 1 addition & 1 deletion src/cubic/cubic_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ function _build_cubic_adjoint_periodic(
# promoted :exclusive, :inclusive for direct user input.
cache = _get_cubic_cache(x_ext, _bc_after_extend(bc), _effective_autocache(autocache, Tg))

# Build anchored queries with wrapping (queries outside domain → wrap to [x[1], x[end]))
# Build anchored queries with wrapping (queries outside closed domain → wrap to [x[1], x[end]])
anchors = _anchor_query(cache.x, xq, Val(:cubic), true)

return CubicAdjoint(cache, anchors, cache.bc)
Expand Down
8 changes: 4 additions & 4 deletions src/cubic/cubic_anchor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ Create an anchored query for ultra-fast cubic spline evaluation at a fixed point
- `x`: Grid points (must match grid used for interpolant construction)
- `xq`: Query point (scalar, can be Float or ForwardDiff.Dual for AD)
- `::Val{:cubic}`: Type tag to distinguish from other anchor types
- `wrap`: If true, wrap `xq` to domain [x[1], x[end]) before anchoring.
- `wrap`: If true, wrap `xq` to closed domain [x[1], x[end]] before anchoring.
Used for `extrap=WrapExtrap()` mode. Distinct from `PeriodicBC` (boundary condition).

# Returns
Expand Down Expand Up @@ -243,7 +243,7 @@ the grid used for interpolant construction.
- `x`: Grid points (must match interpolant's grid)
- `xq`: Query points (any Real type, auto-promoted to T)
- `::Val{:cubic}`: Type tag to distinguish from other anchor types
- `wrap`: If true, wrap query points to domain [x[1], x[end]) before anchoring.
- `wrap`: If true, wrap query points to closed domain [x[1], x[end]] before anchoring.
Used for `extrap=WrapExtrap()` mode. Distinct from `PeriodicBC` (boundary condition).

# Example
Expand Down Expand Up @@ -292,7 +292,7 @@ In-place version of `_anchor_query(x, xq, Val(:cubic))` for zero-allocation pool
- `x::AbstractVector{Tg}`: Grid points (must match interpolant's grid)
- `xq::AbstractVector{Tq}`: Query points (must match buffer's query type)
- `::Val{:cubic}`: Type tag for cubic interpolation
- `wrap::Bool=false`: If true, wrap query points to domain [x[1], x[end])
- `wrap::Bool=false`: If true, wrap query points to closed domain [x[1], x[end]]

# Returns
The same `buffer` object, filled with anchored queries.
Expand Down Expand Up @@ -363,7 +363,7 @@ while preserving the full Dual value for weight computation.
# `_anchor_loc` discards `idx_R` from `search_interval`'s 4-tuple, so this
# path always assumes `idxR = idxL + 1`. Valid only for:
# - non-periodic queries (no seam dispatch in the searcher),
# - `WrapExtrap` queries (wrap maps into `[first(x), last(x))`, no seam),
# - `WrapExtrap` queries (wrap maps into `[first(x), last(x)]`, no seam),
# - periodic queries on a *post-extension* (n+1) grid (idxL+1 ≤ n+1).
# Periodic-exclusive callers on a raw n-size grid MUST bypass this path
# and build the anchor from `search_interval`'s 4-tuple to preserve the
Expand Down
4 changes: 2 additions & 2 deletions src/linear/linear_anchor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Create an anchored query for ultra-fast linear interpolation at a fixed point.
- `x`: Grid points (must match grid used for interpolant construction)
- `xq`: Query point (scalar)
- `::Val{:linear}`: Type tag to distinguish from cubic anchor
- `wrap`: If true, wrap `xq` to domain [x[1], x[end]) before anchoring.
- `wrap`: If true, wrap `xq` to closed domain [x[1], x[end]] before anchoring.
Used for `extrap=WrapExtrap()` mode.

# Returns
Expand Down Expand Up @@ -239,7 +239,7 @@ the caller reuses `buffer`. Writes `length(xq)` entries.
- `x::AbstractVector{Tg}`: Grid points (must match interpolant's grid)
- `xq::AbstractVector`: Query points (any Real type)
- `::Val{:linear}`: Type tag for linear interpolation
- `wrap::Bool=false`: If true, wrap query points to domain [x[1], x[end])
- `wrap::Bool=false`: If true, wrap query points to closed domain [x[1], x[end]]

# Precision Preservation
The outer `_LinearAnchoredQuery` constructor promotes via `promote_type(S, Tg)`,
Expand Down
8 changes: 5 additions & 3 deletions src/linear/linear_oneshot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ end
op::O,
searcher::S
) where {Tg, O <: AbstractEvalOp, S <: Searcher}
# Wrap domain comes directly from the axis: `[first(x), last(x))`. For
# Wrap domain comes directly from the axis: `[first(x), last(x)]`. For
# `_ExclusivePeriodicAxis`, `last(x)` is the precomputed virtual endpoint
# (`inner[1] + period`), so the domain extends one period beyond the raw
# grid as required for `:exclusive` periodic. Hoisting once outside the
Expand All @@ -124,8 +124,10 @@ end
x_min, x_max = first(x), last(x)
qmin, qmax = minimum(x_targets), maximum(x_targets)

if qmin >= x_min && qmax < x_max
# Fast path: all queries inside domain — use extension (no wrap overhead)
if qmin >= x_min && qmax <= x_max
# Fast path: all queries inside the closed domain `[first(x), last(x)]`
# — use extension (no wrap overhead). Exact `qmax == x_max` is in-domain,
# so `_wrap_to_domain` would be a no-op anyway — route straight in.
@inbounds for i in eachindex(x_targets, output)
output[i] = _linear_eval_at_point(x, y, x_targets[i], ExtendExtrap(), op, searcher)
end
Expand Down
6 changes: 3 additions & 3 deletions src/quadratic/quadratic_anchor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Create an anchored query for ultra-fast quadratic interpolation at a fixed point
- `x`: Grid points (must match grid used for interpolant construction)
- `xq`: Query point (scalar, can be Float or ForwardDiff.Dual for AD)
- `::Val{:quadratic}`: Type tag to distinguish from other anchor types
- `wrap`: If true, wrap `xq` to domain [x[1], x[end]) before anchoring.
- `wrap`: If true, wrap `xq` to closed domain [x[1], x[end]] before anchoring.
Used for `extrap=WrapExtrap()` mode.

# Returns
Expand Down Expand Up @@ -116,7 +116,7 @@ the grid used for interpolant construction.
- `x`: Grid points (must match interpolant's grid)
- `xq`: Query points (any Real type, auto-promoted to T)
- `::Val{:quadratic}`: Type tag
- `wrap`: If true, wrap query points to domain [x[1], x[end]) before anchoring.
- `wrap`: If true, wrap query points to closed domain [x[1], x[end]] before anchoring.

# Example
```julia
Expand Down Expand Up @@ -161,7 +161,7 @@ In-place version of `_anchor_query(x, xq, Val(:quadratic))` for zero-allocation
- `x::AbstractVector{T}`: Grid points (must match interpolant's grid)
- `xq::AbstractVector`: Query points (any Real type, auto-promoted to T)
- `::Val{:quadratic}`: Type tag for quadratic interpolation
- `wrap::Bool=false`: If true, wrap query points to domain [x[1], x[end])
- `wrap::Bool=false`: If true, wrap query points to closed domain [x[1], x[end]]

# Returns
The same `buffer` object, filled with anchored queries.
Expand Down
16 changes: 13 additions & 3 deletions test/test_anchor_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,22 @@
@test loc_right.state == FI.IN_DOMAIN
end

@testset "_anchor_loc — wrap at x_max" begin
@testset "_anchor_loc — closed boundary at x_max (no wrap)" begin
x = collect(range(0.0, 1.0, 11))
# xq == x_max with wrap → should wrap to x_min
# Closed semantics: xq == x_max stays in-domain, lands on the last cell.
loc = FI._anchor_loc(x, 1.0, true)
@test loc.state == FI.IN_DOMAIN
@test loc.xq ≈ 0.0
@test loc.xq ≈ 1.0
@test loc.idx == length(x) - 1 # last cell (n-1)
@test loc.xR ≈ 1.0
end

@testset "_anchor_loc — strictly OOB right still wraps" begin
x = collect(range(0.0, 1.0, 11))
# q = 1.25 = x_min + 1.25*period → wraps to 0.25
loc = FI._anchor_loc(x, 1.25, true)
@test loc.state == FI.IN_DOMAIN
@test loc.xq ≈ 0.25
end

# ========================================
Expand Down
Loading
Loading