Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 74 additions & 19 deletions test/quantization/pt2e/test_x86inductor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3275,42 +3275,64 @@ def test_fp8_q_attention_block(self):
annotate_matmul=annotate_matmul, is_fp8=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
def test_fp8_scaled_embedding_bag(self):
dtype = torch.float8_e4m3fn

def _test_scaled_embedding_bag_helper(self, dtype, with_output_quant=False):
class FP8QDQEmbeddingBag(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight_scale = 2.0
self.output_scale = 3.0

def _dequantize(self, weight):
if dtype == torch.float8_e4m3fn:
res = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
tensor=weight.data,
scale=torch.tensor([self.weight_scale]),
output_dtype=torch.float,
)
else:
res = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
weight.data,
self.weight_scale,
0,
-128,
127,
torch.int8,
)
return res

def _quantize(self, x):
if dtype == torch.float8_e4m3fn:
qx = (
torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
tensor=x,
scale=torch.tensor([self.output_scale]),
float8_dtype=dtype,
)
)
else:
qx = torch.ops.quantized_decomposed.quantize_per_tensor.default(
x, self.output_scale, 0, -128, 127, torch.int8
)
return qx

def forward(
self,
weight,
input,
offsets=None,
):
weight = (
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
tensor=weight.data,
scale=torch.tensor([self.weight_scale]),
output_dtype=torch.float,
)
)
weight = self._dequantize(weight)

return torch.nn.functional.embedding_bag(
res = torch.nn.functional.embedding_bag(
input,
weight,
offsets,
mode="sum",
include_last_offset=True,
)
if with_output_quant:
res = self._quantize(res)
return res

EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024]
Expand All @@ -3337,8 +3359,11 @@ def forward(
)

def matcher_check_fn():
counter_name = "scaled_embedding_bag"
if with_output_quant:
counter_name += "_with_output_quant"
self.assertEqual(
counters["inductor"]["scaled_embedding_bag_matcher_count"], 1
counters["inductor"][f"{counter_name}_matcher_count"], 1
)

self._test_common(
Expand All @@ -3347,6 +3372,36 @@ def matcher_check_fn():
matcher_check_fn,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
def test_fp8_scaled_embedding_bag(self):
self._test_scaled_embedding_bag_helper(torch.float8_e4m3fn)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
def test_int8_scaled_embedding_bag(self):
self._test_scaled_embedding_bag_helper(torch.int8)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
def test_int8_scaled_embedding_bag_with_output_quant(self):
self._test_scaled_embedding_bag_helper(torch.int8, True)


instantiate_parametrized_tests(TestPatternMatcher)
if __name__ == "__main__":
Expand Down
64 changes: 51 additions & 13 deletions torchao/quantization/pt2e/inductor_passes/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -2911,7 +2911,11 @@ def _register_scaled_embedding_bag_pass(pattern, pass_number, dtype=torch.float3
def scaled_embedding_bag(match: Match, *args, **kwargs):
assert dtype in [torch.float32, torch.bfloat16]

getitem_node = match.output_node()
if "o_dtype" in kwargs:
quant_node = match.output_node()
getitem_node = quant_node.args[0]
else:
getitem_node = match.output_node()
embedding_bag_node = getitem_node.args[0]
assert embedding_bag_node.target is aten._embedding_bag_forward_only.default

Expand Down Expand Up @@ -2940,11 +2944,22 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
kwargs["mode"],
kwargs["include_last_offset"],
)
# only support fp32 output, next step to support more dtype
output_type = torch.float
o_scale = 1.0
if "o_dtype" in kwargs:
output_type = kwargs["o_dtype"]
o_scale = kwargs["o_inv_scale"]

graph = match.graph
with graph.inserting_before(getitem_node):
# float scale not supported on scaled_embedding_bag
# convert scale from float into tensor
if type(w_scale) is float:
w_scale = graph.call_function(
torch.ops.aten.full.default,
args=([1], w_scale),
kwargs={"dtype": torch.float},
)
new_args: tuple[Any, ...] = (
qw,
indices,
Expand All @@ -2953,13 +2968,18 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
o_scale,
mode,
include_last_offset,
torch.float,
output_type,
)

new_embedding_bag_node = graph.call_function(
torch.ops.torchao._scaled_embedding_bag.default, args=new_args
)

# Erase quant pattern
if output_type == torch.int8:
quant_node.replace_all_uses_with(getitem_node)
getitem_node.meta.update(quant_node.meta)
graph.erase_node(quant_node)
getitem_node.replace_all_uses_with(new_embedding_bag_node)
new_embedding_bag_node.meta.update(embedding_bag_node.meta)

Expand All @@ -2970,8 +2990,11 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
# Erase the dequant pattern
graph.erase_node(dequant_node)

counters["inductor"]["scaled_embedding_bag_matcher_count"] += 1
counters["inductor"]["scaled_embedding_bag_matcher_nodes"] += len(match.nodes)
counter_name = "scaled_embedding_bag"
if "o_dtype" in kwargs:
counter_name += "_with_output_quant"
counters["inductor"][f"{counter_name}_matcher_count"] += 1
counters["inductor"][f"{counter_name}_matcher_nodes"] += len(match.nodes)


def _generate_scaled_embedding_bag_patterns(dq_pattern):
Expand All @@ -2994,20 +3017,35 @@ def _generate_scaled_embedding_bag_patterns(dq_pattern):


def _register_quantization_embeddingbag_pass():
for dtype in [torch.float32, torch.bfloat16]:
_register_scaled_embedding_bag_pass(
_generate_scaled_embedding_bag_patterns(
for is_fp8 in [True, False]:
for dtype in [torch.float32, torch.bfloat16]:
embeddingbag_pattern = _generate_scaled_embedding_bag_patterns(
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(
is_tensor_overload=False, is_fp8=True
is_tensor_overload=False, is_fp8=is_fp8
),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
),
pass_number=1,
dtype=dtype,
) # pass_number=0 to run before weight prepack
)

_register_scaled_embedding_bag_pass(
embeddingbag_pattern, pass_number=1, dtype=dtype
)

# will support fp8 output later
if not is_fp8:
embeddingbag_with_qoutput_pattern = generate_pattern_with_output_quant(
embeddingbag_pattern,
dtype == torch.bfloat16,
is_fp8,
)

_register_scaled_embedding_bag_pass(
embeddingbag_with_qoutput_pattern,
pass_number=0,
dtype=dtype,
)


@functools.lru_cache(None)
Expand Down
Loading