Skip to content

Commit 5976cd2

Browse files
committed
fix: hasSameDimensions
1 parent 7e2c8ea commit 5976cd2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/interface/Einsum.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ namespace mlc
6767
* @brief Executes the Einsum operation with the given inputs and output tensor.
6868
*/
6969
template <typename T> Error execute(const std::vector<T> &inputs, Tensor &output);
70-
template <typename T> Error hasSameDimensions(const std::vector<T> &inputs);
70+
template <typename T> Error hasSameDimensions(const std::vector<T> &inputs, const Tensor &output);
7171

7272
Error error;
7373
mini_jit::EinsumTree einsumTree;
@@ -95,16 +95,16 @@ namespace mlc
9595
template <typename T> inline Error EinsumOperation::hasSameDimensions(const std::vector<T> &inputs, const Tensor &output)
9696
{
9797
auto &sortedDimSizes = einsumTree.get_sorted_dim_sizes();
98-
const mini_jit::EinsumTree::EinsumNode *root = einsumTree.getRoot();
98+
const mini_jit::EinsumTree::EinsumNode *root = einsumTree.get_root();
9999

100-
if (output->dim_sizes.size() != root->output_dim_ids.size())
100+
if (output.dim_sizes.size() != root->output_dim_ids.size())
101101
{
102102
return {ErrorType::ExecuteWrongDimension, "The count of dimensions do not match in the output tensor."};
103103
}
104104

105105
for (size_t i = 0; i < root->output_dim_ids.size(); i++)
106106
{
107-
if (output->dim_sizes[i] != static_cast<uint64_t>(sortedDimSizes[root->output_dim_ids[i]]))
107+
if (output.dim_sizes[i] != static_cast<uint64_t>(sortedDimSizes[root->output_dim_ids[i]]))
108108
{
109109
return {ErrorType::ExecuteWrongDimension,
110110
"The output tensor dimension has a different size than the size than the tensor it was setup up with."};

0 commit comments

Comments
 (0)