Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/AddAtomicMutex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ class AddAtomicMutex : public IRMutator {
std::string name = unique_name('t');
index_let = index;
index = Variable::make(index.type(), name);
body = ReplaceStoreIndexWithVar(op->producer_name, index).mutate(body);
body = ReplaceStoreIndexWithVar(op->producer_name, index)(body);
}
// This generates a pointer to the mutex array
Expr mutex_array = Variable::make(
Expand Down Expand Up @@ -454,8 +454,8 @@ Stmt add_atomic_mutex(Stmt s, const std::vector<Function> &outputs) {
CheckAtomicValidity check;
s.accept(&check);
if (check.any_atomic) {
s = RemoveUnnecessaryMutexUse().mutate(s);
s = AddAtomicMutex(outputs).mutate(s);
s = RemoveUnnecessaryMutexUse()(s);
s = AddAtomicMutex(outputs)(s);
}
return s;
}
Expand Down
13 changes: 8 additions & 5 deletions src/AddImageChecks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class TrimStmtToPartsThatAccessBuffers : public IRMutator {
bool touches_buffer = false;
const map<string, FindBuffers::Result> &buffers;

protected:
using IRMutator::visit;

Expr visit(const Call *op) override {
Expand Down Expand Up @@ -185,10 +186,10 @@ Stmt add_image_checks_inner(Stmt s,

// Add the input buffer(s) and annotate which output buffers are
// used on host.
s.accept(&finder);
finder(s);

Scope<Interval> empty_scope;
Stmt sub_stmt = TrimStmtToPartsThatAccessBuffers(bufs).mutate(s);
Stmt sub_stmt = TrimStmtToPartsThatAccessBuffers(bufs)(s);
map<string, Box> boxes = boxes_touched(sub_stmt, empty_scope, fb);

// Now iterate through all the buffers, creating a list of lets
Expand Down Expand Up @@ -225,7 +226,7 @@ Stmt add_image_checks_inner(Stmt s,
string extent_name = concat_strings(name, ".extent.", i);
string stride_name = concat_strings(name, ".stride.", i);
replace_with_required[min_name] = Variable::make(Int(32), min_name + ".required");
replace_with_required[extent_name] = simplify(Variable::make(Int(32), extent_name + ".required"));
replace_with_required[extent_name] = Variable::make(Int(32), extent_name + ".required");
replace_with_required[stride_name] = Variable::make(Int(32), stride_name + ".required");
}
}
Expand Down Expand Up @@ -737,6 +738,7 @@ Stmt add_image_checks(const Stmt &s,
// Checks for images go at the marker deposited by computation
// bounds inference.
class Injector : public IRMutator {
protected:
using IRMutator::visit;

Expr visit(const Variable *op) override {
Expand Down Expand Up @@ -794,9 +796,10 @@ Stmt add_image_checks(const Stmt &s,
bool will_inject_host_copies)
: outputs(outputs), t(t), order(order), env(env), fb(fb), will_inject_host_copies(will_inject_host_copies) {
}
} injector(outputs, t, order, env, fb, will_inject_host_copies);
};
Injector injector(outputs, t, order, env, fb, will_inject_host_copies);

return injector.mutate(s);
return injector(s);
}

} // namespace Internal
Expand Down
2 changes: 1 addition & 1 deletion src/AlignLoads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class AlignLoads : public IRMutator {
} // namespace

Stmt align_loads(const Stmt &s, int alignment, int min_bytes_to_align) {
return AlignLoads(alignment, min_bytes_to_align).mutate(s);
return AlignLoads(alignment, min_bytes_to_align)(s);
}

} // namespace Internal
Expand Down
4 changes: 2 additions & 2 deletions src/AllocationBoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ class StripDeclareBoxTouched : public IRMutator {
Stmt allocation_bounds_inference(Stmt s,
const map<string, Function> &env,
const FuncValueBounds &fb) {
s = AllocationInference(env, fb).mutate(s);
s = StripDeclareBoxTouched().mutate(s);
s = AllocationInference(env, fb)(s);
s = StripDeclareBoxTouched()(s);
return s;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Associativity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ AssociativeOp prove_associativity(const string &f, vector<Expr> args, vector<Exp

// Replace any self-reference to Func 'f' with a Var
ConvertSelfRef csr(f, args, idx, op_x_names);
exprs[idx] = csr.mutate(exprs[idx]);
exprs[idx] = csr(exprs[idx]);
if (!csr.is_solvable) {
return AssociativeOp();
}
Expand Down
31 changes: 21 additions & 10 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class NoOpCollapsingMutator : public IRMutator {
};

class GenerateProducerBody : public NoOpCollapsingMutator {
protected:
const string &func;
vector<Expr> sema;
std::set<string> producers_dropped;
Expand Down Expand Up @@ -285,6 +286,7 @@ class GenerateProducerBody : public NoOpCollapsingMutator {
};

class GenerateConsumerBody : public NoOpCollapsingMutator {
protected:
const string &func;
vector<Expr> sema;

Expand Down Expand Up @@ -342,6 +344,7 @@ class GenerateConsumerBody : public NoOpCollapsingMutator {
};

class CloneAcquire : public IRMutator {
protected:
using IRMutator::visit;

const string &old_name;
Expand Down Expand Up @@ -390,6 +393,7 @@ class CountConsumeNodes : public IRVisitor {
};

class ForkAsyncProducers : public IRMutator {
protected:
using IRMutator::visit;

const map<string, Function> &env;
Expand All @@ -414,8 +418,8 @@ class ForkAsyncProducers : public IRMutator {
sema_vars.push_back(Variable::make(type_of<halide_semaphore_t *>(), sema_names.back()));
}

Stmt producer = GenerateProducerBody(name, sema_vars, cloned_acquires).mutate(body);
Stmt consumer = GenerateConsumerBody(name, sema_vars).mutate(body);
Stmt producer = GenerateProducerBody(name, sema_vars, cloned_acquires)(body);
Stmt consumer = GenerateConsumerBody(name, sema_vars)(body);

// Recurse on both sides
producer = mutate(producer);
Expand All @@ -434,7 +438,7 @@ class ForkAsyncProducers : public IRMutator {
// of the producer and consumer.
const vector<string> &clones = cloned_acquires[sema_name];
for (const auto &i : clones) {
body = CloneAcquire(sema_name, i).mutate(body);
body = CloneAcquire(sema_name, i)(body);
body = LetStmt::make(i, sema_space, body);
}

Expand Down Expand Up @@ -493,6 +497,7 @@ class ForkAsyncProducers : public IRMutator {
// simple failure case, error_async_require_fail. One has not been
// written for the complex nested case yet.)
class InitializeSemaphores : public IRMutator {
protected:
using IRMutator::visit;

const Type sema_type = type_of<halide_semaphore_t *>();
Expand Down Expand Up @@ -558,6 +563,7 @@ class InitializeSemaphores : public IRMutator {
// A class to support stmt_uses_vars queries that repeatedly hit the same
// sub-stmts. Used to support TightenProducerConsumerNodes below.
class CachingStmtUsesVars : public IRMutator {
protected:
const Scope<> &query;
bool found_use = false;
std::map<Stmt, bool> cache;
Expand Down Expand Up @@ -613,6 +619,7 @@ class CachingStmtUsesVars : public IRMutator {

// Tighten the scope of consume nodes as much as possible to avoid needless synchronization.
class TightenProducerConsumerNodes : public IRMutator {
protected:
using IRMutator::visit;

Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<> &scope, CachingStmtUsesVars &uses_vars) {
Expand Down Expand Up @@ -703,6 +710,7 @@ class TightenProducerConsumerNodes : public IRMutator {

// Update indices to add ring buffer.
class UpdateIndices : public IRMutator {
protected:
using IRMutator::visit;

Stmt visit(const Provide *op) override {
Expand Down Expand Up @@ -734,6 +742,7 @@ class UpdateIndices : public IRMutator {

// Inject ring buffering.
class InjectRingBuffering : public IRMutator {
protected:
using IRMutator::visit;

struct Loop {
Expand Down Expand Up @@ -768,7 +777,7 @@ class InjectRingBuffering : public IRMutator {
}
current_index = current_index % f.schedule().ring_buffer();
// Adds an extra index for to the all of the references of f.
body = UpdateIndices(op->name, current_index).mutate(body);
body = UpdateIndices(op->name, current_index)(body);

if (f.schedule().async()) {
Expr sema_var = Variable::make(type_of<halide_semaphore_t *>(), f.name() + ".folding_semaphore.ring_buffer");
Expand Down Expand Up @@ -816,6 +825,7 @@ class InjectRingBuffering : public IRMutator {
// Broaden the scope of acquire nodes to pack trailing work into the
// same task and to potentially reduce the nesting depth of tasks.
class ExpandAcquireNodes : public IRMutator {
protected:
using IRMutator::visit;

Stmt visit(const Block *op) override {
Expand Down Expand Up @@ -918,6 +928,7 @@ class ExpandAcquireNodes : public IRMutator {
};

class TightenForkNodes : public IRMutator {
protected:
using IRMutator::visit;

Stmt make_fork(const Stmt &first, const Stmt &rest) {
Expand Down Expand Up @@ -1005,12 +1016,12 @@ class TightenForkNodes : public IRMutator {
} // namespace

Stmt fork_async_producers(Stmt s, const map<string, Function> &env) {
s = TightenProducerConsumerNodes(env).mutate(s);
s = InjectRingBuffering(env).mutate(s);
s = ForkAsyncProducers(env).mutate(s);
s = ExpandAcquireNodes().mutate(s);
s = TightenForkNodes().mutate(s);
s = InitializeSemaphores().mutate(s);
s = TightenProducerConsumerNodes(env)(s);
s = InjectRingBuffering(env)(s);
s = ForkAsyncProducers(env)(s);
s = ExpandAcquireNodes()(s);
s = TightenForkNodes()(s);
s = InitializeSemaphores()(s);
return s;
}

Expand Down
4 changes: 2 additions & 2 deletions src/AutoScheduleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ Expr substitute_var_estimates(Expr e) {
if (!e.defined()) {
return e;
}
return simplify(SubstituteVarEstimates().mutate(e));
return simplify(SubstituteVarEstimates()(e));
}

Stmt substitute_var_estimates(Stmt s) {
if (!s.defined()) {
return s;
}
return simplify(SubstituteVarEstimates().mutate(s));
return simplify(SubstituteVarEstimates()(s));
}

int string_to_int(const string &s) {
Expand Down
3 changes: 2 additions & 1 deletion src/BoundConstantExtentLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace Internal {

namespace {
class BoundLoops : public IRMutator {
protected:
using IRMutator::visit;

std::vector<std::pair<std::string, Expr>> lets;
Expand Down Expand Up @@ -128,7 +129,7 @@ class BoundLoops : public IRMutator {
} // namespace

Stmt bound_constant_extent_loops(const Stmt &s) {
return BoundLoops().mutate(s);
return BoundLoops()(s);
}

} // namespace Internal
Expand Down
2 changes: 1 addition & 1 deletion src/BoundSmallAllocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class BoundSmallAllocations : public IRMutator {
} // namespace

Stmt bound_small_allocations(const Stmt &s) {
return BoundSmallAllocations().mutate(s);
return BoundSmallAllocations()(s);
}

} // namespace Internal
Expand Down
17 changes: 9 additions & 8 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class Bounds : public IRVisitor {

#endif // DO_TRACK_BOUNDS_INTERVALS

private:
protected:
// Compute the intrinsic bounds of a function.
void bounds_of_func(const string &name, int value_index, Type t) {
// if we can't get a good bound from the function, fall back to the bounds of the type.
Expand Down Expand Up @@ -1799,7 +1799,7 @@ Interval bounds_of_expr_in_scope_with_indent(const Expr &expr, const Scope<Inter
#if DO_TRACK_BOUNDS_INTERVALS
b.log_indent = indent + 1;
#endif
expr.accept(&b);
b(expr);
#if DO_TRACK_BOUNDS_INTERVALS
debug(0) << spaces << " mn=" << simplify(b.interval.min) << "\n"
<< spaces << " mx=" << simplify(b.interval.max) << "\n"
Expand Down Expand Up @@ -2023,6 +2023,7 @@ class FindInnermostVar : public IRVisitor {

// Place innermost vars in an IfThenElse's condition as far to the left as possible.
class SolveIfThenElse : public IRMutator {
protected:
// Scope of variable names and their depths. Higher depth indicates
// variable defined more innermost.
Scope<int> vars_depth;
Expand Down Expand Up @@ -2255,7 +2256,7 @@ class BoxesTouched : public IRGraphVisitor {

#endif // DO_TRACK_BOUNDS_INTERVALS

private:
protected:
struct VarInstance {
string var;
int instance;
Expand Down Expand Up @@ -3107,7 +3108,7 @@ map<string, Box> boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool
// as possible, so that BoxesTouched can prune the variable scope tighter
// when encountering the IfThenElse.
if (s.defined()) {
s = SolveIfThenElse().mutate(s);
s = SolveIfThenElse()(s);
}

// Do calls and provides separately, for better simplification.
Expand All @@ -3116,18 +3117,18 @@ map<string, Box> boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool

if (consider_calls) {
if (e.defined()) {
e.accept(&calls);
calls(e);
}
if (s.defined()) {
s.accept(&calls);
calls(s);
}
}
if (consider_provides) {
if (e.defined()) {
e.accept(&provides);
provides(e);
}
if (s.defined()) {
s.accept(&provides);
provides(s);
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ class BoundsInference : public IRMutator {
} select_to_if_then_else;

for (auto &e : exprs) {
e.value = select_to_if_then_else.mutate(e.value);
e.value = select_to_if_then_else(e.value);
}
}

Expand Down Expand Up @@ -1382,8 +1382,7 @@ Stmt bounds_inference(Stmt s,
s = For::make("<outermost>", 0, 0, ForType::Serial, Partition::Never, DeviceAPI::None, s);

s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups,
outputs, func_bounds, target)
.mutate(s);
outputs, func_bounds, target)(s);
return s.as<For>()->body;
}

Expand Down
5 changes: 5 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ target_compile_definitions(Halide PRIVATE WITH_SPIRV)
target_compile_definitions(Halide PRIVATE WITH_VULKAN)
target_compile_definitions(Halide PRIVATE WITH_WEBGPU)

if (WITH_COMPILER_PROFILING)
target_compile_definitions(Halide PRIVATE WITH_COMPILER_PROFILING)
endif()


##
# Flatbuffers and Serialization dependencies.
##
Expand Down
Loading
Loading