Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions src/relax/transform/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ class ConstantFolder : public ExprMutator {
* constant shape and get runtime shape tuple from it.
* \param struct_info The given struct info whose shape inside is to be casted.
* \return The runtime shape tuple, or nullopt if it is not a constant shape.
* \note Only TensorStructInfo is supported at this moment. Return std::nullopt
* \note Only TensorStructInfo is supported. Returns std::nullopt
* if the input struct info is not TensorStructInfo.
*/
static ffi::Optional<ffi::Shape> MatchConstShape(const StructInfo& struct_info) {
// Only support single output for call_tir at this moment.
const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>();
if (tensor_sinfo == nullptr) {
return std::nullopt;
Expand Down Expand Up @@ -143,8 +142,8 @@ class ConstantFolder : public ExprMutator {
return true;
}

// Try constant evaluate the function call
// if failed return std::nullopt
// Try constant evaluate a call_tir with a single tensor output.
// Returns std::nullopt on failure.
ffi::Optional<Expr> ConstEvaluateCallTIR(tir::PrimFunc tir_func,
ffi::Array<runtime::Tensor> arr_args, ffi::Shape shape,
DataType ret_type) {
Expand Down Expand Up @@ -175,25 +174,72 @@ class ConstantFolder : public ExprMutator {
return Constant(ret_tensor);
}

// Try constant evaluate a call_tir with tuple outputs (multiple output tensors).
// Returns std::nullopt on failure.
ffi::Optional<Expr> ConstEvaluateCallTIRTuple(tir::PrimFunc tir_func,
ffi::Array<runtime::Tensor> arr_args,
const TupleStructInfoNode* tuple_sinfo) {
ffi::Optional<ffi::Function> func = GetCachedBuild(tir_func);
if (!func) return std::nullopt;

DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
size_t num_outputs = tuple_sinfo->fields.size();

// Match shapes and dtypes for all output fields.
std::vector<runtime::Tensor> ret_tensors;
for (size_t i = 0; i < num_outputs; ++i) {
ffi::Optional<ffi::Shape> shape = MatchConstShape(tuple_sinfo->fields[i]);
if (!shape) return std::nullopt;
auto tensor_sinfo = Downcast<TensorStructInfo>(tuple_sinfo->fields[i]);
if (tensor_sinfo->IsUnknownDtype()) return std::nullopt;
ret_tensors.push_back(runtime::Tensor::Empty(shape.value(), tensor_sinfo->dtype, cpu_dev));
}

// Pack input args + all output tensors.
std::vector<runtime::Tensor> temp_args(arr_args.begin(), arr_args.end());
std::vector<AnyView> packed_args;
packed_args.reserve(temp_args.size() + num_outputs);
for (const auto& arg : temp_args) {
packed_args.push_back(arg);
}
for (const auto& out_tensor : ret_tensors) {
packed_args.push_back(out_tensor);
}

ffi::Any ret;
func.value().CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &ret);

ffi::Array<Expr> fields;
for (size_t i = 0; i < num_outputs; ++i) {
fields.push_back(Constant(ret_tensors[i]));
}
return Tuple(fields);
}

// Returns the folded expr if the call is successfully folded to constant, otherwise null.
ffi::Optional<Expr> VisitCallTIR(Call call) {
// call_tir needs to have at least three arguments
// call_tir needs to have at least two arguments
ICHECK_GE(call->args.size(), 2);
ffi::Optional<tir::PrimFunc> func = MatchPrimFunc(call->args[0]);
ICHECK(call->args[1].as<TupleNode>()) << "call_tir.args[1] must be Tuple";
ffi::Optional<ffi::Array<runtime::Tensor>> arr_args =
MatchConstArrayArgs(call->args[1].as<TupleNode>()->fields);
ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg";

if (!func || !arr_args) return {};

// Handle tuple output: sinfo_args[0] is a TupleStructInfo.
if (const auto* tuple_sinfo = call->sinfo_args[0].as<TupleStructInfoNode>()) {
return ConstEvaluateCallTIRTuple(func.value(), arr_args.value(), tuple_sinfo);
}

// Handle single tensor output.
ffi::Optional<ffi::Shape> shape = MatchConstShape(call->sinfo_args[0]);
bool output_not_tuple = call->sinfo_args.size() == 1;
// Pattern 0: call constant function, const argument with const shape.
if (func && arr_args && shape && output_not_tuple) {
if (shape) {
TensorStructInfo ret_sinfo = Downcast<TensorStructInfo>(call->struct_info_);
// value_or will return value if it is not null, otherwise return or
return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_sinfo->dtype)
.value_or({});
}
// TODO(hongyi): support const-fold tuple outputs
return {};
}

Expand Down
51 changes: 51 additions & 0 deletions tests/python/relax/test_transform_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,5 +442,56 @@ def expected(
tvm.ir.assert_structural_equal(after, expected)


def test_fold_tuple_output():
@tvm.script.ir_module
class Module:
@T.prim_func
def split(
A: T.Buffer((4, 4), "float32"),
B: T.Buffer((2, 4), "float32"),
C: T.Buffer((2, 4), "float32"),
) -> None:
for i, j in T.grid(2, 4):
with T.sblock("upper"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]
for i, j in T.grid(2, 4):
with T.sblock("lower"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi + 2, vj]

@R.function
def before(c0: R.Tensor((4, 4), "float32")):
cls = Module
lv0 = relax.call_tir(
cls.split,
(c0,),
out_sinfo=[
R.Tensor((2, 4), dtype="float32"),
R.Tensor((2, 4), dtype="float32"),
],
)
return lv0

@R.function
def expected(
c1: R.Tensor((2, 4), "float32"), c2: R.Tensor((2, 4), "float32")
) -> R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")):
lv0: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")) = (
c1,
c2,
)
return lv0

c0_np = np.arange(16).astype("float32").reshape(4, 4)
c1_np = c0_np[:2]
c2_np = c0_np[2:]
before = gen_mod(Module, "before", {"c0": c0_np})
expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np})

after = relax.transform.FoldConstant()(before)
tvm.ir.assert_structural_equal(after, expected)


if __name__ == "__main__":
tvm.testing.main()
Loading