Skip to content

Conversation

@mhauru
Copy link
Member

@mhauru mhauru commented Nov 20, 2025

I decided that rather than take over VarInfo like in #1074, the first use case of VarNamedTuple should be replacing the NamedTuple/Dict combo in FastLDF. That's what this PR does.

This is still work in progress:

  • Documentation is lacking/out of date
  • There's dead code, and unnecessarily complex code
  • Performance on Julia v1.11 needs fixing
  • There's type piracy
  • This doesn't handle Colons in VarNames.

However, tests seem to pass, so I'm putting this up. I ran the familiar FastLDF benchmarks from #1132, adapted a bit. Source code:

Details
module VNTBench

using DynamicPPL, Distributions, LogDensityProblems, Chairmarks, LinearAlgebra
using ADTypes, ForwardDiff, ReverseDiff
@static if VERSION < v"1.12"
    using Enzyme, Mooncake
end

const adtypes = @static if VERSION < v"1.12"
    [
        ("FD", AutoForwardDiff()),
        ("RD", AutoReverseDiff()),
        ("MC", AutoMooncake()),
        ("EN" => AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const))
    ]
else
    [
        ("FD", AutoForwardDiff()),
        ("RD", AutoReverseDiff()),
    ]
end

function benchmark_ldfs(model; skip=Union{})
    vi = VarInfo(model)
    x = vi[:]
    ldf_no = DynamicPPL.LogDensityFunction(model, getlogjoint, vi)
    fldf_no = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi)
    @assert LogDensityProblems.logdensity(ldf_no, x)  LogDensityProblems.logdensity(fldf_no, x)
    median_new = median(@be LogDensityProblems.logdensity(fldf_no, x))
    print("           FastLDF: eval      ----  ")
    display(median_new)
    for name_adtype in adtypes
        name, adtype = name_adtype
        adtype isa skip && continue
        ldf = DynamicPPL.LogDensityFunction(model, getlogjoint, vi; adtype=adtype)
        fldf = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi; adtype=adtype)
        ldf_grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
        fldf_grad = LogDensityProblems.logdensity_and_gradient(fldf, x)
        @assert ldf_grad[2]  fldf_grad[2]
        median_new = median(@be LogDensityProblems.logdensity_and_gradient(fldf, x))
        print("           FastLDF: grad ($name) ----  ")
        display(median_new)
    end
end

println("Trivial model")
@model f() = x ~ Normal()
benchmark_ldfs(f())

println("Eight schools")
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(y)), tau^2 * I)
    for i in eachindex(y)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
benchmark_ldfs(eight_schools(y, sigma))

println("IndexLenses, dim=1_000")
@model function badvarnames()
    N = 1_000
    x = Vector{Float64}(undef, N)
    for i in 1:N
        x[i] ~ Normal()
    end
end
benchmark_ldfs(badvarnames())

println("Submodel")
@model function inner()
    m ~ Normal(0, 1)
    s ~ Exponential()
    return (m=m, s=s)
end
@model function withsubmodel()
    params ~ to_submodel(inner())
    y ~ Normal(params.m, params.s)
    1.0 ~ Normal(y)
end
benchmark_ldfs(withsubmodel())

end

Results on Julia v1.12:

Details
On base(breaking):
Trivial model
           FastLDF: eval      ----  18.047 ns
           FastLDF: grad (FD) ----  51.805 ns (3 allocs: 96 bytes)
           FastLDF: grad (RD) ----  3.157 μs (45 allocs: 1.531 KiB)
Eight schools
           FastLDF: eval      ----  165.723 ns (4 allocs: 256 bytes)
           FastLDF: grad (FD) ----  685.846 ns (11 allocs: 2.594 KiB)
           FastLDF: grad (RD) ----  39.959 μs (562 allocs: 20.562 KiB)
IndexLenses, dim=1_000
           FastLDF: eval      ----  24.250 μs (14 allocs: 8.312 KiB)
           FastLDF: grad (FD) ----  6.296 ms (1516 allocs: 11.197 MiB)
           FastLDF: grad (RD) ----  2.577 ms (38029 allocs: 1.321 MiB)
Submodel
           FastLDF: eval      ----  57.568 ns
           FastLDF: grad (FD) ----  179.448 ns (3 allocs: 112 bytes)
           FastLDF: grad (RD) ----  10.750 μs (145 allocs: 5.062 KiB)

On this branch:
Trivial model
           FastLDF: eval      ----  11.869 ns
           FastLDF: grad (FD) ----  53.264 ns (3 allocs: 96 bytes)
           FastLDF: grad (RD) ----  3.273 μs (45 allocs: 1.531 KiB)
Eight schools
           FastLDF: eval      ----  203.159 ns (4 allocs: 256 bytes)
           FastLDF: grad (FD) ----  718.750 ns (11 allocs: 2.594 KiB)
           FastLDF: grad (RD) ----  39.792 μs (562 allocs: 20.562 KiB)
IndexLenses, dim=1_000
           FastLDF: eval      ----  9.181 μs (2 allocs: 8.031 KiB)
           FastLDF: grad (FD) ----  4.235 ms (508 allocs: 11.174 MiB)
           FastLDF: grad (RD) ----  2.560 ms (38017 allocs: 1.321 MiB)
Submodel
           FastLDF: eval      ----  49.660 ns
           FastLDF: grad (FD) ----  221.359 ns (3 allocs: 112 bytes)
           FastLDF: grad (RD) ----  10.667 μs (148 allocs: 5.219 KiB)

Same thing but in Julia v1.11:

Details
On base(breaking):
Trivial model
           FastLDF: eval      ----  11.082 ns
           FastLDF: grad (FD) ----  53.747 ns (3 allocs: 96 bytes)
           FastLDF: grad (RD) ----  3.069 μs (46 allocs: 1.562 KiB)
           FastLDF: grad (MC) ----  221.910 ns (2 allocs: 64 bytes)
           FastLDF: grad (EN) ----  128.970 ns (2 allocs: 64 bytes)
Eight schools
           FastLDF: eval      ----  164.326 ns (4 allocs: 256 bytes)
           FastLDF: grad (FD) ----  690.049 ns (11 allocs: 2.594 KiB)
           FastLDF: grad (RD) ----  39.250 μs (562 allocs: 20.562 KiB)
           FastLDF: grad (MC) ----  1.082 μs (10 allocs: 656 bytes)
           FastLDF: grad (EN) ----  733.325 ns (13 allocs: 832 bytes)
IndexLenses, dim=1_000
           FastLDF: eval      ----  33.458 μs (15 allocs: 8.344 KiB)
           FastLDF: grad (FD) ----  6.652 ms (1516 allocs: 11.197 MiB)
           FastLDF: grad (RD) ----  2.488 ms (38028 allocs: 1.321 MiB)
           FastLDF: grad (MC) ----  89.583 μs (21 allocs: 24.469 KiB)
           FastLDF: grad (EN) ----  92.833 μs (20 allocs: 102.531 KiB)
Submodel
           FastLDF: eval      ----  70.884 ns
           FastLDF: grad (FD) ----  135.958 ns (3 allocs: 112 bytes)
           FastLDF: grad (RD) ----  10.563 μs (148 allocs: 5.188 KiB)
           FastLDF: grad (MC) ----  481.250 ns (2 allocs: 80 bytes)
           FastLDF: grad (EN) ----  344.612 ns (2 allocs: 80 bytes)

On this branch:
Trivial model
           FastLDF: eval      ----  1.309 μs (27 allocs: 800 bytes)
           FastLDF: grad (FD) ----  1.522 μs (30 allocs: 960 bytes)
           FastLDF: grad (RD) ----  4.667 μs (71 allocs: 2.344 KiB)
           FastLDF: grad (MC) ----  358.143 ns (7 allocs: 224 bytes)
           FastLDF: grad (EN) ----  130.768 ns (2 allocs: 64 bytes)
Eight schools
           FastLDF: eval      ----  164.326 ns (4 allocs: 256 bytes)
           FastLDF: grad (FD) ----  645.378 ns (11 allocs: 2.594 KiB)
           FastLDF: grad (RD) ----  39.541 μs (562 allocs: 20.562 KiB)
           FastLDF: grad (MC) ----  1.043 μs (10 allocs: 656 bytes)
           FastLDF: grad (EN) ----  747.925 ns (13 allocs: 832 bytes)
IndexLenses, dim=1_000
           FastLDF: eval      ----  9.430 μs (3 allocs: 8.062 KiB)
           FastLDF: grad (FD) ----  4.616 ms (508 allocs: 11.174 MiB)
           FastLDF: grad (RD) ----  2.467 ms (38016 allocs: 1.321 MiB)
           FastLDF: grad (MC) ----  73.292 μs (9 allocs: 24.188 KiB)
           FastLDF: grad (EN) ----  72.875 μs (8 allocs: 102.250 KiB)
Submodel
           FastLDF: eval      ----  52.213 ns
           FastLDF: grad (FD) ----  107.166 ns (3 allocs: 112 bytes)
           FastLDF: grad (RD) ----  10.521 μs (142 allocs: 5.078 KiB)
           FastLDF: grad (MC) ----  453.493 ns (2 allocs: 80 bytes)
           FastLDF: grad (EN) ----  320.367 ns (2 allocs: 80 bytes)

So on 1.12 all looks good: This is a bit faster than the old version, substantial faster when there are a lot of IndexLenses, as it should. On 1.11 performance is destroyed, probably because type inference fails/gives up, and I need to fix that.

The main point of this PR is not performance, but having a general data structure for storing information keyed by VarNames, so I'm happy as long as performance doesn't degrade. Next up would be using this same data structure for ConditionContext (hoping to fix #1148), ValuesAsInModelAcc, maybe some other Accumulators, InitFromParams, GibbsContext, and finally to implement an AbstractVarInfo type.

I'll update the docs page with more information about what the current design is that I've implemented, but the one sentence summary is that it's nested NamedTuples, and then whenever we meet IndexLenses, it's an Array for the values together with a mask-Array that marks which values are valid values and which are just placeholders.

I think I know how to fix all the current short-comings, except for Colons in VarNames. Setting a value in a VNT with a Colon could be done, but getting seems ill-defined, at least without providing further information about the size the value should be.

vnt = VarNamedTuple(
vnt = setindex!!(vnt, 1, @varname(x[2]))
vnt = setindex!!(vnt, 1, @varname(x[4]))
getindex(@varname(x[:])  # What should this return?

cc @penelopeysm, though this isn't ready for reviews yet.

penelopeysm and others added 21 commits October 21, 2025 18:08
* Remove NodeTrait

* Changelog

* Fix exports

* docs

* fix a bug

* Fix doctests

* Fix test

* tweak changelog
* Fast Log Density Function

* Make it work with AD

* Optimise performance for identity VarNames

* Mark `get_range_and_linked` as having zero derivative

* Update comment

* make AD testing / benchmarking use FastLDF

* Fix tests

* Optimise away `make_evaluate_args_and_kwargs`

* const func annotation

* Disable benchmarks on non-typed-Metadata-VarInfo

* Fix `_evaluate!!` correctly to handle submodels

* Actually fix submodel evaluate

* Document thoroughly and organise code

* Support more VarInfos, make it thread-safe (?)

* fix bug in parsing ranges from metadata/VNV

* Fix get_param_eltype for TSVI

* Disable Enzyme benchmark

* Don't override _evaluate!!, that breaks ForwardDiff (sometimes)

* Move FastLDF to experimental for now

* Fix imports, add tests, etc

* More test fixes

* Fix imports / tests

* Remove AbstractFastEvalContext

* Changelog and patch bump

* Add correctness tests, fix imports

* Concretise parameter vector in tests

* Add zero-allocation tests

* Add Chairmarks as test dep

* Disable allocations tests on multi-threaded

* Fast InitContext (#1125)

* Make InitContext work with OnlyAccsVarInfo

* Do not convert NamedTuple to Dict

* remove logging

* Enable InitFromPrior and InitFromUniform too

* Fix `infer_nested_eltype` invocation

* Refactor FastLDF to use InitContext

* note init breaking change

* fix logjac sign

* workaround Mooncake segfault

* fix changelog too

* Fix get_param_eltype for context stacks

* Add a test for threaded observe

* Export init

* Remove dead code

* fix transforms for pathological distributions

* Tidy up loads of things

* fix typed_identity spelling

* fix definition order

* Improve docstrings

* Remove stray comment

* export get_param_eltype (unfortunatley)

* Add more comment

* Update comment

* Remove inlines, fix OAVI docstring

* Improve docstrings

* Simplify InitFromParams constructor

* Replace map(identity, x[:]) with [i for i in x[:]]

* Simplify implementation for InitContext/OAVI

* Add another model to allocation tests

Co-authored-by: Markus Hauru <markus@mhauru.org>

* Revert removal of dist argument (oops)

* Format

* Update some outdated bits of FastLDF docstring

* remove underscores

---------

Co-authored-by: Markus Hauru <markus@mhauru.org>
* print output

* fix

* reenable

* add more lines to guide the eye

* reorder table

* print tgrad / trel as well

* forgot this type
@github-actions
Copy link
Contributor

github-actions bot commented Nov 20, 2025

Benchmark Report

  • this PR's head: 51b399aeb1f3c4ee29e1029215668b47847e0a15
  • base branch: f3f866b73a494173a15c82571e0f81f79ce6c100

Computer Information

Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────────────────────┬───────────────────────────┬─────────────────────────────────┐
│                       │       │             │                   │        │        t(eval) / t(ref)        │     t(grad) / t(eval)     │        t(grad) / t(ref)         │
│                       │       │             │                   │        │ ──────────┬──────────┬──────── │ ──────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │      base │  this PR │ speedup │  base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼──────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │             typed │   true │    369.31 │   394.29 │    0.94 │  9.86 │    9.07 │    1.09 │   3639.86 │   3577.79 │    1.02 │
│                   LDA │    12 │ reversediff │             typed │   true │   2637.47 │  2952.80 │    0.89 │  5.08 │    4.46 │    1.14 │  13396.58 │  13164.17 │    1.02 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │ 103223.54 │ 59021.43 │    1.75 │  3.82 │    6.04 │    0.63 │ 394478.61 │ 356275.88 │    1.11 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼──────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │   7950.34 │  5854.19 │    1.36 │  4.61 │    6.03 │    0.76 │  36614.54 │  35301.89 │    1.04 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │  32929.05 │ 31832.44 │    1.03 │  9.99 │   10.31 │    0.97 │ 328800.83 │ 328151.69 │    1.00 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │   3760.75 │  3652.19 │    1.03 │ 12.40 │    9.17 │    1.35 │  46643.06 │  33491.67 │    1.39 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼──────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │      2.67 │     2.59 │    1.03 │  3.98 │    3.88 │    1.03 │     10.61 │     10.05 │    1.06 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │   1216.95 │  1097.41 │    1.11 │ 63.06 │  134.85 │    0.47 │  76740.06 │ 147982.41 │    0.52 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │       err │      err │     err │   err │     err │     err │       err │       err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼──────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │       err │      err │     err │   err │     err │     err │       err │       err │     err │
│           Smorgasbord │   201 │      enzyme │             typed │   true │   1681.64 │  1496.96 │    1.12 │  6.53 │    6.65 │    0.98 │  10974.31 │   9962.13 │    1.10 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │   1668.31 │  1504.86 │    1.11 │  5.25 │    5.80 │    0.91 │   8754.55 │   8722.10 │    1.00 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼──────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ reversediff │             typed │   true │   1652.63 │  1478.48 │    1.12 │ 92.30 │  104.81 │    0.88 │ 152541.54 │ 154953.75 │    0.98 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │   1669.09 │  1483.80 │    1.12 │ 61.01 │   62.70 │    0.97 │ 101835.03 │  93038.90 │    1.09 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │   1668.31 │  1513.33 │    1.10 │ 59.02 │   63.62 │    0.93 │  98468.85 │  96271.98 │    1.02 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼──────────┼─────────┼───────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │   1676.15 │  1502.22 │    1.12 │ 59.07 │   61.54 │    0.96 │  99013.88 │  92450.40 │    1.07 │
│              Submodel │     1 │    mooncake │             typed │   true │      7.01 │     3.28 │    2.14 │  5.24 │   11.90 │    0.44 │     36.75 │     39.03 │    0.94 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴───────────┴──────────┴─────────┴───────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link

codecov bot commented Nov 20, 2025

Codecov Report

❌ Patch coverage is 92.52336% with 32 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.12%. Comparing base (62e0da8) to head (fac8641).
⚠️ Report is 4 commits behind head on breaking.

Files with missing lines Patch % Lines
src/varnamedtuple.jl 92.15% 31 Missing ⚠️
src/utils.jl 75.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1150      +/-   ##
============================================
+ Coverage     78.90%   80.12%   +1.22%     
============================================
  Files            41       42       +1     
  Lines          3910     4302     +392     
============================================
+ Hits           3085     3447     +362     
- Misses          825      855      +30     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm
Copy link
Member

penelopeysm commented Nov 20, 2025

It looks to me that the 1.11 perf is only a lot worse on the trivial model. In my experience (ran into this exact issue with Enzyme once, see also https://github.com/TuringLang/DynamicPPL.jl/pull/877/files), trivial models with 1 variable can be quite susceptible to changes in inlining strategy. It may be that a judicious @inline or @noinline somewhere will fix this.

… and also `bundle_samples` (#1129)

* Implement `ParamsWithStats` for `FastLDF`

* Add comments

* Implement `bundle_samples` for ParamsWithStats -> MCMCChains

* Remove redundant comment

* don't need Statistics?
@mhauru mhauru mentioned this pull request Nov 24, 2025
penelopeysm and others added 3 commits November 25, 2025 11:41
* Make FastLDF the default

* Add miscellaneous LogDensityProblems tests

* Use `init!!` instead of `fast_evaluate!!`

* Rename files, rebalance tests
@mhauru
Copy link
Member Author

mhauru commented Dec 15, 2025

I have merged in main, which should introduce tests that fail because of the above issue with non-Array block variables. I then fix said issue using a wrapper type like discussed above, in #1180. That one also changes the return type of keys to be a Vector, and fixes some other things.

Further development of this will take place in #1180, to make reviewing easier. Eventually we should get to a point where #1180 is merged into this, and this should then be immediately ready to go into breaking.

@mhauru
Copy link
Member Author

mhauru commented Dec 15, 2025

Just to summarise the conclusion from the team meeting just now, on Colons:

  • I should implement what was discussed above, using the value of x to concretise VarNames in the tilde pipeline before they hit VNT. This should allow x[1,:] ~ blah in a model to work just fine.
  • The harder case of condition(model, Dict(@varname(x[1,:]) => val)) would be solved eventually by having size information of all variables available as part of model (this is the point of @of), but for now we would disallow it.

Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We agreed that VNT things will be coalesced into this branch and then we'll merge this into breaking when we're happy with the whole thing.

mhauru and others added 3 commits January 6, 2026 14:21
Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
@mhauru mhauru force-pushed the mhauru/vnt-for-fastldf branch from fac8641 to 44be19d Compare January 6, 2026 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants