-
Notifications
You must be signed in to change notification settings - Fork 248
Open
Milestone
Description
lib/runtime/src/ops/batch_matmul.cc:
static optional<float> backward_task_impl(TaskArgumentAccessor const &acc) {
...
for (int i = 2; i < a_input.shape.dims.num_dims();
i++) {
int dim_size = a_input.shape[legion_dim_t(i)];
assert(dim_size == b_input.shape[legion_dim_t(i)]);
assert(dim_size == output.shape[legion_dim_t(i)]);
batch *= dim_size;
}
// TODO: add support for meta->a_seq_length_dim >= 0
// or meta->b_seq_length_dim >= 0
assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0));
assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0));
}Metadata
Metadata
Assignees
Labels
No labels