Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infini_train/include/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Function : public std::enable_shared_from_this<Function> {
virtual std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) = 0;

std::vector<std::shared_ptr<Tensor>> Apply(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
virtual void BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int idx);
virtual void BackwardPartial(std::shared_ptr<Tensor> grad_output, int idx);

void IncreaseDependenciesNumber();

Expand Down
32 changes: 27 additions & 5 deletions infini_train/src/autograd/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ std::vector<std::shared_ptr<Tensor>> Function::Apply(const std::vector<std::shar
return output_tensors;
}

void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int grad_output_idx) {
void Function::BackwardPartial(std::shared_ptr<Tensor> grad_output, int grad_output_idx) {
auto device = grad_output->GetDevice();
core::DeviceGuard guard(device);

Expand All @@ -106,7 +106,7 @@ void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int g
grad_outputs_.resize(1, nullptr);
}
if (!grad_outputs_.at(grad_output_idx)) {
grad_outputs_[grad_output_idx] = grad_output;
grad_outputs_[grad_output_idx] = std::move(grad_output);
++grad_outputs_reached_;
} else {
auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"});
Expand Down Expand Up @@ -144,13 +144,35 @@ void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int g
dependencies_reached_ = 0;

CHECK_EQ(grad_inputs.size(), next_functions_.size());
for (int idx = 0; idx < grad_inputs.size(); ++idx) {
auto &grad_input = grad_inputs[idx];
auto propagate_grad_input = [&](size_t idx) {
auto grad_input = std::move(grad_inputs[idx]);
auto &[next_function, output_idx] = next_functions_[idx];
if (grad_input && next_function) {
next_function->BackwardPartial(grad_input, output_idx);
next_function->BackwardPartial(std::move(grad_input), output_idx);
}
grad_inputs[idx].reset();
};

// Send leaf gradients out first. This recursive engine keeps the
// current function's full grad_inputs vector alive while traversing
// earlier inputs; for ops like Linear(input, weight, bias), visiting
// input first would retain weight/bias gradients across all preceding
// layers. PyTorch's non-recursive engine does not have that stack
// retention pattern, so flush AccumulateGrad edges before recursing
// into non-leaf activation edges.
for (size_t idx = 0; idx < grad_inputs.size(); ++idx) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要遍历两次吗?如果后续希望实现更接近 PyTorch 的 queue-based autograd engine,要不要现在直接放一个用于排序 next_functions_ 的函数,之后可以继续完善这个函数的逻辑

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉目前的状况不太好弄,主要 grad_inputs[idx] 和 next_functions_[idx] 是按 input index 对齐的,而 input index 又跟 module/模型构造有关,如果这要改的话就得一连串地改,只改这一处的话会对不上位置。

如果要排序的话,最多也就是额外再维护一个 list(而不是原地修改 next_functions_),根据这个规则重排得到一个新的顺序,然后再遍历一次。但其实目前 grad_inputs 最多也就两三个,遍历两次也不会有很大开销。

const auto &next_function = next_functions_[idx].first;
if (next_function && std::dynamic_pointer_cast<AccumulateGrad>(next_function)) {
propagate_grad_input(idx);
}
}
for (size_t idx = 0; idx < grad_inputs.size(); ++idx) {
const auto &next_function = next_functions_[idx].first;
if (next_function && !std::dynamic_pointer_cast<AccumulateGrad>(next_function)) {
propagate_grad_input(idx);
}
}
next_functions_.clear();
}
}

Expand Down
Loading