Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 43 additions & 34 deletions src/RegionCosts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,61 +244,70 @@ class ExprCost : public IRVisitor {
}
}

void visit(const Shuffle *op) override {
arith += 1;
}

void visit(const Let *let) override {
let->value.accept(this);
let->body.accept(this);
}

// None of the following IR nodes should be encountered when traversing the
// IR at the level at which the auto scheduler operates.
void visit(const Load *) override {
internal_error;
void fail(const Expr &e) {
internal_error << "Unexpected Expr while computing region costs: " << e << "\n"
<< "Expected front-end Exprs only.";
}
void fail(const Stmt &s) {
internal_error << "Unexpected Stmt while computing region costs:\n"
<< s << "\n"
<< "Expected front-end Exprs only.";
}
void visit(const Ramp *) override {
internal_error;

void visit(const Load *op) override {
fail(op);
}
void visit(const Ramp *op) override {
fail(op);
}
void visit(const Shuffle *op) override {
fail(op);
}
void visit(const Broadcast *) override {
internal_error;
void visit(const Broadcast *op) override {
fail(op);
}
void visit(const LetStmt *) override {
internal_error;
void visit(const LetStmt *op) override {
fail(op);
}
void visit(const AssertStmt *) override {
internal_error;
void visit(const AssertStmt *op) override {
fail(op);
}
void visit(const ProducerConsumer *) override {
internal_error;
void visit(const ProducerConsumer *op) override {
fail(op);
}
void visit(const For *) override {
internal_error;
void visit(const For *op) override {
fail(op);
}
void visit(const Store *) override {
internal_error;
void visit(const Store *op) override {
fail(op);
}
void visit(const Provide *) override {
internal_error;
void visit(const Provide *op) override {
fail(op);
}
void visit(const Allocate *) override {
internal_error;
void visit(const Allocate *op) override {
fail(op);
}
void visit(const Free *) override {
internal_error;
void visit(const Free *op) override {
fail(op);
}
void visit(const Realize *) override {
internal_error;
void visit(const Realize *op) override {
fail(op);
}
void visit(const Block *) override {
internal_error;
void visit(const Block *op) override {
fail(op);
}
void visit(const IfThenElse *) override {
internal_error;
void visit(const IfThenElse *op) override {
fail(op);
}
void visit(const Evaluate *) override {
internal_error;
void visit(const Evaluate *op) override {
fail(op);
}

public:
Expand Down
1 change: 1 addition & 0 deletions src/SplitTuples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class SplitTuples : public IRMutator {
could_alias(op->args, store_args)) {
deps.insert(op->value_index);
}
IRVisitor::visit(op);
}

bool could_alias(const vector<Expr> &a, const vector<Expr> &b) {
Expand Down
3 changes: 1 addition & 2 deletions src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,8 +1452,7 @@ class FindVectorizableExprsInAtomicNode : public IRMutator {
Stmt visit(const Store *op) override {
// A store poisons all subsequent loads, but loads before the
// first store can be lifted.
mutate(op->index);
mutate(op->value);
IRMutator::visit(op);
poisoned_names.push(op->name);
return op;
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/custom_lowering_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class CheckForFloatDivision : public IRMutator {
std::cerr << "Found floating-point division by constant: " << Expr(op) << "\n";
exit(1);
}
IRMutator::visit(op);
return op;
}
};
Expand Down
40 changes: 40 additions & 0 deletions test/correctness/tuple_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,46 @@ int main(int argc, char **argv) {
}
}

{
// A case which requires tuple updates to be atomic, but hides a
// dependence in a way that triggered a bug in the past.
Func f, g;
Var x, y;

f(x) = Tuple(x + 17, x + 1);
constexpr int w = 100;

RDom r(0, w);
f(r) = Tuple(f(r)[0] + 5, f(clamp(f(r)[0], 0, w - 1))[1]);
g(x, y) = mux(y, {f(x)[0], f(x)[1]});

f.compute_root();

Buffer<int> buf = g.realize({w, 2});
Buffer<int> correct(w, 2);
for (int x = 0; x < w; x++) {
correct(x, 0) = x + 17;
correct(x, 1) = x + 1;
}
for (int r = 0; r < w; r++) {
int new_0 = correct(r, 0) + 5;
int new_1 = correct(std::min(std::max(correct(r, 0), 0), w - 1), 1);
// Tuple element 1 might depend on the old value of tuple element
// zero. The new values must be both computed *then* assigned.
correct(r, 0) = new_0;
correct(r, 1) = new_1;
}

for (int x = 0; x < w; x++) {
for (int y = 0; y < 2; y++) {
if (buf(x, y) != correct(x, y)) {
printf("buf(%d, %d) = %d instead of %d\n", x, y, buf(x, y), correct(x, y));
return -1;
}
}
}
}

printf("Success!\n");
return 0;
}
Loading