Skip to content

GLOBE Performance Improvements#1633

Open
peterdsharpe wants to merge 2 commits into
NVIDIA:mainfrom
peterdsharpe:psharpe/GLOBE-perf-to-merge
Open

GLOBE Performance Improvements#1633
peterdsharpe wants to merge 2 commits into
NVIDIA:mainfrom
peterdsharpe:psharpe/GLOBE-perf-to-merge

Conversation

@peterdsharpe
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

This PR is a focused performance pass on the experimental GLOBE model and its training scripts. No model architecture changes, no checkpoint compat changes; all numerics are unchanged at fp32-epsilon precision (a couple of internal aggregations now happen in fp64 - details in §5).

Net results on 16x B200s, DrivAerML training with standalone recipe: per-sample time (load, forward, backward, step) goes from ~7.8 seconds to ~5.5 seconds. (-30%)

TL;DR

  • CPU↔GPU synchronization is the dominant win. The hot path used to issue ~60+ implicit syncs per training step (mem_get_info, .item(), if X.any():, tensor[bool_mask], …). Most have been eliminated.
  • Two O(tree_depth) Python loops are gone. Per-node source aggregation and per-node strength accumulation are now single cumsum + range-subtract ops, exploiting the morton-sorted source ordering.
  • The four-phase BarnesHutKernel scatter is packed. One index_add_ per phase instead of one scatter_add_ per phase per output field. This more than halves the backward-pass scatter cost on DrivAerML, where indexing_backward was ~23 % of total GPU time per step.
  • GradScaler is gone from the training scripts (it's a no-op with bf16 autocast and forces a per-step inf/NaN-check sync).
  • CUDA allocator is now expandable_segments:True to avoid synchronizing cudaMalloc / cudaFree round-trips under the chunked kernel workload.
  • Numerical-stability fix: centroid / node-strength aggregations now cumsum in fp64. Was producing p99 ~100 % wrong leaf centroids at DrivAer scale (N=1M, coords ~5 m, leaf_size=1).
  • Several torch.compile graph breaks are fixed, mostly as a side-effect of the changes above.
  • New optional tree_build_device knob ("cpu" / "cuda" / None) for the small-problem regime where CUDA-launch latency dominates tree work.

Why these changes matter (hot path context)

A GLOBE training step is dominated by two things:

  1. Per-sample geometry preprocessing. Build one ClusterTree per BC type, then run a dual-tree Barnes-Hut traversal for every (src BC, dst BC) pair to produce a DualInteractionPlan describing which target/source point pairs interact in the near, far, and mixed-rank phases.
  2. BarnesHutKernel evaluation. Apply MLPs to those pairs and scatter_add the results back to per-target buffers, four times (one for each near/far phase), once per multiscale branch, once per communication hyperlayer.

Profiling showed both phases were CPU-bound on launch/sync overhead rather than GPU-bound on the actual MLP and indexing work. The changes below attack that bottleneck.


1. Eliminating GPU↔CPU synchronization

Each entry below removes a sync point that profiling identified as significant. These are the same change in spirit (avoid materializing int(tensor) or branching on tensor contents in the hot path), applied to several different call sites.

1.1 _auto_chunk_size: mem_get_info → cached total_memory

BarnesHutKernel._auto_chunk_size ran on every kernel evaluation and called torch.cuda.mem_get_info(device), which is a synchronizing driver query. Across the four phases × multiple branches × multiple hyperlayers, this was ~60 syncs per training step. Replaced with a cached lookup of total device memory and a fixed-fraction budget (25 %):

_CHUNK_MEMORY_BUDGET_PERCENT: Final[int] = 25


def _ceil_div(a: int, b: int) -> int:
    # ... pure-integer ceiling division ...
    return -(-a // b)


@torch.compiler.disable
@cache
def _device_total_memory_bytes(device: torch.device) -> int:
    # CUDA: torch.cuda.get_device_properties(device).total_memory
    # CPU:  psutil.virtual_memory().total
    ...

@torch.compiler.disable
def _device_chunk_budget_bytes(device: torch.device) -> int:
    return _device_total_memory_bytes(device) * _CHUNK_MEMORY_BUDGET_PERCENT // 100

Same swap is applied to MultiscaleKernel.forward's branch-checkpoint decision.

The trade is "exactly fits in currently-free memory" → "fits in 25 % of total device memory". Fine on 100+ GB devices (e.g. GB200) where the non-kernel resident state is well under the remaining 75 %.

1.2 Smoothing radius: Python float → registered buffer

Kernel._evaluate_interactions was rebuilding a tensor on every forward pass:

smoothing_radius = torch.tensor(self.smoothing_radius, device=device, dtype=dtype)
... (vectors * vectors).sum(dim=-1).apply(lambda x: x + smoothing_radius**2)

Now stored as a registered buffer and used directly with TensorDict broadcasting:

        ### Pre-squared smoothing radius as a registered tensor buffer.
        ### A buffer (rather than a Python float) is required because
        ### ``_evaluate_interactions`` adds this scalar to a TensorDict
        ### inside a checkpoint sub-graph; Python free variables cannot
        ### be lifted across nested Dynamo SubgraphTracers (the lift
        ### chain bottoms out at the root tracer with no parent and
        ### asserts ``lift_tracked_freevar_to_input should not be
        ### called on root SubgraphTracer``).  As a buffer, the
        ### tensor is a tracked module attribute that Dynamo treats as
        ### a graph leaf, so no lift is needed.
        self.register_buffer(
            "_smoothing_radius_sq",
            torch.tensor(smoothing_radius**2, dtype=torch.float32),
            persistent=False,
        )

Eager: skips a tensor allocation and a .apply(lambda) per forward. Compile: closes the torch.compile graph break.

1.3 Cluster tree construction: per-level sync → end-of-loop compaction

ClusterTree.from_points used to do, on every level of every tree,

leaf_indices = torch.where(is_leaf_seg)[0]
if len(leaf_indices) > 0:
    ...

That's a CPU-GPU sync per level × per tree (the in-code comment notes ~16 levels × 4 trees × 28 samples for DrivAerML). The loop now accumulates (seg_node_ids, seg_starts, seg_sizes, is_leaf_seg) per level and pays one boolean compaction after the loop:

            ### Defer leaf-segment processing to a single end-of-loop pass.
            # Per-iter ``torch.where(is_leaf_seg)[0]`` was a CPU-GPU sync
            # point on every level of every tree (~16 levels x 4 trees x
            # 28 samples).  We instead accumulate (node_id, start, size,
            # validity) for each segment seen during the loop and pay one
            # boolean compaction at the end.  ``torch.where(is_internal_seg)``
            # remains in-loop because the next iteration's active segments
            # are derived from this iteration's internals.
            leaf_seg_node_ids: list[torch.Tensor] = []
            leaf_seg_starts: list[torch.Tensor] = []
            leaf_seg_sizes: list[torch.Tensor] = []
            leaf_seg_validity: list[torch.Tensor] = []

The previous helpers _fill_leaf_aabbs and _fill_leaf_total_areas also each called int(leaf_sizes.sum()); they're merged into a single _fill_leaf_aggregates that shares the _ragged_arange mapping between AABB and area aggregation, removing the redundant .sum() sync.

1.4 Dual-tree traversal: ~12 .any()-gated indexings → one compaction

ClusterTree.find_dual_interaction_pairs is the hot loop of geometry prep. The previous body did 12 separate if X.any(): early-exits and tensor[bool_mask] selections per iteration - each one a sync.

The new body classifies every active node-pair into one of three cases (case_T_only, case_S_only, case_both), computes all eight potential child slots with validity masks, stacks them, and pays one boolean compaction per iteration:

                ### 3. Generate next iteration's active set.
                # We compute children over the FULL active set (n_active
                # entries) and use validity masks per (T,S) child slot to
                # encode the case-A / case-B / case-C splitting rules from
                # the original implementation.  After unioning the eight
                # potential child slots we pay ONE boolean compaction
                # instead of the original ~12 ``.any()``-gated indexings.
                ...
                slot_t = torch.stack([...])
                slot_s = torch.stack([...])
                slot_v = torch.stack([...])
                ...
                flat_v = slot_v.reshape(-1)
                active_tgt_nodes = slot_t.reshape(-1)[flat_v]
                active_src_nodes = slot_s.reshape(-1)[flat_v]

Far-field outputs are similarly accumulated unfiltered into per-iteration lists with a parallel is_far validity mask, then compacted ONCE at the very end of the traversal.

1.5 Leaf-pair expansion: dedup nonzero calls

In _expand_dual_leaf_hits, tensor[bool_mask] lowers to a synchronizing aten::nonzero that sizes the output. When several indexings share the same mask, the previous code paid one sync per indexing. They're now deduped:

    ### (near, far) output.  ``target_is_far`` is consumed by two
    ### indexings; doing one ``nonzero`` and reusing the integer index
    ### saves one sync (each ``tensor[bool_mask]`` lowers to a
    ### synchronizing ``aten::nonzero`` to size the output).
    far_idx_t = target_is_far.nonzero(as_tuple=True)[0]
    nf_target_ids = target_point_ids[far_idx_t]
    nf_source_node_ids = src_leaf_per_target[far_idx_t]

The function also drops a pile of "early return on empty input" branches that were each a sync; the downstream ops (scatter, bincount, repeat_interleave, _ragged_arange) all handle zero-element inputs correctly.

1.6 Traversal loop bound: max_depth.item()bit_length

The dual-traversal loop bound was

max_iters = int(target_tree.max_depth.item()) + int(source_tree.max_depth.item()) + 1

i.e. two .item() calls before the loop starts. Replaced with a static upper bound that needs no GPU read:

            ### Loop bound: every iteration descends at least one tree level
            ### on at least one side, so ``2 * total_levels + safety`` is a
            ### hard upper bound that requires no GPU->CPU read.  Using
            ### ``int(max_depth.item())`` as before would force two syncs
            ### per call before we even start the loop.
            n_src_levels = max(1, int(source_tree.n_sources).bit_length())
            n_tgt_levels = max(1, int(target_tree.n_sources).bit_length())
            max_iters = 2 * (n_src_levels + n_tgt_levels) + 4

n_sources is a Python int (a .shape[0] lookup), not a tensor.

1.7 Tree-build wrapper: no_grad + minor cleanups

GLOBE._build_trees_and_plans and _build_prediction_plans now run inside torch.no_grad() (the outputs are integer indices and a non-grad-tracked area divisor, so autograd bookkeeping on dozens of tensor ops was wasted) and inline the per-plan debug logging into the build loop instead of re-iterating afterward.


2. Prefix-sum aggregation: replacing Python loops over tree levels

A cluster tree's morton-sorted source ordering has a key property: every node covers a contiguous range of sources, [node_range_start, node_range_start + node_range_count). So any node-subtree sum is a prefix-sum range subtract: prefix[end] - prefix[start].

This kills two O(tree_depth) Python loops that were previously the dominant cost in their respective functions.

2.1 ClusterTree.compute_source_aggregates

The previous implementation did:

  • Per-leaf weighted-sum via segmented scatter.
  • A Python-level for level_ids in reversed(depth_levels): propagating centroids and per-source features bottom-up.

Combined cost in profiling: ~2 s per training step of mixed CPU (level-loop overhead) and GPU (kernel launches per level).

Replaced with a single cumsum + range-subtract over the morton-sorted sources:

        ### Range-sum aggregation via morton-sorted prefix sums.
        # Each node covers a contiguous range
        # [node_range_start, node_range_start + node_range_count) in
        # morton-sorted source order, so any node-subtree sum is just
        # ``prefix[end] - prefix[start]``.  This replaces the old
        # leaf-aggregation + bottom-up Python loop, which were the
        # dominant CPU + GPU costs in ``compute_source_aggregates``
        # (~2 s combined per training step in profiling).
        ...
        sorted_points = source_points[self.sorted_source_order]
        sorted_areas = areas[self.sorted_source_order]
        weighted_points_64 = (sorted_points * sorted_areas.unsqueeze(-1)).double()
        cumsum_weighted_points = torch.nn.functional.pad(
            torch.cumsum(weighted_points_64, dim=0), (0, 0, 1, 0)
        )
        ...
        node_total_weighted_pts = (
            cumsum_weighted_points[ends] - cumsum_weighted_points[starts]
        )

Per-source feature aggregation gets the same treatment via an inner _aggregate_via_prefix_sum, applied across TensorDict leaves. Trailing feature dims are flattened so cumsum sees a single feature axis (avoiding a per-feature kernel chain inside cumsum).

The internal_level_ids / internal_level_offsets / leaf_node_ids / leaf_seg_ids tensors stored on ClusterTree purely to support the old bottom-up loop are dropped.

2.2 BarnesHutKernel._compute_node_strengths

Same story, on a single scalar field (source_strengths). The old implementation did leaf-summing via scatter then a bottom-up Python loop; the new one is one cumsum + range-subtract:

        ### Cumsum and range-subtract in fp64 to avoid catastrophic
        ### cancellation when ``cumsum_total >> range_sum`` - the regime
        ### of small leaves in a large tree built over offset coordinates.
        ### See the matching note in :meth:`ClusterTree.compute_source_aggregates`.
        sorted_strengths_64 = source_strengths[tree.sorted_source_order].double()
        ### Pad with a leading zero so that ``cumsum[i]`` is the sum of
        ### sorted_strengths[:i] - both endpoints index identically.
        prefix_sum = torch.nn.functional.pad(
            torch.cumsum(sorted_strengths_64, dim=0), (1, 0)
        )

        starts = tree.node_range_start
        ends = starts + tree.node_range_count
        return (prefix_sum[ends] - prefix_sum[starts]).to(source_strengths.dtype)

3. Packed scatter for backward-pass efficiency

BarnesHutKernel scatters into per-target output buffers four times per forward (near, far-node, near-far, far-near), once per output field. In profiling, indexing_backward_kernel_stride_1 (the backward of scatter_add_) was ~1 s out of ~4.5 s total GPU time per training step (~23 %) - the largest single GPU kernel by time.

Two changes in BarnesHutKernel.forward:

  1. Pack output fields into a single (n_targets, total_features) buffer instead of per-field buffers.
  2. Use index_add_ instead of scatter_add_ for the per-target accumulation.
        ### Packed output buffer.  All four phases scatter into this single
        ### tensor; the per-phase loops over output keys (one ``scatter_add_``
        ### per key) are replaced with one packed ``scatter_add_`` per phase.
        ### ``indexing_backward`` was the top GPU kernel by time in profiling
        ### (~23% of total GPU time); packing cuts this by the number of
        ### output fields (~4x for ``C_p`` + ``C_f`` on DrivAerML).
        ...
        packed_buf = torch.zeros(
            (n_targets, total_features), dtype=buffer_dtype, device=device
        )

The pack/scatter helper is shared across all four phases:

    def _pack_and_scatter(
        self,
        chunk_result: TensorDict,
        weights: Float[torch.Tensor, " n_pairs"],
        tgt_ids: Int[torch.Tensor, " n_pairs_or_expanded"],
        packed_buf: Float[torch.Tensor, "n_targets total_features"],
        *,
        broadcast_pair_ids: Int[torch.Tensor, " n_pairs_or_expanded"] | None = None,
    ) -> None:
        ...
        ### ``index_add_`` rather than ``scatter_add_`` with broadcasted
        ### indices: equivalent semantics, but ``index_add_`` takes a 1-D
        ### ``index`` and avoids the ``unsqueeze`` + ``expand_as`` overhead
        ### that ``scatter_add_`` requires.  In addition to the direct
        ### speedup, ``index_add_`` has a more compact backward (no index
        ### broadcasting in the saved tensors), which compounds with the
        ### packing optimization above.
        packed_buf.index_add_(0, tgt_ids, weighted)

Counterpart _unpack_buf slices back into a per-field TensorDict. The canonical-ordered field list and per-key feature widths are cached as Kernel._output_packing.

Net effect on DrivAerML (C_p 1ch + C_f 3ch = 4 fields, 4 phases): 16 scatter_add_ calls → 4 index_add_ calls per kernel evaluation.


4. Memory management

4.1 CUDA allocator: expandable_segments:True

Both run.sh scripts now export

export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"

The chunked kernel evaluations stress the cache enough that the default segment allocator falls back to synchronous cudaMalloc / cudaFree round-trips. expandable_segments lets the allocator grow existing segments instead.

4.2 Gradient checkpointing applies in both train and eval

The four if self.training and self.use_gradient_checkpointing branches across BarnesHutKernel.forward are hoisted into one helper:

    def _maybe_checkpointed_evaluate(self, *args: object) -> TensorDict:
        """Run :meth:`_gather_and_evaluate` with optional gradient checkpointing.

        ...

        The wrap fires in **both** training and eval, not just training,
        for two reasons:

        1. Memory savings (the original purpose).  In eval no autograd
           tape is active (``train.py`` wraps validation in
           ``torch.no_grad()``), so ``checkpoint(use_reentrant=False)``
           degenerates to a near-no-op forward call - no recompute, no
           extra memory, no extra compute.
        2. Workaround for a Dynamo+CUDA TensorDict bug.  When the same
           ``_gather_and_evaluate`` body is inlined into the parent
           graph (i.e. without the checkpoint wrapper), FakeTensor
           tracing fails to propagate ``TensorDict.batch_size`` through
           ``(vectors * vectors).sum(dim=-1)`` ...
        """
        if self.use_gradient_checkpointing:
            return checkpoint(
                self._gather_and_evaluate, *args, use_reentrant=False
            )
        return self._gather_and_evaluate(*args)

The self.use_gradient_checkpointing flag is also piped through MultiscaleKernelKernelBarnesHutKernel and exposed as a constructor kwarg on GLOBE (default True).

4.3 Packed output buffer

BarnesHutKernel.forward used to lazily allocate per-field, per-phase output buffers (output_bufs[k] = torch.zeros(...) if-key-not-seen pattern). Now allocated up-front:

        ### Buffer dtype must match the dtype of ``weighted = chunk * weights``
        ### that the four phases will scatter in (``index_add_`` requires
        ### exact dtype match).  ``chunk`` comes from the MLP at the
        ### autocast dtype (or ``source_points.dtype`` outside autocast);
        ### ``weights`` carries ``source_strengths.dtype``.  Their product
        ### is the type-promoted dtype, which is what the previous
        ### lazy-allocated buffers used to capture - we now compute it
        ### eagerly so a single buffer can be allocated up front.
        ...
        packed_buf = torch.zeros(
            (n_targets, total_features), dtype=buffer_dtype, device=device
        )

Composes with §3 (packed scatter) and lets us drop the if n_near == 0 and n_nf == 0 and n_fn == 0 and n_far_nodes == 0 empty-result short-circuit at the top of forward: an all-zero packed_buf unpacks to all-zero per-field tensors, which is the correct result anyway.


5. Numerical stability via fp64 cumsum

The cumsum-based aggregations in §2 are done in float64 and cast back to source_points.dtype at the end. Reason from the code:

        # The cumsum and the range subtract are done in fp64 because fp32
        # suffers catastrophic cancellation when ``range_sum << cumsum_total``,
        # which is the regime of small leaves (``leaf_size=1``) in a large
        # tree built over offset (e.g. all-positive) coordinates.  At
        # drivaer scale (``N=1M``, coords ~5 m), fp32 leaf centroids had
        # median ~2 % relative error and p99 ~100 % wrong.  Lifting the
        # cumsum to fp64 brings this back to fp32 epsilon (~1e-7) and adds
        # <1 % wall-clock to the training step (cumsum is ~2.3x slower in
        # fp64, but cumsum is a tiny fraction of step time).  CUDA fp32
        # cumsum is also non-deterministic across runs (pytorch#75240);
        # fp64 cumsum is much less affected.

Listed under "performance" because it's a prerequisite for §2 - without it the prefix-sum trick can't be used.


6. torch.compile compatibility

Most of the syncs above (§1) and the packed-scatter rewrite (§3) also fix torch.compile graph breaks. A few changes here are specifically for torch.compile:

  • _ceil_div for chunk sizing instead of math.ceil(a / b). Under specialize_float=False, a Python float gets traced as an unbacked symbolic float that propagates into the chunk size and crashes Dynamo at the checkpoint boundary in _gather_and_evaluate. Pure-integer ceiling division (-(-a // b)) sidesteps this.
  • _CHUNK_MEMORY_BUDGET_PERCENT: Final[int] = 25 instead of e.g. 0.25 - same reason.
  • @torch.compiler.disable on the device-memory helpers and _auto_chunk_size, so Dynamo doesn't try to evaluate torch.cuda.get_device_properties as a graph constant (which crashes on CPU-only hosts and produces unbacked symbolic ints on CUDA hosts).
  • The smoothing-radius buffer from §1.2 also fixes the lift-across- SubgraphTracers issue.
  • view(-1, 1)reshape(-1, 1) in GLOBE.forward's output calibration loop, for robustness when upstream ops produce non-contiguous tensors.

The hierarchical_acceleration.md design doc is updated to reflect the new "static fraction of total device memory" chunk-size strategy.


7. Training scripts

Both examples/cfd/external_aerodynamics/globe/airfrans/train.py and .../drivaer/train.py have the same three changes.

7.1 Removed GradScaler

GradScaler is only useful for fp16 autocast (to combat gradient underflow). The training scripts use torch.bfloat16 autocast, which has the same dynamic range as fp32 - there's no underflow to protect against. When enabled=amp, GradScaler still forces a per-step inf/NaN check sync on the gradients, for nothing.

-    scaler = torch.amp.GradScaler(device=device.type, enabled=amp)
-    ...
-    scaler.scale(batch_loss).backward()
+    batch_loss.backward()
     if gradient_clip_norm is not None:
-        scaler.unscale_(optimizer)
         torch.nn.utils.clip_grad_norm_(
             model.parameters(), max_norm=gradient_clip_norm
         )
-    scaler.step(optimizer)
-    scaler.update()
+    optimizer.step()

(Removed from checkpoint save/load plumbing too.)

7.2 New tree_build_device knob

Exposed as a Literal["cpu", "cuda"] | None argument:

        leaf_size=leaf_size,
        network_type=network_type,
        self_regularization_beta=self_regularization_beta,
        latent_compression_scale=latent_compression_scale,
        expand_far_targets=expand_far_targets,
        tree_build_device=tree_build_device,
    ).to(device)

Forwarded to GLOBE.__init__, which transfers boundary centroids/areas to build_device for from_points + find_dual_interaction_pairs, then moves the resulting ClusterTree and DualInteractionPlan (both @tensorclass, so one .to() call each) back to the input's device at the end of _build_trees_and_plans / _build_prediction_plans.

Default None preserves prior behavior (build on the input's device). "cpu" can win for small problems (~a few thousand boundary cells) where CUDA launch latency + cudaStreamSynchronize round-trips dominate the actual tree work.

7.3 CUDA allocator env var (run.sh)

See §4.1.


8. Dependencies (pyproject.toml)

-    "tensordict>=0.10.0",
+    "tensordict>=0.12.2",
     "omegaconf>=2.3.0",
     "importlib-metadata>=8.7.1",
+    "psutil>=6.0.0",
  • tensordict>=0.12.2 (was >=0.10.0) - bumped alongside the BarnesHutKernel refactor; the codebase exercises TensorDict broadcasting and batch-size propagation more heavily than before (see §4.2 for one such case in _maybe_checkpointed_evaluate), and this update brings better torch.compile support.
  • psutil>=6.0.0 - used by _device_total_memory_bytes for the CPU branch (psutil.virtual_memory().total). The CPU build path is debug-only; production runs on CUDA, where the CUDA branch is used.

What's intentionally not in this PR

  • No model-architecture changes. No new layers, no new losses, no hyperparameter changes.
  • No checkpoint format changes - existing .mdlus checkpoints load unmodified.
  • The two new constructor kwargs on GLOBE (use_gradient_checkpointing, tree_build_device) both have defaults matching the prior behavior, so this is a non-breaking change for downstream users.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

commit 6a6629f1c4f444e20859ab905ba2d950d66bbb0b
Merge: 2fb3d047 3670381b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon May 11 14:42:49 2026 -0400

    Merge branch 'main' into psharpe/GLOBE-dev

commit 2fb3d0473e8d66536c04e2b9b78965a8c5a7d3bf
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon May 11 12:02:10 2026 -0400

    Allows CPU-based tree build, if desired

commit ada167d9c52b45487dafdba3e6d2872c1ac925fd
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon May 11 11:24:36 2026 -0400

    bump tensordict version

commit 34e14584b2092ab49c724b75a861989d66910f35
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon May 11 11:17:12 2026 -0400

    Makes `psutil` a core dependency

commit 7e31b1c5e4e033e50c380f594ae206d3115596b6
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon May 11 01:36:17 2026 -0400

    Optimize indexing in _expand_dual_leaf_hits to reduce GPU synchronization by consolidating boolean mask operations into single nonzero calls, improving performance in the ClusterTree implementation.

commit ebb2b0a611e57e6804b7f3a8cfad0294da68f38b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon May 11 05:35:46 2026 +0000

    Refactor gradient checkpointing in BarnesHutKernel to apply in both training and evaluation phases. This change improves memory efficiency and addresses a TensorDict bug related to batch size propagation during tracing.

commit ce4ab79d649c674d24165aaedccf953984b23feb
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 10 18:29:41 2026 +0000

    Adds chunking support on CPU too.

commit 2b326c1ddb498cb7358be78c794360f6a9f73348
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 10 14:14:58 2026 -0400

    Refactor chunk memory budget calculation to use integer arithmetic and introduce integer ceiling division function to enhance compatibility with torch.compile. This change prevents crashes related to symbolic float handling during chunk size derivation.

commit fce4b188a930f3431c44fd686a807c44a9c332ab
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 10 13:48:04 2026 -0400

    eliminates graph break

commit c5cff850b6d00122d866c225d354e9166c3f2462
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 10 16:26:25 2026 +0000

    fixes smoothing_radius graph break using buffer.

commit 3fb11547a2fa2b59143efed608e1004fd470fd9f
Merge: 835ea142 aa6d2208
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 10 15:36:50 2026 +0000

    Merge branch 'psharpe/GLOBE-dev-perf' of https://github.com/peterdsharpe/physicsnemo into psharpe/GLOBE-dev-perf

commit aa6d220864ae24a81e68273aa25a3af1ca60c387
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 10 11:36:06 2026 -0400

    Enhance numerical stability in ClusterTree and BarnesHutKernel by switching cumulative sum operations to double precision (fp64) to prevent catastrophic cancellation in small-leaf scenarios. Added tests to validate precision at scale for both implementations, ensuring consistency between fp32 and fp64 results.

commit 835ea142a27961d2a3af9609d7c911ac5ff9bbd0
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat May 9 05:53:52 2026 +0000

    further simplification to smoothing_radius

commit ccf4d85c179617557349e44b8ae70b0cae069a71
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat May 9 05:27:28 2026 +0000

    removes torch.tensor() break in smoothing_radius

commit de1b8c93a2603bd1c228e71dd76ea436f351371a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat May 9 04:49:02 2026 +0000

    Large refactor to reduce GPU syncs in field_kernel and cluster_tree; plus tests

commit 8d5e6f2604be7852a909830ad392b84d0914dfa5
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat May 9 02:14:40 2026 +0000

    Removes GPU syncs during chunking

commit b0f1db4410ea591a5b90ad15eec0ac323e76bff4
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat May 9 01:35:22 2026 +0000

    pipe through gradient checkpointing option, still default to true

commit 9d5515c0539ad05ac162be8c57e8c10da9e75495
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat May 9 01:32:01 2026 +0000

    Removes GradScaler

commit 57f2c6ce2dfb405b9a5c704aabb1b355fd1f8f03
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat May 9 01:29:49 2026 +0000

    env var setting

commit 59e4b6c94292ce20c8641e7e65417de270c7fae3
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 8 20:07:04 2026 +0000

    Squashed commit of the following:

    commit 6bb41cc5ff9a79e97478b1a1e478245599929c7e
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 19:29:35 2026 +0000

        Add validation for data ranks in GLOBE model

        - Introduced `validate_data_contains_ranks` function to ensure that incoming data contains all declared leaves with matching ranks, enhancing error handling for missing or mismatched data.
        - Updated the GLOBE model to utilize this validation during the forward pass, ensuring robustness against incorrect input configurations.
        - Added comprehensive tests to verify the behavior of the model with respect to extra and missing data keys, as well as rank mismatches.

    commit 97b1d3e9fffec31af513dcc5df724b94ae51f9ee
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 19:10:54 2026 +0000

        fixes normal handling

    commit fdca989212fced6cabdf9da5eaebea9d3dd203c4
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 17:18:21 2026 +0000

        fixes issue where precision was casting the data in addition to the model, not just the model internals

    commit 80c76380c8b4ca158e5cbbb21f53140e5df77377
    Merge: 96b17855 c8f21f16
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 17:14:06 2026 +0000

        Merge branch 'psharpe/add-GLOBE-to-unified-recipe' of https://github.com/peterdsharpe/physicsnemo into psharpe/add-GLOBE-to-unified-recipe

    commit 96b178550b819cce9b99e758e353ebb9e9d1f409
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 17:08:34 2026 +0000

        autocast minor fixes

    commit c8f21f1680dce763347099fab0a040aec03fbd3d
    Merge: a0a2bc70 5f2940f9
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 12:30:27 2026 -0400

        Merge branch 'main' into psharpe/add-GLOBE-to-unified-recipe

    commit 2b5106d70fce298d77727b63afd3a8835a07c8e5
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 15:28:53 2026 +0000

        sync to latest run

    commit a0a2bc70b8a7157a788bd33fff07567e9ceec475
    Merge: e4c6b52f 121e0b31
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 10:55:54 2026 -0400

        Merge branch 'main' into psharpe/add-GLOBE-to-unified-recipe

    commit e4c6b52f00de81bb50f4f753180d1f299494f54e
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 10:54:55 2026 -0400

        Squashed commit of the following:

        commit 95cfcd5e653cb691d034b8a1c7fc05ef6f843938
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 10:52:56 2026 -0400

            pre-commit formatting

        commit 881c2f67954a52b8441c4c682d9d9fc5648ae874
        Merge: 224eef57 121e0b31
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 10:52:22 2026 -0400

            Merge branch 'main' into psharpe/update-unified-recipe

        commit 224eef57dab884308c72e11b2f0d273528083ddf
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 10:50:50 2026 -0400

            Adds tests

        commit 2a191773f9354660c09db0feb04a6c9ceb56f6d5
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:39:56 2026 -0400

            Refactor freestream scale calculations to use the `freestream_scales` function for consistency and precision. Update docstring for TensorBoard logging function to clarify usage and improve readability.

        commit f677107cd7793542d99b52634281925e18b105e3
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:38:27 2026 -0400

            Cleanup pass

        commit f9ac280ab8644f0a653859f92d2369789bb72d51
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:29:13 2026 -0400

            inlines function

        commit f8773b2b2a24a993954472a54a031b024fc33a08
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:25:51 2026 -0400

            strips out apply_to_tensordict_mesh (now dead code; easily a one-liner with td.apply())

        commit 705e46bd2d556df6500a903aa7c654f2590fb276
        Merge: d6b96f6c e33028cd
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 03:15:23 2026 +0000

            Merge branch 'psharpe/update-unified-recipe' of https://github.com/peterdsharpe/physicsnemo into psharpe/update-unified-recipe

        commit e33028cd47c94b8b0c010abc331fb2113069e5ff
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:13:29 2026 -0400

            Switches to be tensordict-native

        commit e302f3c0515f34edae293c5fdf1c2ef6ec048076
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:10:08 2026 -0400

            more conversions to "associations"

        commit d705007ba22c0cc5ae2284aea6f89c7b6666d9d5
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:07:19 2026 -0400

            Standardize on the name "association"

        commit 9c5c5282623fa9b43269ba1c739051b01fa7717e
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:01:40 2026 -0400

            pipes through new imports

        commit 0b7268926a2fd43d45ace9917efb10cb09576804
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:00:59 2026 -0400

            docs

        commit eb9807d023685e2e7613115a71629ad781f90dc0
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:00:17 2026 -0400

            adds docstring

        commit 51c7d2ceb38b8f36fa5bc521eaa345d6a8570051
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 22:55:27 2026 -0400

            Migrates functions per review

        commit 8a1aea97e58caee97278a89f9d0734aeb9b9d610
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 22:52:57 2026 -0400

            Enhance error handling in normalize_output_to_tensordict to raise explicit ValueError for all-None tensor outputs, improving clarity on model misconfiguration.

        commit 3cf398188221317f2b71604f638386db02c76a95
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 22:52:23 2026 -0400

            fixes forward_kwargs bool coercion

        commit d6b96f6c0119ffe7d9a0392cb986637b473df49e
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 02:49:03 2026 +0000

            syncs with new name

        commit 1e86dae3a1aeb4b2f3cbb62e89356a1c23637403
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 02:45:12 2026 +0000

            Update surface pattern in drivaer_ml_surface.yaml to reflect vehicle boundaries

        commit 121e0b31e8ee4bd796f70f0ce96e9c313972c03f
        Author: Corey adams <6619961+coreyjadams@users.noreply.github.com>
        Date:   Thu May 7 15:54:35 2026 -0500

            Testing and validating ci with data from huggingface (#1601)

            * Testing and validating ci with data from huggingface

            * Update ci data download action

            * Remove get data command

            * Make sure the singleton states clear out to fix tests.

        commit bb0e19a45476f650aef4fcda5179cb19aca26c84
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 15:58:55 2026 -0400

            Partial merge from psharpe/add-GLOBE-to-unified-recipe

    commit a44697351dbad6e2d9feeb711408a95e4abee6bd
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 10:49:04 2026 -0400

        sync tests

    commit 10beae03c2ee04dd844cdbe89257ef30bfa5c66c
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 8 10:48:37 2026 -0400

        readme updates

    commit 26a91d0f5655b07399808d7af9021b6eb0fa62c0
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Thu May 7 23:54:54 2026 -0400

        Squashed commit of the following:

        commit 2a191773f9354660c09db0feb04a6c9ceb56f6d5
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:39:56 2026 -0400

            Refactor freestream scale calculations to use the `freestream_scales` function for consistency and precision. Update docstring for TensorBoard logging function to clarify usage and improve readability.

        commit f677107cd7793542d99b52634281925e18b105e3
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:38:27 2026 -0400

            Cleanup pass

        commit f9ac280ab8644f0a653859f92d2369789bb72d51
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:29:13 2026 -0400

            inlines function

        commit f8773b2b2a24a993954472a54a031b024fc33a08
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:25:51 2026 -0400

            strips out apply_to_tensordict_mesh (now dead code; easily a one-liner with td.apply())

        commit 705e46bd2d556df6500a903aa7c654f2590fb276
        Merge: d6b96f6c e33028cd
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 03:15:23 2026 +0000

            Merge branch 'psharpe/update-unified-recipe' of https://github.com/peterdsharpe/physicsnemo into psharpe/update-unified-recipe

        commit e33028cd47c94b8b0c010abc331fb2113069e5ff
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:13:29 2026 -0400

            Switches to be tensordict-native

        commit e302f3c0515f34edae293c5fdf1c2ef6ec048076
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:10:08 2026 -0400

            more conversions to "associations"

        commit d705007ba22c0cc5ae2284aea6f89c7b6666d9d5
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:07:19 2026 -0400

            Standardize on the name "association"

        commit 9c5c5282623fa9b43269ba1c739051b01fa7717e
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:01:40 2026 -0400

            pipes through new imports

        commit 0b7268926a2fd43d45ace9917efb10cb09576804
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:00:59 2026 -0400

            docs

        commit eb9807d023685e2e7613115a71629ad781f90dc0
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 23:00:17 2026 -0400

            adds docstring

        commit 51c7d2ceb38b8f36fa5bc521eaa345d6a8570051
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 22:55:27 2026 -0400

            Migrates functions per review

        commit 8a1aea97e58caee97278a89f9d0734aeb9b9d610
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 22:52:57 2026 -0400

            Enhance error handling in normalize_output_to_tensordict to raise explicit ValueError for all-None tensor outputs, improving clarity on model misconfiguration.

        commit 3cf398188221317f2b71604f638386db02c76a95
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 22:52:23 2026 -0400

            fixes forward_kwargs bool coercion

        commit d6b96f6c0119ffe7d9a0392cb986637b473df49e
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 02:49:03 2026 +0000

            syncs with new name

        commit 1e86dae3a1aeb4b2f3cbb62e89356a1c23637403
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Fri May 8 02:45:12 2026 +0000

            Update surface pattern in drivaer_ml_surface.yaml to reflect vehicle boundaries

        commit bb0e19a45476f650aef4fcda5179cb19aca26c84
        Author: Peter Sharpe <peterdsharpe@gmail.com>
        Date:   Thu May 7 15:58:55 2026 -0400

            Partial merge from psharpe/add-GLOBE-to-unified-recipe

    commit a513e024635a4fd89437c524141ae31b8c772be9
    Merge: 8e98aaee b15dae48
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Thu May 7 15:53:49 2026 -0400

        Merge branch 'main' into psharpe/add-GLOBE-to-unified-recipe

    commit 8e98aaee30cf9e753d209287db28b2647aec4880
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 19:44:43 2026 -0400

        formatting

    commit 596f527a1d38e680f80d7db37990c1eb819af984
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 19:06:43 2026 -0400

        edits comments

    commit ceda0b3757218fb347af3330eb2791419980e388
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 13:51:14 2026 -0400

        type hints

    commit dfafc47c5a7d06dc23173e7e071c87de285c0a4c
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 13:38:40 2026 -0400

        simplifications

    commit 49827b15fedf68a9cd1828ec234775b550e6fcec
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 13:20:03 2026 -0400

        Adds progress

    commit 4756955f21a82b2be309d05d170a8149575c21c5
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 10:55:38 2026 -0400

        review pass

    commit e6b7345c5e70bcf5bc867f5467d32c113c4b301a
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 09:17:46 2026 -0400

        simplifications throughout

    commit 9e0c1e86f12e0f6beb49742aec01159be45deeac
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 08:27:10 2026 -0400

        Notes domino is WIP

    commit 8c2374a4fdbd30981233f74048aa71dacbbd5c4d
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 08:26:58 2026 -0400

        Notes DoMINO is WIP

    commit 9982de0205903857e2c7cf63f52e0a8c8f630c2e
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 08:26:46 2026 -0400

        add nontrivial default

    commit 66c3c3a361ef8af2e94a2d32a1bc38005f42683b
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Wed May 6 08:26:16 2026 -0400

        dead code sweep

    commit bb4eb8ee381169b44e7ad58ff875b9da135ad6c2
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Tue May 5 17:51:58 2026 -0400

        formatting

    commit cd9a95a3787ca516d9c0dc6aa41a24715f7aa80d
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Tue May 5 17:51:46 2026 -0400

        cleanup

    commit abba05443c2a6230f9c1329adeebb700f677eac2
    Merge: 2ee96eb4 b5f075a9
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Tue May 5 17:44:01 2026 -0400

        Merge remote-tracking branch 'origin/psharpe/DomainMesh-apply-fix' into psharpe/add-GLOBE-to-unified-recipe

    commit b5f075a9366d471066bd6a3fa9a878352a29d19c
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Tue May 5 17:39:50 2026 -0400

        renames DomainMesh.apply to DomainMesh.apply_to_meshes, to mitigate shadowing.

    commit 2ee96eb40c9dc73a0adaaaf6a4a88333088d02b0
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Tue May 5 16:50:42 2026 -0400

        adds tests

    commit 64352f1e50b0bada49a061027a0f8c4e1ebf23b0
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Tue May 5 15:59:38 2026 -0400

        Refactor external aerodynamics recipe to use TensorDict interface

        - Updated README.md to reflect changes in output normalization and loss/metric calculators, transitioning from dict-based to TensorDict-based structures.
        - Modified training configuration files to standardize boundary naming as `boundaries.vehicle` across surface and volume datasets.
        - Adjusted dataset YAML files to ensure consistent boundary definitions and removed outdated comments.
        - Enhanced collate functions to support TensorDict, ensuring proper batch dimension handling for targets.
        - Updated loss and metric calculators to operate on TensorDict inputs, improving compatibility and performance.
        - Revised tests to validate TensorDict functionality and ensure expected behavior across various configurations.

    commit 3b8760985a8ccf45a47679ad8fc8db948cb95237
    Merge: c7b27c09 70701e18
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Tue May 5 15:27:27 2026 -0400

        Merge branch 'main' into psharpe/add-GLOBE-to-unified-recipe

    commit c7b27c093a00c3bf2bf4c6cbc1b3ad2a0cc829bb
    Merge: 03d91803 64fecb91
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Tue May 5 11:05:57 2026 -0400

        Merge branch 'psharpe/add-GLOBE-drivaerml-standalone' into psharpe/add-GLOBE-to-unified-recipe

    commit 03d9180347a68bf1071efee03bf90ef7b0aad883
    Merge: 4f79a781 cbb1bf16
    Author: Mehdi Ataei <ataei8@gmail.com>
    Date:   Mon May 4 16:22:50 2026 -0400

        Merge branch 'main' into psharpe/add-GLOBE-to-unified-recipe

    commit 4f79a781ceb87868a7f9f6b1723e4a2fea6951ae
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Mon May 4 16:21:38 2026 -0400

        Adds first cut of updating recipe to support GLOBE, and to be Mesh-native.

    commit 4e33dc0b2c2be47ee098fa5b738d2f7cce00b16c
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Mon May 4 10:44:27 2026 -0400

        import sort

commit 80b760788e617482bddb62f8b0d8301c921584b8
Merge: 5716659f 260ddfb1
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 8 19:54:07 2026 +0000

    Merge branch 'main' into psharpe/GLOBE-dev

commit 5716659f3209a9b4c5e35ffa076320081ffb0163
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 8 12:40:27 2026 -0400

    fixes old renames

commit 73312c90befd8567e3c9bcd9827b8907769ea994
Merge: 76f2a747 fb402c99
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed May 6 13:24:25 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-3D-BarnesHut

commit 76f2a74752f66c6d41b2f032cdd29836036ef209
Merge: bbaad544 8cbd6cb6
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 12:26:44 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-3D-BarnesHut

commit bbaad544f6ba6de16d0005906f9f34a9673b9c89
Merge: 560f5ff7 64fecb91
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 11:15:08 2026 -0400

    Merge branch 'psharpe/add-GLOBE-drivaerml-standalone' into psharpe/add-GLOBE-3D-BarnesHut

commit 64fecb919d16758e087c2551f4b15ef79297b73c
Merge: 87c64c22 9f279dbc
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 11:04:51 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-drivaerml-standalone

commit 87c64c228869d731173fb0e29e9fd1a2d8c7dfa6
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 11:02:19 2026 -0400

    Adds a edge case guard

commit 2958dcea5e0a7b6ba4dd797f56aa20f57c52d783
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 10:43:59 2026 -0400

    docs on sample IDs

commit a1b8337a56fdd950796978aedc11d2e6bfe01f2c
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 10:42:27 2026 -0400

    fixes patience edge case

commit 29813c2f450029cf8147790582076509dcd7708b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 10:09:23 2026 -0400

    fix old arg name

commit c4da8f27ae6321e5da50f497423fb4e5bd4b29f7
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 10:00:53 2026 -0400

    Add expected training behavior section to drivaer README

    - Included details on reference hardware, wall time per epoch, peak VRAM usage, and training/validation loss metrics at various epochs.
    - This addition aims to provide users with a benchmark for evaluating their training runs.

commit f9ac3c633d86c4da588913f8b151afbb8d71a766
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 09:51:54 2026 -0400

    Integrate Profiler into training scripts for enhanced performance monitoring

    - Added Profiler import and initialization in train.py for both airfrans and drivaer examples.
    - Updated profiling logic to enable and configure the profiler based on conditions.
    - Adjusted the training loop to ensure profiler steps are called correctly.

    This enhancement allows for better performance tracking during model training.

commit 733e5a0f456b46d35c73ae8859cd6b18f027ae64
Merge: 5f78ba6e 93d13424
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue May 5 09:13:47 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-drivaerml-standalone

commit 5f78ba6e877d4a944be4093fd210a390c42bf30e
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon May 4 18:59:09 2026 -0400

    adds dataset link

commit 560f5ff743c4abdd41b50283bacd3340a153f5f7
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 21:16:06 2026 -0400

    Squashed commit of the following:

    commit 91a942b6e0e75951907d2a21b49d78a3c837e78f
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 21:12:32 2026 -0400

        Adds Greptile minor fixes

    commit b24f9b6e99e2abed1e9b6f530a108cbe1a99ca44
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 20:56:24 2026 -0400

        Back-merges dataset interrogate fix

    commit 6ddfb5ae569e59b7872184fe0f4838b3039a3735
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 20:43:01 2026 -0400

        Removes accidentally-commited benchmarks; these will come later

    commit 9fa0b5d2744471d5b33baf7a6927ec19f2aa0153
    Merge: 3e67057c 4c52a453
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 20:37:57 2026 -0400

        Merge branch 'main' into psharpe/add-GLOBE-drivaerml-standalone

    commit 3e67057c596760d9aecf0505b5f9de19e98cc9bd
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 20:35:32 2026 -0400

        Partial merge from add-GLOBE-3D-BarnesHut

    commit 4c52a4534c87290a4e0f9149c2942cbe097f95ad
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 16:38:28 2026 -0400

        Synchronizes `GLOBE` model progress for 26.05 (#1595)

        * Migrate cached_dataset.py

        * verified model arch new features (self_regularization_beta)

        * minor formatting syncs

        * Adds nonregression testing

        * Adds compile_logging utilities and prefetching utilities

        * Adds self to pade.py codeowners

        * Syncs AirFRANS updates

        * corrects a docstring

        * Strips out broken ram caching

        * Adds helpful error messages

        * Adds helpful error messages

        * docs

        * Refactor compile logging in training script

        - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
        - Updated the training script to call this new function, improving log clarity during distributed training.
        - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

        * Enhance DataLoader worker configuration for distributed training

        - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
        - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
        - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

        * Partial merge from add-GLOBE-3D-BarnesHut

    commit 4cb586ab3750c2a79fa23c39a6929ccf62fcbc53
    Merge: 645701fc ed855da4
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 15:23:50 2026 -0400

        Merge branch 'main' into psharpe/add-GLOBE-model-progress

    commit 645701fc7af2fd0ba8ffc97da9b9c8698d9dec64
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 15:06:55 2026 -0400

        Partial merge from add-GLOBE-3D-BarnesHut

    commit 15d7913a1623f5f9afe61e0257007898babf4a72
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 13:45:06 2026 -0400

        Enhance DataLoader worker configuration for distributed training

        - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
        - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
        - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

    commit 65675a42e1b4617f088fd079f999bec3efe3a4b0
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 1 17:53:47 2026 -0400

        Refactor compile logging in training script

        - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
        - Updated the training script to call this new function, improving log clarity during distributed training.
        - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

    commit 948da86713ea492c599d14e9f55981ab1e3d7549
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 1 15:40:31 2026 -0400

        docs

    commit ed855da4116d87174a17db0911413648f383cef3
    Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
    Date:   Thu Apr 30 20:44:37 2026 -0700

        Implements Predictor specialization for multi-diffusion (#1573)

        * Implements Predictor specialization for multi-diffusion

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Compile denoiser in multi-diffusion sampling compile tests

        Compiling the predictor instance directly was producing divergent results
        under torch 2.10 in the sample() loop (euler cases only). Follow the same
        pattern as test_samplers.py::TestSampleCompile and compile the denoiser
        closure instead — tracing through it still verifies that the predictor's
        __call__ path is compile-compatible.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Avoid fullgraph compile in multi-diffusion sampling test

        torch 2.10 Dynamo crashes with Fatal Python error: Aborted when tracing
        the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
        inside sample() with fullgraph=True. Allow graph breaks here; the
        predictor compile contract is still tested in isolation by
        test_multi_diffusion_predictor.py::TestCompile.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Flatten MultiDiffusionPredictor hot path for torch.compile

        Dispatch on pos_embd presence and model_kwargs is now resolved once at
        __init__ into a specialized closure, so __call__ is branch-free and the
        no-kwargs path avoids ** expansion. This keeps fullgraph=True compile
        cleanly traceable under torch 2.10 (which was hitting a Dynamo abort on
        the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
        when the denoiser closure was compiled in the sample() loop).

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Loosen torch.compile euler check in multi-diffusion sampling tests

        Reverts the two earlier CI-fix attempts (compile-denoiser switch, predictor
        hot-path flatten) since neither actually fixed the divergence. The
        underlying issue is an upstream torch>=2.10 Dynamo bug: euler + compiled
        MultiDiffusionPredictor produces numerically divergent results. Heun works,
        predictor compiles correctly in isolation. For euler we now assert only
        shape + isfinite until the upstream bug is resolved.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Force contiguous t_cur/t_next in Euler solvers

        sample() passes t.expand(B) (a stride-0 non-contiguous tensor) into
        solver.step(). HeunSolver already forces .contiguous() on both tensors to
        prevent torch.compile from specializing on the stride pattern of the first
        call and then either mis-firing guards or silently recompiling on
        subsequent calls with different underlying storage.

        EulerSolver and EDMStochasticEulerSolver had no such guard, which was a
        latent bug exposed by torch 2.10 (stricter stride tracking) in the
        multi-diffusion compiled sample loop — producing 90%+ element divergence
        vs eager on the first call and a Dynamo abort on the second call. Apply
        the same fix uniformly across all four solver steps.

        Also revert the temporary loosened euler assertion in
        test_multi_diffusion_sampling.py now that the real fix is in place.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Drop dead is_compiling guard and inherit from Predictor in MultiDiffusionPredictor

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Narrow _patching type and tighten multi-diffusion tests

        Move the _patching None check out of the is_compiling guard in
        MultiDiffusionModel2D so the type checker narrows self._patching
        to RandomPatching2D | GridPatching2D for the rest of each method,
        and route fuse/reset_patch_indices through isinstance.

        Streamline TestConstructor to only exercise the public contract
        (.fuse, .model, setter round-trip) and drop assertions on private
        caches. Compile the denoiser instead of the predictor in
        TestMultiDiffusionSampleCompile and add TestMultiDiffusionFullSamplerCompile
        mirroring test_samplers.py.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Force contiguous pos_embd before patching

        pos_embd.unsqueeze(0).expand(B, -1, -1, -1) produces a stride-0 view
        (all B copies share storage). Passing this through nn.ReflectionPad2d
        and F.unfold inside image_batching triggers a glibc heap corruption
        on torch 2.10 (CI, not locally on torch 2.8) when the first non-regression
        posembd_sin test runs. Same class of fix as the earlier euler solver.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Use functional F.pad in image_batching

        Instantiating torch.nn.ReflectionPad2d inside image_batching on every
        call creates a fresh nn.Module each time, which torch.compile / AOT
        autograd struggles to trace cleanly under fullgraph=True on torch 2.10.
        Switch to torch.nn.functional.pad which is a plain functional call and
        traces without allocating a module. Same result semantically.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Replace einops.rearrange with native torch reshape+permute

        einops.rearrange goes through a pattern-matched lowering path that
        torch.compile / inductor on torch 2.10 handles fragilely in the
        image_batching / image_fuse hot paths. The underlying transform is a
        plain view + permute + view, so express it directly: this gives inductor
        a straightforward sequence of ops to trace, and drops the einops
        dependency from this module.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Materialise returned tensors in multi-diffusion fuse path

        Under torch.compile / inductor on torch 2.10, a compiled sample() call
        through MultiDiffusionPredictor was returning a tensor whose metadata
        was valid but whose data pointer was dangling (use-after-free) — the
        caller SIGABRTed on the first read of the tensor data. Add .contiguous()
        at the two boundaries that returned a view: image_fuse returns
        x_folded[...] / overlap_count[...], and MultiDiffusionModel2D.forward
        returns the (possibly fused) inner-model output. Forcing fresh storage
        on each boundary prevents the returned tensor from aliasing a buffer
        whose lifetime ends with the compiled frame.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Use clone instead of contiguous at fuse boundary

        The second torch.compile call of a fused MultiDiffusionPredictor was
        segfaulting (SIGSEGV) while the first succeeded. .contiguous() is a
        no-op when the tensor is already contiguous, so inductor could still
        see the returned tensor as aliasing an internal buffer across calls.
        .clone() always allocates fresh storage, so successive compiled calls
        get independent outputs. Also drop the redundant .contiguous() added
        earlier in MultiDiffusionModel2D.forward now that image_fuse owns that
        boundary.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Revert speculative fuse-boundary copies and xfail full-sampler compile on torch>=2.10

        Revert commits 3dfcdb51, 746518ff and a007c469 (native-torch rearrange
        in image_batching/image_fuse, .contiguous() on returned tensors, .clone()
        at fuse boundary) since they did not resolve the torch 2.10 inductor
        codegen segfault in TestMultiDiffusionFullSamplerCompile. Keep commits
        7e1db11c (pos_embd .contiguous() for the glibc heap corruption in
        posembd_sin non-regression tests) and feb0d9e4 (ReflectionPad2d → F.pad).

        Gate TestMultiDiffusionFullSamplerCompile with xfail(run=False) when
        torch>=2.10 so the SIGSEGV does not bring down the pytest process.
        TestMultiDiffusionSampleCompile (per-step denoiser compile) still runs.

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Minor updates to predictor.py

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        * Drop redundant _patching_type and add test-time-only docstring warning

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

        ---------

        Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    commit 16a336f9ac267ffa6b2fee8e02d0d21dd104007a
    Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
    Date:   Thu Apr 30 00:03:24 2026 -0700

        FSDP optimizer state channels last fix (#1597)

        * Fix channels last FSDP optimizer state load bug

        * lint

        * Catch use_orig_params=True case

    commit 845906f4d0846061b516749d0f1ee450c16aca91
    Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
    Date:   Wed Apr 29 23:59:29 2026 -0700

        Add HealDA dataloader protocols and init recipe (#1555)

        * Add healda protocols and loaders to experimental

        * Cleanup and address imports

        * Update precommit for examples tests

        * integrate restartable sampler, other updates, migrate tests

        * move imports, cleanup

        * ruff check fix

        * skip prefetch on CPU

        * Rename to local_platform

        * Revert precommit change

        * greptile feedback

        * Migrate CSVs and deps to example

        * lockfile fix

commit 91a942b6e0e75951907d2a21b49d78a3c837e78f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 21:12:32 2026 -0400

    Adds Greptile minor fixes

commit b24f9b6e99e2abed1e9b6f530a108cbe1a99ca44
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:56:24 2026 -0400

    Back-merges dataset interrogate fix

commit f03f7218c02cfbd517d176ffc722accdbf343501
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:54:30 2026 -0400

    Refactor DrivAerMLSample class to inherit from TensorClass and clean up unused imports. This change enhances type handling and maintains code clarity.

commit 6ddfb5ae569e59b7872184fe0f4838b3039a3735
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:43:01 2026 -0400

    Removes accidentally-commited benchmarks; these will come later

commit 9fa0b5d2744471d5b33baf7a6927ec19f2aa0153
Merge: 3e67057c 4c52a453
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:37:57 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-drivaerml-standalone

commit 3e67057c596760d9aecf0505b5f9de19e98cc9bd
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:35:32 2026 -0400

    Partial merge from add-GLOBE-3D-BarnesHut

commit 9d582ad73a7204797c859bf0fa195bee63c20064
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 22:38:53 2026 +0000

    default to expand far targets

commit a297fb9f3e332a442649a09c0621915089c43abe
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 22:35:13 2026 +0000

    sync with airfrans

commit 0c0c0280fd86f499403c0171fa371eeada4b4e38
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 22:33:34 2026 +0000

    sync with airfrans

commit 4cb586ab3750c2a79fa23c39a6929ccf62fcbc53
Merge: 645701fc ed855da4
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 15:23:50 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-model-progress

commit 645701fc7af2fd0ba8ffc97da9b9c8698d9dec64
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 15:06:55 2026 -0400

    Partial merge from add-GLOBE-3D-BarnesHut

commit b0fb9d299046fb65c24e781433d849f3c644b58a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 19:00:02 2026 +0000

    Squashed commit of the following:

    commit 15d7913a1623f5f9afe61e0257007898babf4a72
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Sun May 3 13:45:06 2026 -0400

        Enhance DataLoader worker configuration for distributed training

        - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
        - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
        - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

    commit 65675a42e1b4617f088fd079f999bec3efe3a4b0
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 1 17:53:47 2026 -0400

        Refactor compile logging in training script

        - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
        - Updated the training script to call this new function, improving log clarity during distributed training.
        - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

    commit 948da86713ea492c599d14e9f55981ab1e3d7549
    Author: Peter Sharpe <peterdsharpe@gmail.com>
    Date:   Fri May 1 15:40:31 2026 -0400

        docs

commit 4f2254d0de5991a835f5617111ac11ce588efff4
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 18:42:52 2026 +0000

    Refactor Kernel class in field_kernel.py by removing the add_semantics method and optimizing the output processing. The changes streamline the handling of tensor semantics and improve performance by constructing per-field outputs in a single pass, addressing potential issues with PyTorch's tensor operations.

commit 459e3c0a3e63ae21a08c05c639665ff96c19a47e
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 17:47:39 2026 +0000

    sync onto td 0.12.2

commit 15d7913a1623f5f9afe61e0257007898babf4a72
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 13:45:06 2026 -0400

    Enhance DataLoader worker configuration for distributed training

    - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
    - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
    - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

commit 65675a42e1b4617f088fd079f999bec3efe3a4b0
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 1 17:53:47 2026 -0400

    Refactor compile logging in training script

    - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
    - Updated the training script to call this new function, improving log clarity during distributed training.
    - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

commit 948da86713ea492c599d14e9f55981ab1e3d7549
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 1 15:40:31 2026 -0400

    docs

commit 7992acc211470de1e21d446a9659a4fb70901a21
Merge: dd801555 aec5a3aa
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 29 16:11:47 2026 -0400

    Merge branch 'psharpe/add-GLOBE-model-progress' into psharpe/add-GLOBE-3D-BarnesHut

commit aec5a3aa5dfae3717c5daf97b24a196ed056286a
Merge: 973bdd53 645fb0aa
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 29 16:10:53 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-model-progress

commit 973bdd53c4cee8b575e75172085f33bf7435a485
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 29 15:57:31 2026 -0400

    Adds helpful error messages

commit 65a9c54418be92af19324d66062e53406a6b1137
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 29 15:56:58 2026 -0400

    Adds helpful error messages

commit 8bde77a319a64744c2fc6028a58b7fd43726a5d8
Merge: 9a8f056c a8a0739a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 29 10:00:32 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-model-progress

commit dd8015555b4a5d368a816e29a79d1d5e283bb921
Merge: 9d63b124 9a8f056c
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 17:13:46 2026 -0400

    Merge branch 'psharpe/add-GLOBE-model-progress' into psharpe/add-GLOBE-3D-BarnesHut

commit 9a8f056cde5589601568fd2ed784f0f515e79e4e
Merge: 61f5cc00 fc088545
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 17:13:26 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-model-progress

commit 9d63b124c33074ec5447afc76389ba78d55b9791
Merge: 56db224a 61f5cc00
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 16:49:27 2026 -0400

    Merge psharpe/add-GLOBE-model-progress into psharpe/add-GLOBE-3D-BarnesHut

    Brings in the corrected Pade docstring, the AirFRANSSample TensorClass-
    inheritance refactor, and the removal of the broken RAM caching path.

    Note: git's recursive auto-merge of cached_dataset.py duplicated the
    'sample_paths is empty' validation block (line aligner confused by both
    branches removing self.use_ram_caching above it). One copy was removed
    during conflict resolution. Resulting blob is byte-identical to MERGE_HEAD.

commit 61f5cc00b27e848a6a48efc2c03d52affdf29cc9
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 16:27:06 2026 -0400

    Strips out broken ram caching

commit dd331c9e0aeadd520fa5aa282f11483cedd61875
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 15:25:09 2026 -0400

    corrects a docstring

commit 7d4a113d5312a8170a7d9f7374f7f83aa9bcec6a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 12:05:29 2026 -0400

    Syncs AirFRANS updates

commit ebaa00ea1f1cb05bb651bc5b95bcd8039cd02e55
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 12:00:05 2026 -0400

    Adds self to pade.py codeowners

commit fd7d4695a6384551c5903f6c707ec61bb3a06088
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 11:55:32 2026 -0400

    Adds compile_logging utilities and prefetching utilities

commit 783b849378313126a01e0012dc5c8ea7f882f7ed
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 11:50:48 2026 -0400

    Adds nonregression testing

commit d1209c0225d4c75c8b3551ea636bea0dc052ec0e
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 11:47:49 2026 -0400

    minor formatting syncs

commit 83cd333133cc8ced126abdd54f2bbd282965bf45
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 11:47:33 2026 -0400

    verified model arch new features (self_regularization_beta)

commit f69179b5442226b3bb6e4317f03c92f3d8010995
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 11:45:59 2026 -0400

    Migrate cached_dataset.py

commit 56db224aed395535ab0b790babb9846d78d0bcff
Merge: 096d3913 3c396d43
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 27 11:12:59 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-3D-BarnesHut

commit 096d3913862726bb4210edb5e3c1b7e78b1b0a8d
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 15:28:29 2026 +0000

    default to 128x3

commit 69db12e4ed2046f016f8b8e00460a8594684df46
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 15:27:32 2026 +0000

    forward expand far targets

commit 86757d757ca801870cd5d01787b165621b475a04
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 15:26:43 2026 +0000

    removes smooth_pade

commit 34dcff5cc4e77a0410e9a274d29024044df8afdf
Merge: ce0dbb85 8a02bced
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 15:22:55 2026 +0000

    Merge branch 'main' into psharpe/add-GLOBE-3D-BarnesHut

commit ce0dbb8500f3d085e2ac193d744308a6b27fef08
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 15:21:42 2026 +0000

    syncs changes to train.py

commit af594ccc9b84d91d143f1085af86dd5efff45451
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 15:20:10 2026 +0000

    Adds `prediction_chunk_size` parameter to GLOBE model for memory-efficient evaluation. This allows chunked processing of prediction points, improving performance when handling large datasets. Updates documentation to reflect new parameter and its usage.

commit 62e870e79a7ded5444de303b81d0b60a302181b9
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 15:01:01 2026 +0000

    formatting

commit b15a414d44826ed7d918b43bb43c430e7de2b66b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 15:00:33 2026 +0000

    formatting

commit 002cc44d2729746304499075d4fd8880fdc7be94
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 14:59:32 2026 +0000

    adds new params to airfrans

commit 99955a114de94731e43d1c0d25c24e24c305949f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 14:51:07 2026 +0000

    formatting

commit fc4761d8187863d10e0475a01a51bcded8224886
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri Apr 24 14:49:21 2026 +0000

    Adds in more pade kernel capabilities

commit fd68c8f624caaed8fa186bfa6a9c35b355f9a88d
Merge: 6a8ab360 59321561
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue Apr 21 09:50:44 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-3D-BarnesHut

commit 6a8ab36035c2640013fae55fdb55cae57b46ac0f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 20 16:25:00 2026 +0000

    switch airfrans to theta 0

commit 7dfb8cafc1bfef610d396bfd5e0995f2732cabd0
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 20 15:12:26 2026 +0000

    formatting

commit 57e06a48ad13c5ef1d30a37b3d77612210f7fc82
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 20 15:11:29 2026 +0000

    formatting

commit ceeee52be6c448b342cb4bffd5a471e39b6eae74
Merge: 625e43e2 8b1e3e8b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 8 11:42:27 2026 -0400

    Merge branch 'psharpe/add-GLOBE-3D-BarnesHut' of https://github.com/peterdsharpe/physicsnemo into psharpe/add-GLOBE-3D-BarnesHut

commit 625e43e24da748556927df9701795a64103279b2
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 6 14:38:40 2026 -0400

    Update theta effect visualization script to enhance configuration and error evaluation

    - Increased the number of source points from 20 to 60 and grid resolution from 128 to 256 for improved accuracy.
    - Refactored mode configurations to use tuples for better clarity and added new labels and colors for modes.
    - Updated error evaluation logic to accommodate the new mode structure, ensuring consistent handling across all configurations.
    - Enhanced plotting logic to reflect changes in mode handling and improve visualization clarity.

commit 8b1e3e8b23bfccc86a5f9f650c1ef1b4d0c2b984
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Apr 6 18:27:17 2026 +0000

    shutdown earlier

commit cf58894e0f78e622231b14d51a6232bc621cb7a0
Merge: 21a99804 f124d964
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 1 14:32:13 2026 +0000

    Merge branch 'psharpe/add-GLOBE-3D-BarnesHut' of https://github.com/peterdsharpe/physicsnemo into psharpe/add-GLOBE-3D-BarnesHut

commit 21a99804427a75f2eea06c1b4a0b4043be88f522
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 1 14:31:58 2026 +0000

    Refactor hyperparameter logging in GLOBE examples

    - Removed reliance on `inspect` for constructor parameters in `log_hyperparameters`.
    - Updated to use canonical constructor arguments from `model._args["__args__"]`, improving robustness and avoiding silent drops of untracked parameters.

commit ca805ad45240e8ba51cd8d6880af5b3e677ce5f7
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 1 14:18:49 2026 +0000

    Adds a backwards-compat self-regularization-beta parameter.

commit 06482befa2c6d7f47a108a09e682303253096dd8
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Apr 1 14:18:18 2026 +0000

    Adds benchmark_accuracy

commit f124d964ff1c57ccc7d7e6015e70dae69f66325b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue Mar 31 22:06:11 2026 -0400

    Refactor 3D visualization setup in theta effect scripts

    - Updated the figure creation process in `theta_effect.py` and `theta_effect_nonuniform_strength.py` to use `p.figure3d` for improved aspect ratio handling.
    - Removed redundant box aspect setting to streamline the visualization code.

commit 7436d7026c522cabbca2f3d0f11962848bb6997a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue Mar 31 22:04:27 2026 -0400

    Add expand_far_targets option to ClusterTree and kernel classes

    - Introduced an optional parameter `expand_far_targets` in the `ClusterTree`, `BarnesHutKernel`, and `MultiscaleKernel` classes to allow for the expansion of far-field target nodes to individual points. This change enhances the accuracy of interactions by eliminating target-side centroid approximations at the cost of additional kernel evaluations.
    - Updated relevant documentation to reflect the new parameter and its implications on performance and accuracy.

commit c5c942af35f33a86b35e9070d87a3352f6b18568
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue Mar 31 21:51:20 2026 -0400

    Add visualization script for the effect of the Barnes-Hut theta parameter on GLOBE's field kernel

    - Introduced a new Python script that generates multiple visualizations, including ClusterTree spatial hierarchies and kernel scalar field comparisons.
    - The script evaluates the impact of varying theta values on computational cost and approximation errors, enhancing understanding of the Barnes-Hut algorithm's performance.

commit a631ef89047d93a76a875fce7ff56814cd1d8b47
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue Mar 31 20:44:17 2026 -0400

    Adds visualization scripts

commit 379c02197d9dbdc5fdc8c0825276f863bc329b5f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Thu Mar 26 15:22:27 2026 +0000

    docs

commit c6e0b9de2c7c461b3b0dcd01ccc924d9f29938a0
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Thu Mar 26 15:19:44 2026 +0000

    Update MLflow launch scripts to specify ports for UI instances

    - Added port specifications to the MLflow UI launch commands in both `mlflow_launch.sh` scripts for `airfrans` and `drivaer`, ensuring that each instance runs on a unique port (5001 for airfrans and 5002 for drivaer).

commit bf43029c1588c93a7b96ffd3224425cdbcbf48ac
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Wed Mar 25 19:21:31 2026 -0400

    Adds Mesh dimensionality annotations for readability.

commit 3dac89498109d1684bbbf744e78c0a0512cfdbdf
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Tue Mar 24 16:34:36 2026 +0000

    Moves benchmarks into an explicitly-resource benchmarking, to distinguish from accuracy benchmarks.

commit 07ce3efb1f3231e1877cc4ba149918d17b99adf3
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Mon Mar 23 00:45:11 2026 +0000

    Disable bytecode-compile; let this happen lazily

commit 0f7ac796387cf2a3ca7dec2cb1976b8b993bbc41
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 16:16:14 2026 -0400

    Refactor dataclass usage to tensorclass in ClusterTree

    - Replaced `@dataclass` with `@tensorclass` for `DualInteractionPlan` and `SourceAggregates` to enhance performance and compatibility with tensor operations.
    - This change aligns with the ongoing optimization efforts in the GLOBE model's data handling.

commit 1f5d7656d8815b6de72f834a998e3f34179eb55e
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 20:14:37 2026 +0000

    Optimize source scalar and vector gathering in BarnesHutKernel

    - Refactored the gathering of source scalars and vectors to reduce GPU kernel calls, improving performance.
    - Implemented pre-flattening of source scalars and vectors using `concatenate_leaves`, followed by efficient indexing.
    - Adjusted the handling of gathered vectors to ensure compatibility with the feature engineering pipeline in `_evaluate_interactions`.

commit 6dd835845b4edae2f22ab04b392040900f07cd37
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 12:38:04 2026 -0400

    EXPERIMENTAL changes that disable lazy compilation from GLOBE kernels

commit 7948120728cc8604fe58e79dbcc957704aec8ab0
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 16:31:33 2026 +0000

    Enhance CompileDiagnosticsCollector for improved logging of graph breaks

    - Introduced a new internal class `_BreakRecord` to encapsulate details of graph break events, including the reason, user caller, and raw message.
    - Updated the `filter` method to parse and store detailed information about graph breaks and recompiles, improving the clarity of logged messages.
    - Added a new method `_extract_user_caller` to identify the nearest user code caller from the traceback, enhancing the context provided in logs.
    - Modified the `summary` method to display user caller information alongside graph break reasons, aiding in debugging and analysis.

commit f223f790633c5a55db5c8dbeb8a88dcc3f3c3941
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 15:49:49 2026 +0000

    Add non-regression tests for GLOBE model

    - Introduced a new test suite in `test_nonregression.py` to validate the GLOBE model's forward pass against saved reference outputs, ensuring numerical consistency.
    - Implemented functionality to generate and save reference data on first run, with subsequent tests comparing outputs to this reference.
    - Added a gradient flow test to verify that gradients propagate correctly through the model during training.
    - Created a binary reference output file `globe_nonregression_output.pth` for comparison in tests.

commit 851ce5f2d85d8e430a2b137d46020163493e83bc
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 15:47:58 2026 +0000

    save every 1

commit 9f4d49f8037030c3fda1beb2a0e8465f6764b9a2
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 11:34:56 2026 -0400

    Add validation method to DualInteractionPlan and corresponding tests

    - Implemented a `validate` method in the `DualInteractionPlan` class to check for internal consistency, including shape pairing, non-negativity, and bounds for broadcast tensors.
    - Added unit tests for the `validate` method to ensure correct behavior, including checks for valid plans, shape mismatches, out-of-bounds errors, and negative counts.
    - Integrated validation call in the `find_dual_interaction_pairs` method to enforce checks during execution.

commit 53f6a1a9ccd46b306f91e547b5d6290cfb1308f2
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 11:09:07 2026 -0400

    Add fixture to disable torch.compile during tests

    - Introduced a new pytest fixture `_disable_torch_compile` to prevent per-test compilation overhead by disabling torch.compile.
    - This change ensures that tests run with zero compilation cost while maintaining identical numerical behavior.

commit 652b0c28be48bbe70cbb39629fa89bf0b57a656a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 01:22:53 2026 -0400

    Remove unused total parameter from _ragged_arange call in BarnesHutKernel class

commit b42927132780a65457cd8cbb73a35eb56b5f9aa8
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 05:02:43 2026 +0000

    Add leaf node ID and segment ID attributes to ClusterTree

    - Introduced `leaf_node_ids` and `leaf_seg_ids` to optimize source aggregation by precomputing leaf indices during tree construction.
    - Updated relevant methods to utilize these new attributes, improving performance and maintaining compatibility with `torch.compile`.
    - Enhanced documentation to reflect the new properties and their usage.

commit 0ad3271344bb9aaa11f87d6f629d3564707aafa4
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 04:57:32 2026 +0000

    sets a default LR of 1e-2.

commit 71ab33296b3ea74e34597d3ef54de95f3942ee95
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 04:56:46 2026 +0000

    Expands out kwarg splat, which was causing a torch.compile graph break.

commit 56607fc99dfcaf43e0aaaa905f39c1d88ddf26ee
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 04:52:10 2026 +0000

    Add precomputed leaf field consistency tests for ClusterTree

    - Introduced tests for verifying leaf_node_ids and leaf_seg_ids against expected values.
    - Added checks for empty trees and single-point trees to ensure correct behavior.
    - Utilized _ragged_arange for leaf_seg_ids computation consistency.

commit fd65d17815d9b6d6ed5cc707b0278065ea52c9fe
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun Mar 22 04:51:54 2026 +0000

    Fixes ragged_arange to allow for counts of 0 while still preserving compile compatibility

commit 082a9c7fc65c5343a9e5669689fd8ce30450215c
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 22:00:26 2026 +0000

    Tweaks compile diagnostics + reduced log verbosity

commit e704d4bb245316ee4c3e498a2c5aea8fe04481c2
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 21:32:28 2026 +0000

    Fixes LR scaling from Moonshot paper

commit 5ab066bb1de1d07ce007e1741c632772978ab149
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 21:30:56 2026 +0000

    update paths

commit 713674cb6b540262bb7c7a5e52dbe309c2b4aa01
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 21:30:24 2026 +0000

    update paths

commit 179bf7ed75fb0d27c406f36d33bfff3e96edd23f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 21:29:22 2026 +0000

    Moves a lot of non-GLOBE-specific code to experimental.utils

commit 59139e4ced1d956f8914edc0a15325c546916ef6
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 21:10:21 2026 +0000

    Cleans up verbose logs

commit d7a6b3015dabd13ce8524841c5a4423fd22e639f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 21:03:15 2026 +0000

    Adds CompileDiagnosticsCollector

commit eb8d912ece968656fdd039a91e01914c71e4337a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 21:00:19 2026 +0000

    Filters warnings by default

commit d479c39f96689ab6df25dce39c1a9bfddf218d39
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 20:57:42 2026 +0000

    Fixes a graph break in ragged

commit fea587889b934526d3a067eebb9990260205dfde
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 20:40:50 2026 +0000

    reverts erroneous rdvz changes

commit c76b78dec29d2f3d416d70b90db7490b7ed3797f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sat Mar 21 19:41:25 2026 +0000

    Add `patience_steps` parameter to training scripts for dynamic learning rate adjustment. Update `ReduceLROnPlateau` scheduler to use calculated patience based on training data size. Refactor world-size learning rate adjustment for consistency across different GPU configurati…
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

A focused performance pass on the GLOBE model achieving ~30% per-step speedup on 16× B200s. The dominant wins are eliminating ~60 GPU↔CPU sync points per training step (deferred boolean compactions, static memory budget, bit_length-based loop bound), replacing two O(tree_depth) Python loops with O(N) prefix-sum range subtracts, and packing the four-phase Barnes-Hut scatter into a single index_add_ per phase. A real numerical bug (fp32 cumsum catastrophic cancellation giving ~100% wrong leaf centroids at DrivAer scale) is also fixed as a prerequisite.

  • Sync elimination (cluster_tree.py, field_kernel.py): per-iteration .any() / .item() / boolean-mask indexings replaced with one deferred compaction per traversal loop iteration; mem_get_info per kernel call replaced with a cached 25%-of-total-memory budget.
  • Packed scatter (field_kernel.py): all four BH phases now scatter into a single (n_targets, total_features) buffer via index_add_, cutting indexing_backward from ~23% to ~6% of GPU time.
  • Training script cleanup (train.py ×2): GradScaler removed (correct no-op for bfloat16 autocast), tree_build_device kwarg added, PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True exported.

Important Files Changed

Filename Overview
physicsnemo/experimental/models/globe/cluster_tree.py Major refactor eliminating ~60+ GPU↔CPU sync points: deferred leaf compaction, single-compaction dual-traversal, and O(N) prefix-sum aggregation replacing O(tree_depth) Python loops. Logic verified correct; new fp64 cumsum fixes real centroid precision bug at DrivAer scale.
physicsnemo/experimental/models/globe/field_kernel.py Packed scatter buffer (index_add_ replacing per-field scatter_add_), smoothing radius registered buffer, and static memory budget. Two P2 concerns: checkpoint now applies in eval mode by default, and fixed 25% memory budget may exceed available VRAM on loaded devices.
physicsnemo/experimental/models/globe/model.py Adds tree_build_device and use_gradient_checkpointing kwargs, wraps tree-build passes in torch.no_grad(), and removes the overly defensive .exclude("strengths") filter. reference_area is a registered buffer so .device fallback is valid. view→reshape fix is correct.
examples/cfd/external_aerodynamics/globe/airfrans/train.py GradScaler removed (correct for bfloat16 autocast), tree_build_device kwarg threaded through. Checkpoint forward-compat note: new checkpoints without scaler state cannot be loaded by old code.
examples/cfd/external_aerodynamics/globe/drivaer/train.py Identical GradScaler removal and tree_build_device addition as airfrans/train.py; same checkpoint forward-compat concern applies.
test/models/globe/test_barnes_hut_kernel.py Removes tests for now-deleted leaf_node_ids/leaf_seg_ids attributes; adds precision regression guards for the fp64 cumsum fix, a bf16 autocast forward correctness test, and a CUDA sync-budget regression test. Coverage is well-targeted at the riskiest new code paths.
pyproject.toml Bumps tensordict floor to 0.12.2 and adds psutil>=6.0.0 for the CPU memory-query fallback path. Both additions are justified by the PR changes.

Comments Outside Diff (1)

  1. examples/cfd/external_aerodynamics/globe/airfrans/train.py, line 591-597 (link)

    P2 Checkpoint forward-compatibility break

    Checkpoints saved with the new code (no scaler key) will fail to load with the old training script, which still calls load_checkpoint(..., scaler=scaler, ...) and expects that key in the saved state dict. While the PR guarantees backward compat (old checkpoint → new code), users who need to roll back (e.g., a bugfix release reverting to the old code) or who share a checkpoint with a colleague running an older version will hit a hard failure. Worth documenting this one-way migration in the training README or a comment near save_checkpoint.

Reviews (1): Last reviewed commit: "Squashed commit of the following:" | Re-trigger Greptile

Comment thread physicsnemo/experimental/models/globe/field_kernel.py
Comment thread physicsnemo/experimental/models/globe/field_kernel.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant