Skip to content

Add support for meta->a_seq_length_dim >= 0 #1060

@KateUnger

Description

@KateUnger

lib/runtime/src/ops/batch_matmul.cc:

static optional<float> backward_task_impl(TaskArgumentAccessor const &acc) {
...
  for (int i = 2; i < a_input.shape.dims.num_dims();
       i++) {
    int dim_size = a_input.shape[legion_dim_t(i)];
    assert(dim_size == b_input.shape[legion_dim_t(i)]);
    assert(dim_size == output.shape[legion_dim_t(i)]);
    batch *= dim_size;
  }

  // TODO: add support for meta->a_seq_length_dim >= 0
  // or meta->b_seq_length_dim >= 0
  assert((meta->a_seq_length_dim >= a_len) || (iter_config.seq_length == 0));
  assert((meta->b_seq_length_dim >= b_len) || (iter_config.seq_length == 0));
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions