Skip to content

Commit 93097f4

Browse files
committed
[Test] Add torch.no_grad(), change to use torch.nn.ReLU
1 parent 9d05ca2 commit 93097f4

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

tests/test_activation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ def test_ReLU(device, size=(128, 128)):
2323
input = torch.randn(size)
2424
x1 = input.to(device=device)
2525
x2 = input.to("cpu")
26-
opt_fn = torch.compile(dynamic=False)(torch.nn.functional.relu)
26+
ReLU = torch.nn.ReLU()
27+
opt_fn = torch.compile(dynamic=False)(ReLU)
2728
y = opt_fn(x1)
28-
cpu_y = torch.nn.functional.relu(x2)
29+
cpu_y = ReLU(x2)
2930
test_result("ReLU", y, cpu_y)
3031

3132
def test_GeLU(device, size=(128, 128), approximate='none'):

tests/test_conv2d.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,16 @@ def custom_conv2d(a, b, bias):
4444
module = ExecutionEngine.setup_device()
4545
device = module.custom_device()
4646
torch._dynamo.config.cache_size_limit = 64
47-
test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0)
48-
test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3)
49-
test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3)
50-
test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3)
51-
test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3)
52-
test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2)
53-
test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3)
54-
test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1)
55-
test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1)
56-
test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0)
57-
test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0)
58-
test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0)
47+
with torch.no_grad():
48+
test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0)
49+
test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3)
50+
test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3)
51+
test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3)
52+
test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3)
53+
test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2)
54+
test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3)
55+
test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1)
56+
test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1)
57+
test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0)
58+
test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0)
59+
test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0)

tests/test_layernorm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@ def test_LayerNorm(device, size=(64, 64)):
4444
from Scheduler.scheduler import ExecutionEngine
4545
module = ExecutionEngine.setup_device()
4646
device = module.custom_device()
47-
#test_LayerNorm(device)
48-
test_LayerNorm(device, shape)
47+
with torch.no_grad():
48+
#test_LayerNorm(device)
49+
test_LayerNorm(device, shape)

0 commit comments

Comments
 (0)