Skip to content

Commit a6ee309

Browse files
authored
Fix 7 broken tests caused by D102624163 (pytorch#19273)
Differential Revision: D103598357 Pull Request resolved: pytorch#19273
1 parent 48a8d58 commit a6ee309

1 file changed

Lines changed: 21 additions & 7 deletions

File tree

backends/cadence/aot/ops_registrations.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,9 @@ def quantize_per_tensor_meta(
772772
quant_max: int,
773773
dtype: torch.dtype,
774774
) -> torch.Tensor:
775-
torch._check(input.dtype == torch.float32, lambda: "expected float32")
775+
torch._check(
776+
input.dtype in (torch.float32, torch.bfloat16), lambda: "expected float dtype"
777+
)
776778
return input.new_empty(input.size(), dtype=dtype)
777779

778780

@@ -785,7 +787,9 @@ def quantize_per_tensor_asym8s_meta(
785787
quant_max: int,
786788
dtype: torch.dtype,
787789
) -> torch.Tensor:
788-
torch._check(input.dtype == torch.float32, lambda: "expected float32")
790+
torch._check(
791+
input.dtype in (torch.float32, torch.bfloat16), lambda: "expected float dtype"
792+
)
789793
return input.new_empty(input.size(), dtype=dtype)
790794

791795

@@ -798,7 +802,9 @@ def quantize_per_tensor_asym8u_meta(
798802
quant_max: int,
799803
dtype: torch.dtype,
800804
) -> torch.Tensor:
801-
torch._check(input.dtype == torch.float32, lambda: "expected float32")
805+
torch._check(
806+
input.dtype in (torch.float32, torch.bfloat16), lambda: "expected float dtype"
807+
)
802808
return input.new_empty(input.size(), dtype=dtype)
803809

804810

@@ -811,7 +817,9 @@ def quantize_per_tensor_asym16s_meta(
811817
quant_max: int,
812818
dtype: torch.dtype,
813819
) -> torch.Tensor:
814-
torch._check(input.dtype == torch.float32, lambda: "expected float32")
820+
torch._check(
821+
input.dtype in (torch.float32, torch.bfloat16), lambda: "expected float dtype"
822+
)
815823
return input.new_empty(input.size(), dtype=dtype)
816824

817825

@@ -824,7 +832,9 @@ def quantize_per_tensor_asym16u_meta(
824832
quant_max: int,
825833
dtype: torch.dtype,
826834
) -> torch.Tensor:
827-
torch._check(input.dtype == torch.float32, lambda: "expected float32")
835+
torch._check(
836+
input.dtype in (torch.float32, torch.bfloat16), lambda: "expected float dtype"
837+
)
828838
return input.new_empty(input.size(), dtype=dtype)
829839

830840

@@ -837,7 +847,9 @@ def quantize_per_tensor_asym32s_meta(
837847
quant_max: int,
838848
dtype: torch.dtype,
839849
) -> torch.Tensor:
840-
torch._check(input.dtype == torch.float32, lambda: "expected float32")
850+
torch._check(
851+
input.dtype in (torch.float32, torch.bfloat16), lambda: "expected float dtype"
852+
)
841853
return input.new_empty(input.size(), dtype=dtype)
842854

843855

@@ -2601,7 +2613,9 @@ def fully_connected_meta(
26012613
weight: torch.Tensor,
26022614
bias: Optional[torch.Tensor] = None,
26032615
) -> torch.Tensor:
2604-
torch._check(src.dtype == torch.float32, lambda: "expected float32")
2616+
torch._check(
2617+
src.dtype in (torch.float32, torch.bfloat16), lambda: "expected float dtype"
2618+
)
26052619
torch._check(src.size(0) == 1, lambda: "expected batch size of 1")
26062620
# src comes in shape [leading_dims, in_dim]
26072621
# weight comes in shape [out_dim, in_dim]

0 commit comments

Comments
 (0)