Fix over-masking of cyclic longitude axis in shifted-window attention#1600
Fix over-masking of cyclic longitude axis in shifted-window attention#1600awikner wants to merge 4 commits into
Conversation
… mask get_shift_window_mask was partitioning the attention mask along all three spatial axes (Pl, Lat, Lon) in 3D, and along both axes (Lat, Lon) in 2D. Because longitude is cyclic on the global ERA5 grid, the cyclic torch.roll shift is sufficient to handle wrap-around windows; adding a hard longitude mask additionally prevents tokens that are physically adjacent across the dateline from attending to one another. The Pangu-Weather paper (Bi et al., arXiv:2211.02556) states explicitly that wrap-around longitude windows "are directly merged into one window" and treats longitude as cyclic throughout (M_lon is omitted from the position bias for this reason). The fix removes the Lon + shift_lon padding, the lon_slices loop, and the trailing :Lon crop, leaving region IDs assigned by a Pl x Lat double loop only. The 2D path (FengWu) receives the symmetric change: its attention module is directly derived from Pangu-Weather and inherits the same cyclic-longitude design. This is a follow-up to NVIDIA#1492, which fixed the shift_lat typo in the forward cyclic roll; this PR fixes the mask that roll was supposed to be working with. Closes NVIDIA#1599. Signed-off-by: Alexander Wikner <awikner@uchicago.edu> Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Remove unused win_lon and shift_lon local variables from TestGetShiftWindowMask2D.test_longitude_unmasked_region_count; the test only needs win_lat and shift_lat to construct the latitude slices. Signed-off-by: Alexander Wikner <awikner@uchicago.edu> Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Greptile SummaryThis PR fixes the over-masking of the cyclic longitude axis in
Important Files Changed
Reviews (1): Last reviewed commit: "Fix ruff lint errors in test_shift_windo..." | Re-trigger Greptile |
| @@ -105,13 +109,13 @@ def get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3): | |||
| win_pl, win_lat, win_lon = window_size | |||
| shift_pl, shift_lat, shift_lon = shift_size | |||
There was a problem hiding this comment.
Unused
shift_lon variable after fix
shift_lon is unpacked here but is no longer used anywhere in the 3D branch — the lon_slices that consumed it were removed. Same in the 2D branch (line 116). Leaving it in is harmless but misleading; it signals that longitude shift plays a role in masking, which is precisely what this PR is removing. Consider replacing the unpacking with a discard or a _ suffix.
| def test_longitude_unmasked_region_count(self): | ||
| """Longitude must not be partitioned: only Pl x Lat region IDs (9).""" | ||
| input_resolution = (8, 24, 48) | ||
| window_size = (2, 6, 12) | ||
| shift_size = (1, 3, 6) | ||
| Pl, Lat, Lon = input_resolution | ||
| win_pl, win_lat, win_lon = window_size | ||
| shift_pl, shift_lat, shift_lon = shift_size | ||
|
|
||
| # Reconstruct the underlying region-ID map directly | ||
| img_mask = torch.zeros((1, Pl, Lat, Lon, 1)) | ||
| pl_slices = ( | ||
| slice(0, -win_pl), | ||
| slice(-win_pl, -shift_pl), | ||
| slice(-shift_pl, None), | ||
| ) | ||
| lat_slices = ( | ||
| slice(0, -win_lat), | ||
| slice(-win_lat, -shift_lat), | ||
| slice(-shift_lat, None), | ||
| ) | ||
| cnt = 0 | ||
| for pl in pl_slices: | ||
| for lat in lat_slices: | ||
| img_mask[:, pl, lat, :, :] = cnt | ||
| cnt += 1 | ||
|
|
||
| n_regions = len(torch.unique(img_mask)) | ||
| # 3 Pl bands x 3 Lat bands = 9; longitude must NOT add more partitions | ||
| assert n_regions == 9, ( | ||
| f"Expected 9 region IDs (Pl x Lat only), got {n_regions}. " | ||
| "Longitude axis must not be partitioned in the mask." | ||
| ) |
There was a problem hiding this comment.
Region-count test duplicates implementation internals
test_longitude_unmasked_region_count reconstructs img_mask from scratch with the same slice logic as get_shift_window_mask, then counts unique IDs. This is testing a hand-rolled copy of the code rather than the function under test. If the implementation drifts, the test will silently pass because both copies drift together. Consider driving the assertion through the public API instead — e.g. call get_shift_window_mask and check that the number of distinct finite values in the resulting attention mask is consistent with a Pl × Lat-only partition. The same applies to TestGetShiftWindowMask2D.test_longitude_unmasked_region_count.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Two changes based on greptile review of NVIDIA#1600: 1. Mark shift_lon as intentionally unused in get_shift_window_mask (both 3D and 2D branches) by replacing the variable name with _ . Leaving it named shift_lon implied it still played a role in longitude masking, which is the opposite of what this PR intends. 2. Rewrite test_longitude_unmasked_region_count (3D and 2D) to drive the assertion through the public API rather than a hand-rolled copy of the implementation. The new tests call get_shift_window_mask and assert that the returned mask is identical for every longitude window index — the direct invariant of an unmasked cyclic axis. Signed-off-by: Alexander Wikner <awikner@uchicago.edu> Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Fix: 3D shifted-window attention over-masks the cyclic longitude axis (Pangu, FengWu)
Summary
physicsnemo.nn.module.utils.get_shift_window_maskpartitions theshifted-window attention mask along all three spatial axes (pressure
level, latitude, longitude) in the 3D path, and along both spatial axes
(latitude, longitude) in the 2D path. For both Pangu-Weather (3D) and FengWu
(2D), the longitude axis is cyclic: the cyclic shift (
torch.rollon thelongitude dim) is itself sufficient to handle wrap-around windows. The current
code additionally adds a hard mask along longitude that prevents tokens which
were physically adjacent across the dateline from attending to one another.
The original Pangu-Weather paper explicitly says these wrap-around windows
should be merged, not separated.
This PR removes the longitude masking (and the associated
Lon + shift_lonpadding / trailing crop) so the mask now partitions only the (Pl, Lat) plane
in 3D and only the Lat axis in 2D.
This is a bugfix for both the 3D path (Pangu) and the 2D path (FengWu), for
the same root cause. It is the natural follow-up to PR #1492, which fixed the
shift_lat-vs-shift_lontypo in the cyclic shift itself; with #1492 theforward and reverse rolls match, but the mask built around them was still
constructed against an over-partitioned longitude axis.
Why this is wrong (citations)
Pangu-Weather (3D path)
From Bi et al., "Pangu-Weather: A 3D high-resolution model for fast and
accurate global weather forecast" (arXiv:2211.02556 / Nature, 2023):
actually close to each other. If half windows appear at both leftmost and
rightmost positions, they are directly merged into one window."
(verbatim — this exact sentence is the docstring of
get_shift_window_mask,but the implementation does the opposite of what it says).
M_londoes not appear because different longitudes share the same bias— the longitude indices are cyclic and spacing is evenly distributed along
this axis." Confirms longitude is treated as cyclic in the architectural
design.
The official pseudocode at
github.com/198808xc/Pangu-Weather/pseudocode.py
applies
roll3Dto all three dims (which physicsnemo also does, after #1492)but does not publish a complete
gen_mask. The verbal specification in thepaper is unambiguous, and the standard reference implementations against which
this can be validated mask only
(Pl, Lat).FengWu (2D path)
The FengWu paper (Chen et al., arXiv:2304.02948) uses the same global ERA5
grid at 0.25° resolution (
img_size=(721, 1440)) — longitude is manifestlycyclic on this grid.
The FengWu 2D attention module in physicsnemo is architecturally derived
directly from Pangu-Weather. This is confirmed in two ways:
Position-bias implementation.
get_earth_position_index(called forboth
EarthAttention3DandEarthAttention2D) is annotated in the sourcewith
"implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py".The 2D position bias uses
coords_w = torch.arange(win_lon)for bothquery and key (symmetric longitude treatment), the same construction that
the Pangu paper describes as cyclic (
M_londropped because longitude iscyclic and evenly distributed). This is
physicsnemo/nn/module/utils/utils.pylines 50–71.
Window-grouping structure. Both
window_partition(ndim=2)andEarthAttention2Dgroup feature maps by longitude (B*num_lonbatchdimension,
num_latwindow dimension), identical to the 3D case. The maskshape
(num_lon, num_lat, Wlat*Wlon, Wlat*Wlon)parallels the 3D(num_lon, num_pl*num_lat, …)layout. Any longitude masking on thisstructure would erroneously partition windows that span the dateline, for
the exact same reason as in 3D.
Taken together: the 2D path inherits the same cyclic-longitude design
assumptions as the 3D path, and the fix is identical in character and
justified by the same evidence.
Mechanism of the bug
Consider
pad_resolution=(8, 24, 48),window_size=(2, 6, 12),shift_size=(1, 3, 6)(the default Pangu config used intest/models/pangu/test_pangu.py).Before this PR.
get_shift_window_maskbuilds a mask of shape(1, 8, 24, 48 + 6, 1)and assigns 27 distinct region IDs across all threeaxes via three nested slice loops. After the trailing
[:, :, :, :48, :]cropthe third (
slice(-shift_lon, None)) longitude region is discarded, leaving18 distinct region IDs in the visible mask: 3 (Pl) × 3 (Lat) × 2 (Lon).
In any window that straddles the dateline after the cyclic shift, tokens that
the paper says should be "merged into one window" are split into two
non-attending sub-windows by the mask.
After this PR. The mask is built directly at shape
(1, 8, 24, 48, 1)with a
Pl × Latdouble loop that assigns 9 region IDs (3 × 3). Thelongitude axis is left unpartitioned.
window_partitionthen groups windowsby longitude (
x.permute(0, 5, 1, 3, 2, 4, 6, 7)) and the per-windowattention mask is uniform along longitude, so wrap-around windows attend
freely. This matches the paper's spec.
Diff
physicsnemo/nn/module/utils/shift_window_mask.py—get_shift_window_mask:if ndim == 3: Pl, Lat, Lon = input_resolution win_pl, win_lat, win_lon = window_size shift_pl, shift_lat, shift_lon = shift_size - img_mask = torch.zeros((1, Pl, Lat, Lon + shift_lon, 1)) + img_mask = torch.zeros((1, Pl, Lat, Lon, 1)) elif ndim == 2: Lat, Lon = input_resolution win_lat, win_lon = window_size shift_lat, shift_lon = shift_size - img_mask = torch.zeros((1, Lat, Lon + shift_lon, 1)) + img_mask = torch.zeros((1, Lat, Lon, 1)) if ndim == 3: pl_slices = ( slice(0, -win_pl), slice(-win_pl, -shift_pl), slice(-shift_pl, None), ) lat_slices = ( slice(0, -win_lat), slice(-win_lat, -shift_lat), slice(-shift_lat, None), ) - lon_slices = ( - slice(0, -win_lon), - slice(-win_lon, -shift_lon), - slice(-shift_lon, None), - ) cnt = 0 if ndim == 3: for pl in pl_slices: for lat in lat_slices: - for lon in lon_slices: - img_mask[:, pl, lat, lon, :] = cnt - cnt += 1 - img_mask = img_mask[:, :, :, :Lon, :] + img_mask[:, pl, lat, :, :] = cnt + cnt += 1 elif ndim == 2: for lat in lat_slices: - for lon in lon_slices: - img_mask[:, lat, lon, :] = cnt - cnt += 1 - img_mask = img_mask[:, :, :Lon, :] + img_mask[:, lat, :, :] = cnt + cnt += 1The downstream
window_partition/ mask-difference /masked_fillblock isunchanged. Public signature is unchanged. Output tensor shape is unchanged
(
(n_lon, n_pl·n_lat, W, W)in 3D /(n_lon, n_lat, W, W)in 2D); only itscontents change.
Test plan
Sanity check (CPU, no GPU needed)
Verified locally: post-fix the underlying region-ID map contains exactly 9
distinct IDs (vs. 18 pre-fix), as expected from a
Pl × Lat-only partition.Existing test suite
test/models/pangu/test_pangu.py::test_pangu_forwardandtest/models/fengwu/test_fengwu.py::test_fengwu_forwardvalidate forward-passoutputs against checked-in reference tensors at
test/models/pangu/data/pangu_output.pthandtest/models/fengwu/data/fengwu_output.pth. Those references were lastregenerated by PR #1492 when the cyclic-shift typo was fixed, and will need to
be regenerated again after this PR for the same reason — the corrected
attention mask changes activations, so the reference output values change.
Suggested workflow (matching #1492):
pangu_output.pthandfengwu_output.pth..pthfiles alongside the source change.I have not regenerated those
.pthblobs in this PR because the change wasprepared in a CPU-only environment.
Backwards compatibility
get_shift_window_maskis unchanged.produce slightly different activations after this fix. The same applies to
any pretrained FengWu checkpoint. This is the same compatibility caveat as
PR Fix window shift in pangu, fengwu #1492 — checkpoints already needed re-evaluation after that PR. Models
retrained from scratch after this fix will produce correct outputs.
Maintainers may want to flag the change in release notes.
Files changed
physicsnemo/nn/module/utils/shift_window_mask.py— fix (3D and 2Dbranches symmetrically).
test/nn/module/test_shift_window_mask.py— new unit tests forget_shift_window_mask(shape, binary values, region-ID count for both 3Dand 2D) and round-trip tests for
window_partition/window_reverse.CHANGELOG.md— add aFixedentry under the2.1.0a0section.Checklist notes
(
Signed-off-by: Alexander Wikner <awikner@uchicago.edu>).test/nn/module/test_shift_window_mask.pycovers thechanged function directly (shape, values, region-ID count, round-trip).
The existing
test_pangu_forwardandtest_fengwu_forwardnon-regressiontests will need their reference
.pthfiles regenerated before they pass(see Test plan above).
get_shift_window_maskupdated toexplain the cyclic-longitude design rationale.
2.1.0a0 ### Fixed.changed; no new required parameters added; output tensor shape unchanged.
ruff check,ruff format,interrogate,license,import-linter). Two ruff F841 lint errors(unused
win_lon/shift_lonvariables in the test) were fixed in afollow-up commit. Note:
import-linterrequireslint-importsonPATH;install
import-linterinto your Python environment(
pip install import-linter) and rungit commitwith that environmentactive.
Linked issue / related PR
partitions the cyclic longitude axis in Pangu/FengWu).
shift; this PR fixes the mask the cyclic shift was supposed to be working
with.