Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8586d6c
Example to use efficient attention
wujingyue Dec 9, 2025
6ad01f1
Builds
wujingyue Dec 10, 2025
c206b32
Fix test
wujingyue Dec 10, 2025
bfae52a
Fix test_sdpa.py
wujingyue Dec 10, 2025
3a87466
Fix test_alphafold3
wujingyue Dec 10, 2025
f0862fa
Move to direct
wujingyue Dec 10, 2025
40c13fd
WIP
wujingyue Dec 11, 2025
460b683
Fix test_sdpa
wujingyue Dec 11, 2025
2f68957
Change bias and mask's argument order
wujingyue Dec 11, 2025
99a235e
Argument order
wujingyue Dec 12, 2025
19a13dd
Data attributes
wujingyue Dec 13, 2025
02a9ede
Use math
wujingyue Dec 13, 2025
5f72622
unnecessary ir_utils
wujingyue Dec 13, 2025
56a438b
Change batch_size and n_tokens to 3 and 5 for debugging
wujingyue Dec 13, 2025
67a9ace
Comment
wujingyue Dec 13, 2025
bdb7dc5
Comment
wujingyue Dec 13, 2025
c6f122d
Fix multidevice tests
wujingyue Dec 13, 2025
b3673b3
Review
wujingyue Dec 20, 2025
e4457c9
Merge branch 'main' into wjy/bias
wujingyue Dec 20, 2025
ecb515b
Fix comparison
wujingyue Dec 20, 2025
b0d4d4d
WIP
wujingyue Dec 20, 2025
2789485
Reference implementation for triangle updates
wujingyue Dec 21, 2025
b24839d
Update tests/python/direct/test_alphafold3.py
wujingyue Dec 21, 2025
e1f4f08
Add missing norms
wujingyue Dec 24, 2025
815c16d
Add missing mask
wujingyue Dec 24, 2025
56a9c93
Add missing layernorm in triangle attention
wujingyue Dec 24, 2025
e57a41f
Merge remote-tracking branch 'origin/main' into wjy/update
wujingyue Jan 5, 2026
62a7b31
Comment
wujingyue Jan 5, 2026
620e296
Remove redundant code
wujingyue Jan 5, 2026
063d944
Comment mask
wujingyue Jan 6, 2026
5502896
Fix test
wujingyue Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/python_direct/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,24 @@ Returns
-------
list of Val
The shape of this tensor.
)")
.def(
"dtype",
[](TensorView* self) -> PrimDataType {
DataType dt = self->dtype();
NVF_CHECK(
std::holds_alternative<PrimDataType>(dt.type),
"Expected PrimDataType but got type: ",
dt);
return std::get<PrimDataType>(dt.type);
},
R"(
Get the data type of this tensor.

Returns
-------
DataType
The data type of this tensor.
Comment on lines +249 to +251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Docstring incorrectly states return type as 'DataType' but should be 'PrimDataType' to match the actual return type.

Suggested change
-------
DataType
The data type of this tensor.
-------
PrimDataType
The data type of this tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine given

py::enum_<PrimDataType>(nvfuser, "DataType", py::module_local())
. The Python user expects to use DataType instead of PrimDataType.

)")
.def("has_root", &TensorView::hasRoot, R"(
Check if this tensor has a root domain.
Expand Down
50 changes: 18 additions & 32 deletions python/python_direct/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <bindings.h>
#include <ops/all_ops.h>
#include <ops/arith.h>
#include <utils.h>

namespace nvfuser::python {

Expand Down Expand Up @@ -2418,46 +2419,31 @@ TensorView* expand_fn(TensorView* arg, ShapeType generic_new_shape) {

template <class ShapeType>
TensorView* broadcast_in_dim_fn(
TensorView* arg,
TensorView* input,
ShapeType generic_output_shape,
std::vector<int64_t>& broadcast_dims) {
const std::vector<int64_t>& nonbroadcast_dims) {
std::vector<Val*> output_shape = SequenceAsVector(generic_output_shape);
NVF_CHECK(
output_shape.size() >= broadcast_dims.size(),
"broadcast_dims vector size is too big for output shape!");
NVF_CHECK_GE(output_shape.size(), nonbroadcast_dims.size());

const auto arg_ndims = static_cast<size_t>(std::ranges::distance(
arg->getLoopDomain() | TensorDomain::kNoReductions));
NVF_CHECK(
output_shape.size() >= broadcast_dims.size(),
"The new shape is expected to be greater-then-or-equal to the input: ",
output_shape.size(),
" vs ",
arg_ndims);
NVF_CHECK(
arg_ndims == broadcast_dims.size(),
"The broadcast dimensions should match the input dimensions: ",
arg_ndims,
" vs ",
broadcast_dims.size(),
". arg = ",
arg->toString());
const auto input_ndim = std::ranges::distance(
input->getLogicalDomain() | TensorDomain::kNoReductions);
NVF_CHECK_GE(std::ssize(output_shape), input_ndim);
NVF_CHECK_EQ(input_ndim, std::ssize(nonbroadcast_dims));

std::vector<bool> is_broadcast_dim(output_shape.size(), true);
for (const auto idx : arange(broadcast_dims.size())) {
if (idx > 0) {
NVF_CHECK(
broadcast_dims[idx - 1] < broadcast_dims[idx],
"Broadcast dimension is not greater than the previous value.");
}
for (int64_t nonbroadcast_dim : nonbroadcast_dims) {
nonbroadcast_dim = wrapDim(nonbroadcast_dim, std::ssize(output_shape));
NVF_CHECK(
broadcast_dims[idx] < static_cast<int>(output_shape.size()),
"Invalid broadcast_dims value.");
is_broadcast_dim.at(broadcast_dims[idx]) = false;
is_broadcast_dim.at(nonbroadcast_dim),
"nonbroadcast_dim (",
nonbroadcast_dim,
") is specified more than once.");
is_broadcast_dim.at(nonbroadcast_dim) = false;
}

auto bcast_output = broadcast(arg, is_broadcast_dim);
return expand(bcast_output, output_shape);
TensorView* output = broadcast(input, is_broadcast_dim);
output = expand(output, output_shape);
return output;
}

template <class ShapeType>
Expand Down
162 changes: 155 additions & 7 deletions tests/python/direct/test_alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass
from enum import Enum, auto

from nvfuser_direct import FusionDefinition, DataType
from nvfuser_direct import FusionDefinition, DataType, TensorView


@dataclass
Expand All @@ -28,14 +28,157 @@ class Direction(Enum):
OUTGOING = auto() # aka starting node


def layer_norm(
fd: FusionDefinition, x: TensorView, w: TensorView, b: TensorView
) -> TensorView:
io_dtype = x.dtype()
x = fd.ops.cast(x, dtype=DataType.Float)
var, mean = fd.ops.var_mean(x, dims=[-1], correction=0, keepdim=True)
y = fd.ops.sub(x, mean)
var = fd.ops.add(var, fd.define_scalar(1e-5))
y = fd.ops.mul(y, fd.ops.rsqrt(var))
shape = fd.ops.shape(x)
w = fd.ops.broadcast_in_dim(w, shape=shape, broadcast_dims=[-1])
y = fd.ops.mul(y, w)
b = fd.ops.broadcast_in_dim(b, shape=shape, broadcast_dims=[-1])
y = fd.ops.add(y, b)
y = fd.ops.cast(y, dtype=io_dtype)
return y


def gating(
fd: FusionDefinition,
z: TensorView,
w_p: TensorView,
z_in: TensorView,
w_g: TensorView,
) -> TensorView:
io_dtype = z.dtype()
p = fd.ops.linear(z, w_p)
g = fd.ops.linear(z_in, w_g)
g = fd.ops.sigmoid(g)
z = fd.ops.mul(p, g)
return fd.ops.cast(z, dtype=io_dtype)


# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-updates
#
# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
# prediction with AlphaFold. Nature 596, 583–589 (2021).
# https://doi.org/10.1038/s41586-021-03819-2
# (see Supplementary Methods 1.6.5 for details)
@pytest.mark.parametrize(
"direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower()
)
def test_triangle_updates(direction):
pass
c_z = _DEFAULT_CONFIG.c_z

with FusionDefinition() as fd:
z_in = fd.define_tensor(
shape=[-1, -1, -1, c_z],
dtype=DataType.BFloat16,
contiguity=True,
) # [b, i, j, c_z]
w_norm_in = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
b_norm_in = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
w_p_in = fd.define_tensor(
shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_g_in = fd.define_tensor(
shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_norm_out = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
b_norm_out = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
w_p_out = fd.define_tensor(
shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_g_out = fd.define_tensor(
shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
)
# Masking is used in an internal implementation: http://nv/e-4
mask = fd.define_tensor(
shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
) # [b, i, j]

batch_size = fd.ops.size(z_in, 0)
n_tokens = fd.ops.size(z_in, 1)

z_in = layer_norm(fd, z_in, w_norm_in, b_norm_in)
z = gating(fd, z_in, w_p_in, z_in, w_g_in)
mask = fd.ops.broadcast_in_dim(
mask, shape=[batch_size, n_tokens, n_tokens, c_z], broadcast_dims=[0, 1, 2]
)
z = fd.ops.where(mask, z, 0.0)
a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z])
b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2])

match direction:
case Direction.OUTGOING:
# z_out = einsum("bikc,bjkc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j]
case Direction.INCOMING:
# z_out = einsum("bkic,bkjc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j]
z = fd.ops.matmul(a, b) # [b, c, i, j]
z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c]

z = layer_norm(fd, z, w_norm_out, b_norm_out)
z = gating(fd, z, w_p_out, z_in, w_g_out)
fd.add_output(z)

batch_size = 3
n_tokens = 5
z_in = torch.testing.make_tensor(
batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda"
)
w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
w_p_in = torch.testing.make_tensor(
c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
)
w_g_in = torch.testing.make_tensor(
c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
)
w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
mask = torch.testing.make_tensor(
batch_size, n_tokens, n_tokens, dtype=torch.bool, device="cuda"
)
(z_out,) = fd.execute(
[
z_in,
w_norm_in,
b_norm_in,
w_p_in,
w_g_in,
w_norm_out,
b_norm_out,
w_p_out,
w_g_out,
mask,
]
)
assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z)


# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-attention
#
# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
# prediction with AlphaFold. Nature 596, 583–589 (2021).
# https://doi.org/10.1038/s41586-021-03819-2
# (see Supplementary Methods 1.6.6 for details)
@pytest.mark.parametrize(
"direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower()
)
Expand All @@ -52,8 +195,8 @@ def test_triangle_attention(direction):
dtype=DataType.BFloat16,
contiguity=True,
) # [b, i, j, c_z]
if direction == Direction.INCOMING:
z_in = fd.ops.permute(z_in, [0, 2, 1, 3])
w_norm = fd.define_tensor(shape=[c_z], dtype=DataType.BFloat16, contiguity=True)
b_norm = fd.define_tensor(shape=[c_z], dtype=DataType.BFloat16, contiguity=True)
w_q = fd.define_tensor(
shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True
)
Expand All @@ -64,8 +207,6 @@ def test_triangle_attention(direction):
mask = fd.define_tensor(
shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
) # [b, i, j]
if direction == Direction.INCOMING:
mask = fd.ops.permute(mask, [0, 2, 1])
w_v = fd.define_tensor(
shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True
)
Expand All @@ -79,6 +220,9 @@ def test_triangle_attention(direction):
batch_size = fd.ops.size(z_in, 0)
n_tokens = fd.ops.size(z_in, 1)

if direction == Direction.INCOMING:
z_in = fd.ops.permute(z_in, [0, 2, 1, 3])
z_in = layer_norm(fd, z_in, w_norm, b_norm)
q = fd.ops.linear(z_in, w_q)
q_h = fd.ops.reshape(
q, [batch_size, n_tokens, n_tokens, h, -1]
Expand All @@ -99,6 +243,8 @@ def test_triangle_attention(direction):
broadcast_dims=[0, 2, 3, 4],
) # [b, 1, h, j, k]

if direction == Direction.INCOMING:
mask = fd.ops.permute(mask, [0, 2, 1])
mask = fd.ops.broadcast_in_dim(
mask,
shape=[batch_size, n_tokens, 1, 1, n_tokens],
Expand Down Expand Up @@ -142,6 +288,8 @@ def test_triangle_attention(direction):
z_in = torch.testing.make_tensor(
batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda"
)
w_norm = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
b_norm = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
w_q = torch.testing.make_tensor(
h * c_hidden, c_z, dtype=torch.bfloat16, device="cuda"
)
Expand All @@ -161,5 +309,5 @@ def test_triangle_attention(direction):
w_o = torch.testing.make_tensor(
c_z, h * c_hidden, dtype=torch.bfloat16, device="cuda"
)
(z_out,) = fd.execute([z_in, w_q, w_k, w_b, mask, w_v, w_g, w_o])
(z_out,) = fd.execute([z_in, w_norm, b_norm, w_q, w_k, w_b, mask, w_v, w_g, w_o])
assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z)
14 changes: 3 additions & 11 deletions tests/python/opinfo/opinfo_input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,21 +217,14 @@ def broadcast_in_dim_error_generator(
"The new shape is expected to be greater-then-or-equal to the input",
)

# 3. broadcast_dimensions is an ascending sequence of integers.
descending_broadcast_dimensions = (
([2, 2], [2, 2], [1, 0]),
RuntimeError,
"Broadcast dimension is not greater than the previous value.",
)

# 4. Each broadcast dimension is within the new shape.
# 3. Each broadcast dimension is within the new shape.
out_of_bounds_broadcast_dimensions = (
([2, 2], [2, 2], [0, 2]),
RuntimeError,
"Invalid broadcast_dims value.",
)

# 5. The original tensor is not broadcastable to desired shape.
# 4. The original tensor is not broadcastable to desired shape.
# tensor.shape[idx] == 1 or tensor.shape[idx] == output_shape[new_idx]
#
# Jax Exception:
Expand All @@ -244,7 +237,7 @@ def broadcast_in_dim_error_generator(
"Invalid broadcast_dims value.",
)

# 6. TypeError: broadcast_in_dim shape must have every element be nonnegative, got (-1, 2, 3).
# 5. TypeError: broadcast_in_dim shape must have every element be nonnegative, got (-1, 2, 3).
negative_shape = (
([2, 3], [2, 3, -1], [0, 1]),
RuntimeError,
Expand All @@ -255,7 +248,6 @@ def broadcast_in_dim_error_generator(
error_cases = [
missing_axis_in_bcast_dims,
fewer_dims_in_output_shape,
descending_broadcast_dimensions,
out_of_bounds_broadcast_dimensions,
# not_broadcastable,
# negative_shape,
Expand Down