Skip to content
2 changes: 1 addition & 1 deletion python_bindings/src/halide/_generator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def generator(name: str = ""):
_check(sys.version_info >= (3, 7), "Halide Generators require Python 3.7 or later.")

def generator_impl(cls):
n = name if name else _fqname(cls)
n = name or _fqname(cls)
_check_generator_name_in_use(n)
_check(isclass(cls), "@generator can only be used on classes.")
# Allow (but don't require) explicit inheritance from hl.Generator;
Expand Down
23 changes: 10 additions & 13 deletions src/AssociativeOpsTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ void populate_ops_table_double_general_sub(const vector<Type> &types, vector<Ass

void populate_ops_table_double_general_select(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_double(types);
// Argmax with index as first tuple element
table.push_back({{select(x1 > y1, x0, y0), max(x1, y1)}, {zero_0, tmin_1}, true});
table.push_back({{select(x1 >= y1, x0, y0), max(x1, y1)}, {zero_0, tmin_1}, true});
table.push_back({{select(y1 < x1, x0, y0), max(x1, y1)}, {zero_0, tmin_1}, true});
table.push_back({{select(y1 <= x1, x0, y0), max(x1, y1)}, {zero_0, tmin_1}, true});
// Argmin with index as first tuple element
table.push_back({{select(x1 < y1, x0, y0), min(x1, y1)}, {zero_0, tmax_1}, true});
table.push_back({{select(x1 <= y1, x0, y0), min(x1, y1)}, {zero_0, tmax_1}, true});
table.push_back({{select(y1 > x1, x0, y0), min(x1, y1)}, {zero_0, tmax_1}, true});
table.push_back({{select(y1 >= x1, x0, y0), min(x1, y1)}, {zero_0, tmax_1}, true});
}

void populate_ops_table_single_uint1_and(const vector<Type> &types, vector<AssociativePattern> &table) {
Expand Down Expand Up @@ -326,19 +336,6 @@ const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types
return table_it->second;
}

std::string print_types(const vector<Type> &types) {
std::ostringstream stream;
stream << "{";
for (size_t i = 0; i < types.size(); ++i) {
if (i > 0) {
stream << ", ";
}
stream << types[i];
}
stream << "}";
return stream.str();
}

} // anonymous namespace

const vector<AssociativePattern> &get_ops_table(const vector<Expr> &exprs) {
Expand Down
107 changes: 50 additions & 57 deletions src/Associativity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ namespace Halide {
namespace Internal {

using std::map;
using std::pair;
using std::set;
using std::string;
using std::vector;
Expand Down Expand Up @@ -103,8 +102,14 @@ bool associative_op_pattern_match(const Expr &e,
<< "Expr has type " << e.type() << ", while pattern has type " << op.type() << "\n";
map<string, Expr> result;
if (expr_match(op, e, result)) {
debug(5) << "Found associative ops for " << e << " -> " << op
<< ", y_part: " << result["y0"] << "\n";
debug(5) << "Found associative ops for " << e << " -> " << op << ":\n"
<< [&] {
std::stringstream ss;
for (const auto &[var, val] : result) {
ss << " " << var << " -> " << val << "\n";
}
return ss.str();
}();
Comment on lines +105 to +112
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Fun" fact: without this patch, we get a segfault when HL_DEBUG_CODEGEN is set to print this message. When the associative op is in the second tuple component, only y1 is in the result, yet the lookup itself will add an entry to for y0, with an undefined Expr for the value. Later, the expr_uses_vars check will try to visit the undefined Expr and a null pointer gets dereferenced.


for (size_t i = 0; i < x_names.size(); ++i) {
const auto &iter = result.find("x" + std::to_string(i));
Expand Down Expand Up @@ -187,7 +192,6 @@ bool find_match(const vector<AssociativePattern> &table, const vector<string> &o
continue;
}

vector<pair<Expr, Expr>> replacement; // find -> replacement
for (size_t index = 0; index < op_y_names.size(); ++index) {
const auto &y_iter = pattern_match.find("y" + std::to_string(index));
if (y_iter == pattern_match.end()) {
Expand All @@ -202,20 +206,25 @@ bool find_match(const vector<AssociativePattern> &table, const vector<string> &o

assoc_op.xs[index] = {op_x_names[index], x_parts[index]};
assoc_op.ys[index] = {op_y_names[index], y_part};
replacement.emplace_back(y_part, Variable::make(y_part.type(), op_y_names[index]));
}
if (!matched) {
continue;
}
for (size_t index = 0; index < exprs.size(); ++index) {
Expr e = exprs[index];
// Order of substitution matters, e.g. in the argmin case, _y_0 -> g(rx)[0]
// and _y_1 -> rx. If we substitute the 2nd element rx first, substitution
// of g(rx)[0] will fail.
for (const auto &iter : replacement) {
e = substitute(iter.first, iter.second, e);
// Build the concrete ops by renaming the pattern's abstract
// wildcard variables (x0, y0, k0, ...) to the actual variable
// names used in the expressions.
map<string, Expr> replacement;
for (size_t index = 0; index < op_x_names.size(); ++index) {
replacement["x" + std::to_string(index)] = Variable::make(exprs[index].type(), op_x_names[index]);
replacement["y" + std::to_string(index)] = Variable::make(exprs[index].type(), op_y_names[index]);
}
for (const auto &[wildcard, identity] : pattern_match) {
if (wildcard[0] == 'k') {
replacement[wildcard] = identity;
}
assoc_op.pattern.ops[index] = e;
}
for (size_t index = 0; index < pattern.ops.size(); ++index) {
assoc_op.pattern.ops[index] = substitute(replacement, pattern.ops[index]);
assoc_op.pattern.identities[index] = pattern.identities[index];
}
assoc_op.pattern.is_commutative = pattern.is_commutative;
Expand All @@ -225,7 +234,7 @@ bool find_match(const vector<AssociativePattern> &table, const vector<string> &o
}

// Return a pair of booleans indicating if an operator is associative.
// 'assoc_op' contains the the equivalent associative binary/unary operator
// 'assoc_op' contains the equivalent associative binary/unary operator
// for that operator. If the operator is non-associative, 'assoc_op' is not valid.
bool extract_associative_op(const vector<Expr> &exprs, const vector<string> &op_x_names,
const vector<string> &op_y_names, const vector<Expr> &x_parts,
Expand All @@ -236,7 +245,7 @@ bool extract_associative_op(const vector<Expr> &exprs, const vector<string> &op_
// An update that just assigns some value is not associative,
// because there's no good identity. An identity is necessary
// because things like rfactor will combine the identity with
// partially-computed values and expect it to do nothing. For an
// partially computed values and expect it to do nothing. For an
// example, see https://github.com/halide/Halide/issues/7893
return false;
} else if (equal(exprs[0], Variable::make(t, op_x_names[0]))) {
Expand All @@ -256,58 +265,44 @@ bool extract_associative_op(const vector<Expr> &exprs, const vector<string> &op_
x_parts, exprs, assoc_op);
}

void add_transitive_dependencies(vector<set<int>> &dependencies) {
// TODO(psuriana): there might be a better way to find all the transitive dependencies
bool change = true;
while (change) {
change = false;
bool is_subset_of(const std::set<int> &a, const std::set<int> &b) {
return std::includes(b.begin(), b.end(), a.begin(), a.end());
}

// Compute the dependency subgraphs for a tuple reduction. First closes the
// dependency relation transitively, then returns only the earliest (by index)
// maximal dependency sets, clearing any set contained in a dominating one.
vector<set<int>> compute_subgraphs(vector<set<int>> dependencies) {
// Compute the transitive closure using Warshall's algorithm.
for (size_t k = 0; k < dependencies.size(); ++k) {
for (size_t i = 0; i < dependencies.size(); ++i) {
for (size_t j = 0; j < dependencies.size(); ++j) {
if (i == j) {
continue;
}
if (dependencies[i].count(j)) {
for (const auto &idx : dependencies[j]) {
if (dependencies[i].count(idx) == 0) {
dependencies[i].insert(idx);
change = true;
}
}
if (dependencies[i].count(k)) {
for (int j : dependencies[k]) {
dependencies[i].insert(j);
}
}
}
}
}

// Given dependencies of each tuple element, compute the set of subgraphs:
// all vertices that are reachable from a given vertex. If a subgraph is fully
// contained in another subgraph, remove it from the final output.
vector<set<int>> compute_subgraphs(vector<set<int>> dependencies) {
// Keep only maximal dependency sets. A set is removed if another
// set strictly contains it or is identical but has a lower index.
vector<set<int>> subgraphs(dependencies.size());
for (size_t i = 0; i < dependencies.size(); ++i) {
// Check if the current subgraph is a subset of another
const auto &current = dependencies[i];
if (current.empty()) {
if (dependencies[i].empty()) {
continue;
}
bool should_remove = false;
bool is_maximal = true;
for (size_t j = 0; j < dependencies.size(); ++j) {
const auto &other = dependencies[j];
if ((i == j) || (current.size() > other.size()) || (j < i && subgraphs[i].empty())) {
continue;
}
vector<int> diff;
// Compute the vertices in the current set that are not contained in the other
std::set_difference(current.begin(), current.end(), other.begin(), other.end(),
std::inserter(diff, diff.begin()));
if (diff.empty()) {
// 'current' is fully contained in 'other'
should_remove = true;
const bool can_dominate =
(dependencies[j].size() > dependencies[i].size()) ||
(dependencies[j].size() == dependencies[i].size() && j < i);
if (can_dominate && is_subset_of(dependencies[i], dependencies[j])) {
is_maximal = false;
break;
}
}
if (!should_remove) {
subgraphs[i] = current;
if (is_maximal) {
subgraphs[i] = dependencies[i];
}
}
return subgraphs;
Expand Down Expand Up @@ -353,8 +348,8 @@ AssociativeOp prove_associativity(const string &f, vector<Expr> args, vector<Exp
}
x_parts[idx] = csr.x_part;
dependencies[idx] = csr.x_dependencies;
// Add dependency on itself (regardless whether it actually depends on
// its previous values) for the purpose of computing the subgraph
// Add a dependency on itself (regardless of whether it actually
// depends on its previous values) to compute the subgraph
dependencies[idx].insert(idx);

exprs[idx] = common_subexpression_elimination(exprs[idx]);
Expand All @@ -367,8 +362,6 @@ AssociativeOp prove_associativity(const string &f, vector<Expr> args, vector<Exp
vector<set<int>> subgraphs;
if (!all_independent) {
debug(5) << "There are cross-dependencies. Need to prove associativity in bulk.\n";
// Find all transitive dependencies and add them to the graph
add_transitive_dependencies(dependencies);
// Decompose the tuple into subgraphs and solve for each separately
subgraphs = compute_subgraphs(dependencies);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/Associativity.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ struct AssociativeOp {
/**
* Given an update definition of a Func 'f', determine its equivalent
* associative binary/unary operator if there is any. 'is_associative'
* indicates if the operation was successfuly proven as associative.
* indicates if the operation was successfully proven as associative.
*/
AssociativeOp prove_associativity(
const std::string &f, std::vector<Expr> args, std::vector<Expr> exprs);
Expand Down
95 changes: 87 additions & 8 deletions test/correctness/rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,12 +831,70 @@ int argmin_rfactor_test() {
return 0;
}

enum class InlineReductionVariant {
ArgMin,
ArgMax,
};

template<InlineReductionVariant variant>
int inline_reductions_test() {
using namespace ConciseCasts;
constexpr float pi = M_PI;

Func f{"f"};
Var x("x");
f(x) = sin(f32(x) / 8 * pi); // argmax should be f(4) = 1.0, argmin should be f(12) = -10.0
f.compute_root();

RDom r(0, 32);

Func g{"reduction"};
Func output{"g"};

if constexpr (variant == InlineReductionVariant::ArgMin) {
output() = argmin(f(r), g);
} else {
output() = argmax(f(r), g);
}

RVar ro("rxo"), ri("rxi");
g.update(0).split(r, ro, ri, 2);

Var u("u");
Func intm = g.update(0).rfactor(ro, u);
intm.compute_root();
intm.update(0).vectorize(u, 2);

Realization rn = output.realize();
Buffer<int> sch_idx(rn[0]);
Buffer<float> sch_val(rn[1]);

if constexpr (variant == InlineReductionVariant::ArgMin) {
if (sch_val() != -1.0f || sch_idx() != 12) {
fprintf(stderr, "Expected argmin to be f(12) = -1.0, got f(%d) = %f\n", sch_idx(), sch_val());
return 1;
}
} else {
if (sch_val() != 1.0f || sch_idx() != 4) {
fprintf(stderr, "Expected argmax to be f(4) = 1.0, got f(%d) = %f\n", sch_idx(), sch_val());
return 1;
}
}

return 0;
}

enum class ArgMaxVariant {
Explicit,
TupleSelect
};

template<ArgMaxVariant variant>
enum class ArgMaxTupleOrder {
IndexFirst,
ValueFirst,
};

template<ArgMaxVariant variant, ArgMaxTupleOrder order>
int argmax_rfactor_test() {
using namespace ConciseCasts;
constexpr float pi = M_PI;
Expand All @@ -849,12 +907,29 @@ int argmax_rfactor_test() {
RDom r(0, 32);

Func g{"g"};
g() = Tuple(f.type().min(), r.x.min());

int value_tup = order == ArgMaxTupleOrder::ValueFirst ? 0 : 1;
int index_tup = order == ArgMaxTupleOrder::ValueFirst ? 1 : 0;

if constexpr (order == ArgMaxTupleOrder::ValueFirst) {
g() = Tuple(f.type().min(), r.x.min());
} else {
g() = Tuple(r.x.min(), f.type().min());
}

if constexpr (variant == ArgMaxVariant::Explicit) {
g() = Tuple(max(f(r), g()[0]), select(g()[0] < f(r), r, g()[1]));
if constexpr (order == ArgMaxTupleOrder::ValueFirst) {
g() = Tuple(max(f(r), g()[value_tup]), select(g()[value_tup] < f(r), r, g()[index_tup]));
} else {
g() = Tuple(select(g()[value_tup] < f(r), r, g()[index_tup]), max(f(r), g()[value_tup]));
}
} else {
static_assert(variant == ArgMaxVariant::TupleSelect);
g() = select(g()[0] < f(r), Tuple(f(r), r), g());
if constexpr (order == ArgMaxTupleOrder::ValueFirst) {
g() = select(g()[value_tup] < f(r), Tuple(f(r), r), g());
} else {
g() = select(g()[value_tup] < f(r), Tuple(r, f(r)), g());
}
}

RVar ro("rxo"), ri("rxi");
Expand All @@ -866,8 +941,8 @@ int argmax_rfactor_test() {
intm.update(0).vectorize(u, 2);

Realization rn = g.realize();
Buffer<float> sch_val(rn[0]);
Buffer<int> sch_idx(rn[1]);
Buffer<float> sch_val(rn[value_tup]);
Buffer<int> sch_idx(rn[index_tup]);

if (sch_val() != 1.0f || sch_idx() != 4) {
fprintf(stderr, "Expected argmax to be f(4) = 1.0, got f(%d) = %f\n", sch_idx(), sch_val());
Expand Down Expand Up @@ -1208,8 +1283,12 @@ int main(int argc, char **argv) {
{"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test},
{"complex multiply rfactor test", complex_multiply_rfactor_test},
{"argmin rfactor test", argmin_rfactor_test},
{"argmax rfactor test (explicit)", argmax_rfactor_test<ArgMaxVariant::Explicit>},
{"argmax rfactor test (tuple)", argmax_rfactor_test<ArgMaxVariant::TupleSelect>},
{"inline reductions test (argmin)", inline_reductions_test<InlineReductionVariant::ArgMin>},
{"inline reductions test (argmax)", inline_reductions_test<InlineReductionVariant::ArgMax>},
{"argmax rfactor test (explicit, index first)", argmax_rfactor_test<ArgMaxVariant::Explicit, ArgMaxTupleOrder::IndexFirst>},
{"argmax rfactor test (tuple, index first)", argmax_rfactor_test<ArgMaxVariant::TupleSelect, ArgMaxTupleOrder::IndexFirst>},
{"argmax rfactor test (explicit, value first)", argmax_rfactor_test<ArgMaxVariant::Explicit, ArgMaxTupleOrder::ValueFirst>},
{"argmax rfactor test (tuple, value first)", argmax_rfactor_test<ArgMaxVariant::TupleSelect, ArgMaxTupleOrder::ValueFirst>},
{"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test},
{"rfactor bounds tests", rfactor_precise_bounds_test},
{"isnan max rfactor test (bitwise or)", isnan_max_rfactor_test<BitwiseOr>},
Expand Down
Loading