Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def __init__(self, model, subgraph, exp_tab, ctx):
"REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max),
"REDUCE_MIN": functools.partial(self._convert_reduce, relax_op=_op.min),
"REDUCE_PROD": functools.partial(self._convert_reduce, relax_op=_op.prod),
"REDUCE_WINDOW": self.convert_reduce_window,
"RELU": self.convert_relu,
"RELU6": self.convert_relu6,
"RELU_N1_TO_1": self.convert_relu_n1_to_1,
Expand Down Expand Up @@ -2462,6 +2463,169 @@ def _convert_reduce(self, relax_op, op):

return out

def convert_reduce_window(self, op):
"""Convert TFLite REDUCE_WINDOW."""

from tflite.BuiltinOptions2 import BuiltinOptions2
from tflite.ReduceWindowFunction import ReduceWindowFunction
from tflite.ReduceWindowOptions import ReduceWindowOptions

input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
assert len(input_tensors) == 5, "input tensors length should be 5"
assert len(output_tensors) == 1, "output tensors length should be 1"
Comment on lines +2475 to +2476

if op.BuiltinOptions2Type() != BuiltinOptions2.ReduceWindowOptions:
raise tvm.error.OpAttributeUnImplemented(
"TFLite REDUCE_WINDOW requires ReduceWindowOptions."
)

(
input_tensor,
init_tensor,
window_shape_tensor,
window_strides_tensor,
window_dilations_tensor,
) = input_tensors
output_tensor = output_tensors[0]

if any(
self.has_expr(tensor.tensor_idx)
for tensor in [window_shape_tensor, window_strides_tensor, window_dilations_tensor]
):
raise tvm.error.OpNotImplemented(
"TFLite REDUCE_WINDOW requires constant window_shape, "
"window_strides, and window_dilations."
)

input_shape = to_int_list(self.get_tensor_shape(input_tensor))
output_shape = to_int_list(self.get_tensor_shape(output_tensor))
Comment on lines +2501 to +2502
input_dtype = self.get_tensor_type_str(input_tensor.tensor.Type())
output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())

if input_tensor.qnn_params or output_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
"Quantized TFLite REDUCE_WINDOW is not yet supported in the Relax frontend."
)

if input_dtype != output_dtype:
raise tvm.error.OpAttributeUnImplemented(
"TFLite REDUCE_WINDOW requires input and output dtypes to match."
)

if len(to_int_list(self.get_tensor_shape(init_tensor))) != 0:
raise tvm.error.OpNotImplemented(
"TFLite REDUCE_WINDOW only supports scalar init_value."
)

options = ReduceWindowOptions()
op_options = op.BuiltinOptions2()
options.Init(op_options.Bytes, op_options.Pos)
reduce_function = options.ReduceFunction()

if reduce_function == ReduceWindowFunction.UNSUPPORTED:
raise tvm.error.OpNotImplemented(
"TFLite REDUCE_WINDOW with UNSUPPORTED reduce_function is not supported."
)

window_shape = to_int_list(self.get_tensor_value(window_shape_tensor))
window_strides = to_int_list(self.get_tensor_value(window_strides_tensor))
window_dilations = to_int_list(self.get_tensor_value(window_dilations_tensor))
rank = len(input_shape)

if not (len(window_shape) == len(window_strides) == len(window_dilations) == rank):
raise tvm.error.OpAttributeUnImplemented(
"TFLite REDUCE_WINDOW window_shape, window_strides, and window_dilations "
"must match input rank."
)

if any(value <= 0 for value in window_shape + window_strides + window_dilations):
raise tvm.error.OpAttributeUnImplemented(
"TFLite REDUCE_WINDOW window dimensions, strides, and dilations must be positive."
)

dilated_window_shape = [
(window_dim - 1) * dilation + 1
for window_dim, dilation in zip(window_shape, window_dilations)
]
expected_output_shape = [
0 if input_dim < dilated_dim else (input_dim - dilated_dim) // stride + 1
for input_dim, dilated_dim, stride in zip(
input_shape, dilated_window_shape, window_strides
)
]

numeric_reduce_functions = (
ReduceWindowFunction.ADD,
ReduceWindowFunction.MUL,
ReduceWindowFunction.MINIMUM,
ReduceWindowFunction.MAXIMUM,
)
bool_reduce_functions = (
ReduceWindowFunction.ALL,
ReduceWindowFunction.ANY,
)

if reduce_function in numeric_reduce_functions and input_dtype == "bool":
raise tvm.error.OpAttributeUnImplemented(
"TFLite REDUCE_WINDOW numeric reductions expect numeric input."
)
if reduce_function in bool_reduce_functions and input_dtype != "bool":
raise tvm.error.OpAttributeUnImplemented(
"TFLite REDUCE_WINDOW boolean reductions expect bool input."
)

if output_shape != expected_output_shape:
raise tvm.error.OpAttributeUnImplemented(
"TFLite REDUCE_WINDOW output shape does not match input/window parameters."
)

if any(output_dim == 0 for output_dim in output_shape):
return relax.op.zeros(output_shape, output_dtype)

data = self.get_tensor_expr(input_tensor)
init_value = self.get_tensor_expr(init_tensor)

windowed = relax.op.call_dps_packed(
"topi.sliding_window",
(
data,
0,
relax.ShapeExpr(dilated_window_shape),
relax.ShapeExpr(window_strides),
),
out_sinfo=relax.TensorStructInfo(output_shape + dilated_window_shape, input_dtype),
)

if any(dilation != 1 for dilation in window_dilations):
windowed = relax.op.strided_slice(
windowed,
axes=list(range(rank, 2 * rank)),
begin=[0] * rank,
end=dilated_window_shape,
strides=window_dilations,
)

reduce_axes = list(range(rank, 2 * rank))
if reduce_function == ReduceWindowFunction.ADD:
return relax.op.add(relax.op.sum(windowed, axis=reduce_axes), init_value)
if reduce_function == ReduceWindowFunction.MUL:
return relax.op.multiply(relax.op.prod(windowed, axis=reduce_axes), init_value)
if reduce_function == ReduceWindowFunction.MINIMUM:
return relax.op.minimum(relax.op.min(windowed, axis=reduce_axes), init_value)
if reduce_function == ReduceWindowFunction.MAXIMUM:
return relax.op.maximum(relax.op.max(windowed, axis=reduce_axes), init_value)
if reduce_function == ReduceWindowFunction.ALL:
reduced = relax.op.min(relax.op.astype(windowed, "int8"), axis=reduce_axes)
return relax.op.logical_and(relax.op.astype(reduced, "bool"), init_value)
if reduce_function == ReduceWindowFunction.ANY:
reduced = relax.op.max(relax.op.astype(windowed, "int8"), axis=reduce_axes)
return relax.op.logical_or(relax.op.astype(reduced, "bool"), init_value)
Comment on lines +2610 to +2623
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This series of if statements can be refactored using dictionaries to map reduction functions to their corresponding Relax operators. This would improve readability and maintainability by grouping the numeric and boolean reduction logic separately.

Suggested change
if reduce_function == ReduceWindowFunction.ADD:
return relax.op.add(relax.op.sum(windowed, axis=reduce_axes), init_value)
if reduce_function == ReduceWindowFunction.MUL:
return relax.op.multiply(relax.op.prod(windowed, axis=reduce_axes), init_value)
if reduce_function == ReduceWindowFunction.MINIMUM:
return relax.op.minimum(relax.op.min(windowed, axis=reduce_axes), init_value)
if reduce_function == ReduceWindowFunction.MAXIMUM:
return relax.op.maximum(relax.op.max(windowed, axis=reduce_axes), init_value)
if reduce_function == ReduceWindowFunction.ALL:
reduced = relax.op.min(relax.op.astype(windowed, "int8"), axis=reduce_axes)
return relax.op.logical_and(relax.op.astype(reduced, "bool"), init_value)
if reduce_function == ReduceWindowFunction.ANY:
reduced = relax.op.max(relax.op.astype(windowed, "int8"), axis=reduce_axes)
return relax.op.logical_or(relax.op.astype(reduced, "bool"), init_value)
_NUMERIC_REDUCE_MAP = {
ReduceWindowFunction.ADD: (relax.op.sum, relax.op.add),
ReduceWindowFunction.MUL: (relax.op.prod, relax.op.multiply),
ReduceWindowFunction.MINIMUM: (relax.op.min, relax.op.minimum),
ReduceWindowFunction.MAXIMUM: (relax.op.max, relax.op.maximum),
}
if reduce_function in _NUMERIC_REDUCE_MAP:
reduce_op, combine_op = _NUMERIC_REDUCE_MAP[reduce_function]
return combine_op(reduce_op(windowed, axis=reduce_axes), init_value)
_BOOL_REDUCE_MAP = {
ReduceWindowFunction.ALL: (relax.op.min, relax.op.logical_and),
ReduceWindowFunction.ANY: (relax.op.max, relax.op.logical_or),
}
if reduce_function in _BOOL_REDUCE_MAP:
reduce_op, combine_op = _BOOL_REDUCE_MAP[reduce_function]
reduced = reduce_op(relax.op.astype(windowed, "int8"), axis=reduce_axes)
return combine_op(relax.op.astype(reduced, "bool"), init_value)


Comment on lines +2610 to +2624
raise tvm.error.OpNotImplemented(
f"TFLite REDUCE_WINDOW reduce_function {reduce_function} is not supported."
)

def _convert_reduce_bool(self, relax_op, op):
"""Convert TFLite REDUCE_ANY / REDUCE_ALL (bool-only ops).

Expand Down
Loading
Loading