Skip to content

Fix over-masking of cyclic longitude axis in shifted-window attention#1600

Open
awikner wants to merge 4 commits into
NVIDIA:mainfrom
awikner:fix-pangu-3d-mask-longitude
Open

Fix over-masking of cyclic longitude axis in shifted-window attention#1600
awikner wants to merge 4 commits into
NVIDIA:mainfrom
awikner:fix-pangu-3d-mask-longitude

Conversation

@awikner
Copy link
Copy Markdown

@awikner awikner commented Apr 28, 2026

Fix: 3D shifted-window attention over-masks the cyclic longitude axis (Pangu, FengWu)

Summary

physicsnemo.nn.module.utils.get_shift_window_mask partitions the
shifted-window attention mask along all three spatial axes (pressure
level, latitude, longitude) in the 3D path, and along both spatial axes
(latitude, longitude) in the 2D path. For both Pangu-Weather (3D) and FengWu
(2D), the longitude axis is cyclic: the cyclic shift (torch.roll on the
longitude dim) is itself sufficient to handle wrap-around windows. The current
code additionally adds a hard mask along longitude that prevents tokens which
were physically adjacent across the dateline from attending to one another.
The original Pangu-Weather paper explicitly says these wrap-around windows
should be merged, not separated.

This PR removes the longitude masking (and the associated Lon + shift_lon
padding / trailing crop) so the mask now partitions only the (Pl, Lat) plane
in 3D and only the Lat axis in 2D.

This is a bugfix for both the 3D path (Pangu) and the 2D path (FengWu), for
the same root cause. It is the natural follow-up to PR #1492, which fixed the
shift_lat-vs-shift_lon typo in the cyclic shift itself; with #1492 the
forward and reverse rolls match, but the mask built around them was still
constructed against an over-partitioned longitude axis.

Why this is wrong (citations)

Pangu-Weather (3D path)

From Bi et al., "Pangu-Weather: A 3D high-resolution model for fast and
accurate global weather forecast"
(arXiv:2211.02556 / Nature, 2023):

  • "Along the longitude dimension, the leftmost and rightmost indices are
    actually close to each other. If half windows appear at both leftmost and
    rightmost positions, they are directly merged into one window
    ."

    (verbatim — this exact sentence is the docstring of get_shift_window_mask,
    but the implementation does the opposite of what it says).
  • "M_lon does not appear because different longitudes share the same bias
    — the longitude indices are cyclic and spacing is evenly distributed along
    this axis."
    Confirms longitude is treated as cyclic in the architectural
    design.

The official pseudocode at
github.com/198808xc/Pangu-Weather/pseudocode.py
applies roll3D to all three dims (which physicsnemo also does, after #1492)
but does not publish a complete gen_mask. The verbal specification in the
paper is unambiguous, and the standard reference implementations against which
this can be validated mask only (Pl, Lat).

FengWu (2D path)

The FengWu paper (Chen et al., arXiv:2304.02948) uses the same global ERA5
grid at 0.25° resolution (img_size=(721, 1440)) — longitude is manifestly
cyclic on this grid.

The FengWu 2D attention module in physicsnemo is architecturally derived
directly from Pangu-Weather. This is confirmed in two ways:

  1. Position-bias implementation. get_earth_position_index (called for
    both EarthAttention3D and EarthAttention2D) is annotated in the source
    with "implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py".
    The 2D position bias uses coords_w = torch.arange(win_lon) for both
    query and key (symmetric longitude treatment), the same construction that
    the Pangu paper describes as cyclic (M_lon dropped because longitude is
    cyclic and evenly distributed). This is physicsnemo/nn/module/utils/utils.py
    lines 50–71.

  2. Window-grouping structure. Both window_partition(ndim=2) and
    EarthAttention2D group feature maps by longitude (B*num_lon batch
    dimension, num_lat window dimension), identical to the 3D case. The mask
    shape (num_lon, num_lat, Wlat*Wlon, Wlat*Wlon) parallels the 3D
    (num_lon, num_pl*num_lat, …) layout. Any longitude masking on this
    structure would erroneously partition windows that span the dateline, for
    the exact same reason as in 3D.

Taken together: the 2D path inherits the same cyclic-longitude design
assumptions as the 3D path, and the fix is identical in character and
justified by the same evidence.

Mechanism of the bug

Consider pad_resolution=(8, 24, 48), window_size=(2, 6, 12),
shift_size=(1, 3, 6) (the default Pangu config used in
test/models/pangu/test_pangu.py).

Before this PR. get_shift_window_mask builds a mask of shape
(1, 8, 24, 48 + 6, 1) and assigns 27 distinct region IDs across all three
axes via three nested slice loops. After the trailing [:, :, :, :48, :] crop
the third (slice(-shift_lon, None)) longitude region is discarded, leaving
18 distinct region IDs in the visible mask: 3 (Pl) × 3 (Lat) × 2 (Lon).

In any window that straddles the dateline after the cyclic shift, tokens that
the paper says should be "merged into one window" are split into two
non-attending sub-windows by the mask.

After this PR. The mask is built directly at shape (1, 8, 24, 48, 1)
with a Pl × Lat double loop that assigns 9 region IDs (3 × 3). The
longitude axis is left unpartitioned. window_partition then groups windows
by longitude (x.permute(0, 5, 1, 3, 2, 4, 6, 7)) and the per-window
attention mask is uniform along longitude, so wrap-around windows attend
freely. This matches the paper's spec.

Diff

physicsnemo/nn/module/utils/shift_window_mask.pyget_shift_window_mask:

     if ndim == 3:
         Pl, Lat, Lon = input_resolution
         win_pl, win_lat, win_lon = window_size
         shift_pl, shift_lat, shift_lon = shift_size

-        img_mask = torch.zeros((1, Pl, Lat, Lon + shift_lon, 1))
+        img_mask = torch.zeros((1, Pl, Lat, Lon, 1))
     elif ndim == 2:
         Lat, Lon = input_resolution
         win_lat, win_lon = window_size
         shift_lat, shift_lon = shift_size

-        img_mask = torch.zeros((1, Lat, Lon + shift_lon, 1))
+        img_mask = torch.zeros((1, Lat, Lon, 1))

     if ndim == 3:
         pl_slices = (
             slice(0, -win_pl),
             slice(-win_pl, -shift_pl),
             slice(-shift_pl, None),
         )
     lat_slices = (
         slice(0, -win_lat),
         slice(-win_lat, -shift_lat),
         slice(-shift_lat, None),
     )
-    lon_slices = (
-        slice(0, -win_lon),
-        slice(-win_lon, -shift_lon),
-        slice(-shift_lon, None),
-    )

     cnt = 0
     if ndim == 3:
         for pl in pl_slices:
             for lat in lat_slices:
-                for lon in lon_slices:
-                    img_mask[:, pl, lat, lon, :] = cnt
-                    cnt += 1
-        img_mask = img_mask[:, :, :, :Lon, :]
+                img_mask[:, pl, lat, :, :] = cnt
+                cnt += 1
     elif ndim == 2:
         for lat in lat_slices:
-            for lon in lon_slices:
-                img_mask[:, lat, lon, :] = cnt
-                cnt += 1
-        img_mask = img_mask[:, :, :Lon, :]
+            img_mask[:, lat, :, :] = cnt
+            cnt += 1

The downstream window_partition / mask-difference / masked_fill block is
unchanged. Public signature is unchanged. Output tensor shape is unchanged
((n_lon, n_pl·n_lat, W, W) in 3D / (n_lon, n_lat, W, W) in 2D); only its
contents change.

Test plan

Sanity check (CPU, no GPU needed)

from physicsnemo.nn.module.utils.shift_window_mask import get_shift_window_mask
import torch

mask = get_shift_window_mask(
    input_resolution=(8, 24, 48),
    window_size=(2, 6, 12),
    shift_size=(1, 3, 6),
    ndim=3,
)
assert tuple(mask.shape) == (4, 16, 144, 144)        # (n_lon, n_pl*n_lat, W, W)
assert sorted(torch.unique(mask).tolist()) == [-100.0, 0.0]

Verified locally: post-fix the underlying region-ID map contains exactly 9
distinct IDs (vs. 18 pre-fix), as expected from a Pl × Lat-only partition.

Existing test suite

test/models/pangu/test_pangu.py::test_pangu_forward and
test/models/fengwu/test_fengwu.py::test_fengwu_forward validate forward-pass
outputs against checked-in reference tensors at
test/models/pangu/data/pangu_output.pth and
test/models/fengwu/data/fengwu_output.pth. Those references were last
regenerated by PR #1492 when the cyclic-shift typo was fixed, and will need to
be regenerated again after this PR for the same reason — the corrected
attention mask changes activations, so the reference output values change.

Suggested workflow (matching #1492):

  1. Apply this PR.
  2. Run the forward-pass tests in capture mode to overwrite
    pangu_output.pth and fengwu_output.pth.
  3. Commit the regenerated .pth files alongside the source change.

I have not regenerated those .pth blobs in this PR because the change was
prepared in a CPU-only environment.

Backwards compatibility

  • Public API of get_shift_window_mask is unchanged.
  • Output tensor shape is unchanged.
  • Output tensor values change (this is the bugfix).
  • Any pretrained Pangu checkpoint trained with the over-masked attention will
    produce slightly different activations after this fix. The same applies to
    any pretrained FengWu checkpoint. This is the same compatibility caveat as
    PR Fix window shift in pangu, fengwu #1492 — checkpoints already needed re-evaluation after that PR. Models
    retrained from scratch after this fix will produce correct outputs.
    Maintainers may want to flag the change in release notes.

Files changed

  • physicsnemo/nn/module/utils/shift_window_mask.py — fix (3D and 2D
    branches symmetrically).
  • test/nn/module/test_shift_window_mask.py — new unit tests for
    get_shift_window_mask (shape, binary values, region-ID count for both 3D
    and 2D) and round-trip tests for window_partition / window_reverse.
  • CHANGELOG.md — add a Fixed entry under the 2.1.0a0 section.

Checklist notes

  • Contributing Guidelines — commit is signed off
    (Signed-off-by: Alexander Wikner <awikner@uchicago.edu>).
  • Teststest/nn/module/test_shift_window_mask.py covers the
    changed function directly (shape, values, region-ID count, round-trip).
    The existing test_pangu_forward and test_fengwu_forward non-regression
    tests will need their reference .pth files regenerated before they pass
    (see Test plan above).
  • Documentation — docstring of get_shift_window_mask updated to
    explain the cyclic-longitude design rationale.
  • CHANGELOG — entry added under 2.1.0a0 ### Fixed.
  • Issue linkedcloses 🐛[BUG]: Shifted-window attention mask incorrectly partitions the cyclic longitude axis (Pangu, FengWu) #1599.
  • Models coding standards — no public API signatures or return types
    changed; no new required parameters added; output tensor shape unchanged.
  • Pre-commit hooks — all hooks pass (ruff check, ruff format,
    interrogate, license, import-linter). Two ruff F841 lint errors
    (unused win_lon / shift_lon variables in the test) were fixed in a
    follow-up commit. Note: import-linter requires lint-imports on PATH;
    install import-linter into your Python environment
    (pip install import-linter) and run git commit with that environment
    active.

Linked issue / related PR

awikner and others added 2 commits April 28, 2026 04:47
… mask

get_shift_window_mask was partitioning the attention mask along all three
spatial axes (Pl, Lat, Lon) in 3D, and along both axes (Lat, Lon) in 2D.
Because longitude is cyclic on the global ERA5 grid, the cyclic torch.roll
shift is sufficient to handle wrap-around windows; adding a hard longitude
mask additionally prevents tokens that are physically adjacent across the
dateline from attending to one another.

The Pangu-Weather paper (Bi et al., arXiv:2211.02556) states explicitly that
wrap-around longitude windows "are directly merged into one window" and treats
longitude as cyclic throughout (M_lon is omitted from the position bias for
this reason). The fix removes the Lon + shift_lon padding, the lon_slices
loop, and the trailing :Lon crop, leaving region IDs assigned by a Pl x Lat
double loop only. The 2D path (FengWu) receives the symmetric change: its
attention module is directly derived from Pangu-Weather and inherits the same
cyclic-longitude design.

This is a follow-up to NVIDIA#1492, which fixed the shift_lat typo in the forward
cyclic roll; this PR fixes the mask that roll was supposed to be working with.

Closes NVIDIA#1599.

Signed-off-by: Alexander Wikner <awikner@uchicago.edu>
Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Remove unused win_lon and shift_lon local variables from
TestGetShiftWindowMask2D.test_longitude_unmasked_region_count; the test only
needs win_lat and shift_lat to construct the latitude slices.

Signed-off-by: Alexander Wikner <awikner@uchicago.edu>
Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
@awikner awikner requested a review from loliverhennigh as a code owner April 28, 2026 13:08
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR fixes the over-masking of the cyclic longitude axis in get_shift_window_mask for both the 3D (Pangu-Weather) and 2D (FengWu) shifted-window attention paths. The longitude dimension is cyclic — wrap-around windows should be merged, not separated — so the fix removes the longitude partition from the mask, leaving only (Pl, Lat) partitioning in 3D and (Lat) in 2D, which matches the Pangu-Weather paper specification.

  • The existing integration tests test_pangu_forward and test_fengwu_forward compare against checked-in .pth reference tensors that are not updated in this PR. Because the corrected mask changes activations, these tests will fail on CI until pangu_output.pth and fengwu_output.pth are regenerated and committed (as the PR description acknowledges).

Important Files Changed

Filename Overview
physicsnemo/nn/module/utils/shift_window_mask.py Core bugfix: removes longitude-axis masking from get_shift_window_mask so the mask partitions only (Pl, Lat) in 3D and (Lat) in 2D, matching the Pangu-Weather cyclic-longitude spec. Logic is correct; minor residual: shift_lon is still unpacked but no longer used.
test/nn/module/test_shift_window_mask.py New unit tests covering shape, binary-value invariant, region-count, and round-trip for both 3D and 2D paths. Region-count tests duplicate internal implementation logic rather than testing through the public API; all other tests are solid.
CHANGELOG.md Adds a clear Fixed entry under 2.1.0a0 describing the longitude over-masking bug and the cyclic treatment of longitude.

Reviews (1): Last reviewed commit: "Fix ruff lint errors in test_shift_windo..." | Re-trigger Greptile

@@ -105,13 +109,13 @@ def get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3):
win_pl, win_lat, win_lon = window_size
shift_pl, shift_lat, shift_lon = shift_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unused shift_lon variable after fix

shift_lon is unpacked here but is no longer used anywhere in the 3D branch — the lon_slices that consumed it were removed. Same in the 2D branch (line 116). Leaving it in is harmless but misleading; it signals that longitude shift plays a role in masking, which is precisely what this PR is removing. Consider replacing the unpacking with a discard or a _ suffix.

Comment on lines +67 to +99
def test_longitude_unmasked_region_count(self):
"""Longitude must not be partitioned: only Pl x Lat region IDs (9)."""
input_resolution = (8, 24, 48)
window_size = (2, 6, 12)
shift_size = (1, 3, 6)
Pl, Lat, Lon = input_resolution
win_pl, win_lat, win_lon = window_size
shift_pl, shift_lat, shift_lon = shift_size

# Reconstruct the underlying region-ID map directly
img_mask = torch.zeros((1, Pl, Lat, Lon, 1))
pl_slices = (
slice(0, -win_pl),
slice(-win_pl, -shift_pl),
slice(-shift_pl, None),
)
lat_slices = (
slice(0, -win_lat),
slice(-win_lat, -shift_lat),
slice(-shift_lat, None),
)
cnt = 0
for pl in pl_slices:
for lat in lat_slices:
img_mask[:, pl, lat, :, :] = cnt
cnt += 1

n_regions = len(torch.unique(img_mask))
# 3 Pl bands x 3 Lat bands = 9; longitude must NOT add more partitions
assert n_regions == 9, (
f"Expected 9 region IDs (Pl x Lat only), got {n_regions}. "
"Longitude axis must not be partitioned in the mask."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Region-count test duplicates implementation internals

test_longitude_unmasked_region_count reconstructs img_mask from scratch with the same slice logic as get_shift_window_mask, then counts unique IDs. This is testing a hand-rolled copy of the code rather than the function under test. If the implementation drifts, the test will silently pass because both copies drift together. Consider driving the assertion through the public API instead — e.g. call get_shift_window_mask and check that the number of distinct finite values in the resulting attention mask is consistent with a Pl × Lat-only partition. The same applies to TestGetShiftWindowMask2D.test_longitude_unmasked_region_count.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

awikner and others added 2 commits April 28, 2026 08:20
Two changes based on greptile review of NVIDIA#1600:

1. Mark shift_lon as intentionally unused in get_shift_window_mask (both 3D
   and 2D branches) by replacing the variable name with _ . Leaving it named
   shift_lon implied it still played a role in longitude masking, which is the
   opposite of what this PR intends.

2. Rewrite test_longitude_unmasked_region_count (3D and 2D) to drive the
   assertion through the public API rather than a hand-rolled copy of the
   implementation. The new tests call get_shift_window_mask and assert that
   the returned mask is identical for every longitude window index — the
   direct invariant of an unmasked cyclic axis.

Signed-off-by: Alexander Wikner <awikner@uchicago.edu>
Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛[BUG]: Shifted-window attention mask incorrectly partitions the cyclic longitude axis (Pangu, FengWu)

1 participant