@@ -772,7 +772,9 @@ def quantize_per_tensor_meta(
772772 quant_max : int ,
773773 dtype : torch .dtype ,
774774) -> torch .Tensor :
775- torch ._check (input .dtype == torch .float32 , lambda : "expected float32" )
775+ torch ._check (
776+ input .dtype in (torch .float32 , torch .bfloat16 ), lambda : "expected float dtype"
777+ )
776778 return input .new_empty (input .size (), dtype = dtype )
777779
778780
@@ -785,7 +787,9 @@ def quantize_per_tensor_asym8s_meta(
785787 quant_max : int ,
786788 dtype : torch .dtype ,
787789) -> torch .Tensor :
788- torch ._check (input .dtype == torch .float32 , lambda : "expected float32" )
790+ torch ._check (
791+ input .dtype in (torch .float32 , torch .bfloat16 ), lambda : "expected float dtype"
792+ )
789793 return input .new_empty (input .size (), dtype = dtype )
790794
791795
@@ -798,7 +802,9 @@ def quantize_per_tensor_asym8u_meta(
798802 quant_max : int ,
799803 dtype : torch .dtype ,
800804) -> torch .Tensor :
801- torch ._check (input .dtype == torch .float32 , lambda : "expected float32" )
805+ torch ._check (
806+ input .dtype in (torch .float32 , torch .bfloat16 ), lambda : "expected float dtype"
807+ )
802808 return input .new_empty (input .size (), dtype = dtype )
803809
804810
@@ -811,7 +817,9 @@ def quantize_per_tensor_asym16s_meta(
811817 quant_max : int ,
812818 dtype : torch .dtype ,
813819) -> torch .Tensor :
814- torch ._check (input .dtype == torch .float32 , lambda : "expected float32" )
820+ torch ._check (
821+ input .dtype in (torch .float32 , torch .bfloat16 ), lambda : "expected float dtype"
822+ )
815823 return input .new_empty (input .size (), dtype = dtype )
816824
817825
@@ -824,7 +832,9 @@ def quantize_per_tensor_asym16u_meta(
824832 quant_max : int ,
825833 dtype : torch .dtype ,
826834) -> torch .Tensor :
827- torch ._check (input .dtype == torch .float32 , lambda : "expected float32" )
835+ torch ._check (
836+ input .dtype in (torch .float32 , torch .bfloat16 ), lambda : "expected float dtype"
837+ )
828838 return input .new_empty (input .size (), dtype = dtype )
829839
830840
@@ -837,7 +847,9 @@ def quantize_per_tensor_asym32s_meta(
837847 quant_max : int ,
838848 dtype : torch .dtype ,
839849) -> torch .Tensor :
840- torch ._check (input .dtype == torch .float32 , lambda : "expected float32" )
850+ torch ._check (
851+ input .dtype in (torch .float32 , torch .bfloat16 ), lambda : "expected float dtype"
852+ )
841853 return input .new_empty (input .size (), dtype = dtype )
842854
843855
@@ -2601,7 +2613,9 @@ def fully_connected_meta(
26012613 weight : torch .Tensor ,
26022614 bias : Optional [torch .Tensor ] = None ,
26032615) -> torch .Tensor :
2604- torch ._check (src .dtype == torch .float32 , lambda : "expected float32" )
2616+ torch ._check (
2617+ src .dtype in (torch .float32 , torch .bfloat16 ), lambda : "expected float dtype"
2618+ )
26052619 torch ._check (src .size (0 ) == 1 , lambda : "expected batch size of 1" )
26062620 # src comes in shape [leading_dims, in_dim]
26072621 # weight comes in shape [out_dim, in_dim]
0 commit comments