Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions coremltools/converters/mil/mil/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
optimize_activation_quantization,
optimize_conv,
optimize_elementwise_binary,
optimize_gelu_sigmoid,
optimize_linear,
optimize_normalization,
optimize_quantization,
Expand Down
139 changes: 139 additions & 0 deletions coremltools/converters/mil/mil/passes/defs/optimize_gelu_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2024, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import (
_check_var_scalar_value,
block_context_manager,
)
from coremltools.converters.mil.mil.passes.pass_registry import register_pass


@register_pass(namespace="common")
class fuse_gelu_sigmoid_approximation(AbstractGraphPass):
"""
Detect the pattern that corresponds to the sigmoid approximation version of ``gelu``,
and replace it with a single ``gelu`` layer with ``mode=SIGMOID_APPROXIMATION``.

The sigmoid approximation of GELU is: ``y = x * sigmoid(1.702 * x)``

.. code-block::

Input graph:
[...] ----> mul (1.702) ---> sigmoid ---> mul ---> [...]
| ^
| |
|----------------------------------------

Output graph:
[...] ----> gelu (mode=SIGMOID_APPROXIMATION) ---> [...]
"""

GELU_SIGMOID_CONST = 1.702

def apply(self, prog):
for f in prog.functions.values():
block_changed = True
while block_changed:
block_changed = self._fuse_gelu_sigmoid_block(f)

@block_context_manager
def _fuse_gelu_sigmoid_block(self, block):
fusion_occurred = False
for op in list(block.operations):
if op.enclosing_block is None:
continue

for b in op.blocks:
nested_changed = True
while nested_changed:
nested_changed = self._fuse_gelu_sigmoid_block(b)

if len(op.blocks) > 0:
continue

if op.op_type == "mul":
if self._try_match_and_transform_pattern1(op, block):
fusion_occurred = True
return fusion_occurred

def _try_match_and_transform_pattern1(self, mul_op, block):
"""
Match pattern: x -> mul(1.702) -> sigmoid -> mul(x) -> output

Where the final mul combines x with sigmoid(1.702 * x).
"""
mul_x = mul_op.x
mul_y = mul_op.y

sigmoid_var = None
root_var = None

if mul_x.op is not None and mul_x.op.op_type == "sigmoid":
sigmoid_var = mul_x
root_var = mul_y
elif mul_y.op is not None and mul_y.op.op_type == "sigmoid":
sigmoid_var = mul_y
root_var = mul_x
else:
return False

sigmoid_op = sigmoid_var.op

if sigmoid_op.outputs[0] in block.outputs:
return False

sigmoid_input_op = sigmoid_op.x.op
if sigmoid_input_op is None or sigmoid_input_op.op_type != "mul":
return False

scale_mul_op = sigmoid_input_op

is_x_const = _check_var_scalar_value(scale_mul_op.x, self.GELU_SIGMOID_CONST, tol=0.01)
is_y_const = _check_var_scalar_value(scale_mul_op.y, self.GELU_SIGMOID_CONST, tol=0.01)

if not (is_x_const or is_y_const):
return False

scale_mul_input = scale_mul_op.y if is_x_const else scale_mul_op.x

if scale_mul_input != root_var:
return False

if scale_mul_op.outputs[0] in block.outputs:
return False

return self._transform_to_gelu(
block=block,
root_var=root_var,
ops_to_remove=[scale_mul_op, sigmoid_op, mul_op],
output_op=mul_op,
)

def _transform_to_gelu(self, block, root_var, ops_to_remove, output_op):
"""Replace the matched pattern with a single gelu op."""
for op in ops_to_remove[:-1]:
for out in op.outputs:
if out in block.outputs:
return False

out_name = output_op.outputs[0].name
gelu_out = mb.gelu(
x=root_var,
mode="SIGMOID_APPROXIMATION",
name=out_name,
before_op=ops_to_remove[0],
)

output_op.enclosing_block.replace_uses_of_var_after_op(
anchor_op=output_op,
old_var=output_op.outputs[0],
new_var=gelu_out,
)

block.remove_ops(ops_to_remove)
return True

1 change: 1 addition & 0 deletions coremltools/converters/mil/mil/passes/pass_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"common::fuse_linear_bias",
"common::fuse_gelu_tanh_approximation",
"common::fuse_gelu_exact",
"common::fuse_gelu_sigmoid_approximation",
"common::fuse_leaky_relu",
"common::rank0_expand_dims_swap",
"common::fuse_squeeze_expand_dims",
Expand Down
160 changes: 160 additions & 0 deletions coremltools/converters/mil/mil/passes/tests/test_gelu_sigmoid_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2024, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import numpy as np

from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.testing_utils import (
apply_pass_and_basic_check,
assert_model_is_valid,
get_op_types_in_program,
)


class TestFuseGeluSigmoidApproximation:
"""
Test the fuse_gelu_sigmoid_approximation pass.

Input pattern:
x -> mul(1.702) -> sigmoid -> mul(x) -> output

Output pattern:
x -> gelu(mode=SIGMOID_APPROXIMATION) -> output
"""

def test_basic_fusion(self):
"""Test basic GELU sigmoid approximation fusion."""

@mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))])
def prog(x):
scaled = mb.mul(x=x, y=np.float32(1.702))
sigmoid_out = mb.sigmoid(x=scaled)
return mb.mul(x=x, y=sigmoid_out)

prev_prog, prev_block, block = apply_pass_and_basic_check(
prog, "common::fuse_gelu_sigmoid_approximation"
)

assert get_op_types_in_program(prev_prog) == ["mul", "sigmoid", "mul"]
assert get_op_types_in_program(prog) == ["gelu"]

gelu_op = block.find_ops(op_type="gelu")[0]
assert gelu_op.mode.val == "SIGMOID_APPROXIMATION"

assert_model_is_valid(
prog,
{"x": (2, 3)},
expected_output_shapes={block.outputs[0].name: (2, 3)},
)

def test_fusion_with_reversed_mul_order(self):
"""Test fusion when mul operands are in reversed order."""

@mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))])
def prog(x):
scaled = mb.mul(x=np.float32(1.702), y=x)
sigmoid_out = mb.sigmoid(x=scaled)
return mb.mul(x=sigmoid_out, y=x)

prev_prog, prev_block, block = apply_pass_and_basic_check(
prog, "common::fuse_gelu_sigmoid_approximation"
)

assert get_op_types_in_program(prev_prog) == ["mul", "sigmoid", "mul"]
assert get_op_types_in_program(prog) == ["gelu"]

gelu_op = block.find_ops(op_type="gelu")[0]
assert gelu_op.mode.val == "SIGMOID_APPROXIMATION"

def test_fusion_with_different_shapes(self):
"""Test fusion with different input shapes."""
shapes_to_test = [
(1,),
(2, 3),
(2, 3, 4),
(1, 2, 3, 4),
]

for shape in shapes_to_test:

@mb.program(input_specs=[mb.TensorSpec(shape=shape)])
def prog(x):
scaled = mb.mul(x=x, y=np.float32(1.702))
sigmoid_out = mb.sigmoid(x=scaled)
return mb.mul(x=x, y=sigmoid_out)

apply_pass_and_basic_check(prog, "common::fuse_gelu_sigmoid_approximation")

assert get_op_types_in_program(prog) == ["gelu"]

def test_no_fusion_wrong_constant(self):
"""Test that fusion does not occur with wrong constant value."""

@mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))])
def prog(x):
scaled = mb.mul(x=x, y=np.float32(2.0))
sigmoid_out = mb.sigmoid(x=scaled)
return mb.mul(x=x, y=sigmoid_out)

apply_pass_and_basic_check(prog, "common::fuse_gelu_sigmoid_approximation")

assert get_op_types_in_program(prog) == ["mul", "sigmoid", "mul"]

def test_no_fusion_output_used_elsewhere(self):
"""Test that fusion does not occur when intermediate output is block output."""

@mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))])
def prog(x):
scaled = mb.mul(x=x, y=np.float32(1.702))
sigmoid_out = mb.sigmoid(x=scaled)
gelu_approx = mb.mul(x=x, y=sigmoid_out)
return gelu_approx, sigmoid_out

apply_pass_and_basic_check(prog, "common::fuse_gelu_sigmoid_approximation")

assert "sigmoid" in get_op_types_in_program(prog)

def test_no_fusion_different_input_var(self):
"""Test that fusion does not occur when sigmoid input differs from final mul input."""

@mb.program(
input_specs=[mb.TensorSpec(shape=(2, 3)), mb.TensorSpec(shape=(2, 3))]
)
def prog(x, y):
scaled = mb.mul(x=y, y=np.float32(1.702))
sigmoid_out = mb.sigmoid(x=scaled)
return mb.mul(x=x, y=sigmoid_out)

apply_pass_and_basic_check(prog, "common::fuse_gelu_sigmoid_approximation")

assert get_op_types_in_program(prog) == ["mul", "sigmoid", "mul"]

def test_fusion_fp16(self):
"""Test fusion works with fp16 data type."""

@mb.program(input_specs=[mb.TensorSpec(shape=(2, 3), dtype=types.fp16)])
def prog(x):
scaled = mb.mul(x=x, y=np.float16(1.702))
sigmoid_out = mb.sigmoid(x=scaled)
return mb.mul(x=x, y=sigmoid_out)

apply_pass_and_basic_check(prog, "common::fuse_gelu_sigmoid_approximation")

assert get_op_types_in_program(prog) == ["gelu"]

def test_numerical_correctness(self):
"""Test that the fused operation produces numerically correct results."""

@mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))])
def prog(x):
scaled = mb.mul(x=x, y=np.float32(1.702))
sigmoid_out = mb.sigmoid(x=scaled)
return mb.mul(x=x, y=sigmoid_out)

apply_pass_and_basic_check(prog, "common::fuse_gelu_sigmoid_approximation")

assert get_op_types_in_program(prog) == ["gelu"]