Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions lib/kernels/include/kernels/element_binary_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ void forward_kernel(
float *out_ptr,
OperatorType op_type,
bool broadcast_inputLHS,
device_handle_t const &handle);
device_handle_t const &handle,
size_t const num_elements = 0); // optional only used for CPU

void backward_kernel(
device_stream_t const &stream,
Expand All @@ -44,7 +45,8 @@ void backward_kernel(
OperatorType op_type,
bool broadcast_inputLHS,
bool broadcast_inputRHS,
device_handle_t const &handle);
device_handle_t const &handle,
size_t const num_elements = 0); // optional only used for CPU

void cleanup_kernel(
DeviceType device_type,
Expand Down
6 changes: 4 additions & 2 deletions lib/kernels/include/kernels/element_binary_kernels_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ void cpu_forward_kernel(float const *lhs_ptr,
float const *rhs_ptr,
float *out_ptr,
OperatorType op_type,
bool broadcast_inputLHS);
bool broadcast_inputLHS,
size_t const num_elements);

void cpu_backward_kernel(float const *out_grad_ptr,
float const *lhs_ptr,
Expand All @@ -18,7 +19,8 @@ void cpu_backward_kernel(float const *out_grad_ptr,
float *rhs_grad_ptr,
OperatorType op_type,
bool broadcast_inputLHS,
bool broadcast_inputRHS);
bool broadcast_inputRHS,
size_t const num_elements);

} // namespace FlexFlow::Kernels::ElementBinary

Expand Down
27 changes: 16 additions & 11 deletions lib/kernels/include/kernels/profiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ std::optional<milliseconds_t> profiling_wrapper(F const &f,
Ts &&...ts) {
if (enable_profiling) {
ProfilingSettings settings = ProfilingSettings{
/*warmup_iters=*/0,
/*measure_iters=*/1,
/*warmup_iters=*/0_n,
/*measure_iters=*/1_p,
};
return profiling_wrapper<F, Ts...>(f, settings, std::forward<Ts>(ts)...);
} else {
Expand All @@ -33,7 +33,7 @@ std::optional<milliseconds_t>
ProfilingSettings const &settings,
DeviceType device_type,
Ts &&...ts) {
if (settings.measure_iters <= 0) {
if (settings.measure_iters.int_from_positive_int() <= 0) {
return std::nullopt;
}

Expand All @@ -49,7 +49,7 @@ template <typename F, typename... Ts>
milliseconds_t cpu_profiling_wrapper(F const &f,
ProfilingSettings const &settings,
Ts &&...ts) {
ASSERT(settings.measure_iters > 0);
ASSERT(settings.measure_iters.int_from_positive_int() > 0);

device_stream_t stream = get_cpu_device_stream();

Expand All @@ -58,16 +58,19 @@ milliseconds_t cpu_profiling_wrapper(F const &f,
std::optional<TimePoint> start = std::nullopt;
std::optional<TimePoint> end = std::nullopt;

for (int i = 0; i < settings.warmup_iters + settings.measure_iters; i++) {
if (i == settings.warmup_iters) {
for (int i = 0; i < static_cast<int>(settings.warmup_iters) +
settings.measure_iters.int_from_positive_int();
i++) {
if (i == static_cast<int>(settings.warmup_iters)) {
start = std::chrono::steady_clock::now();
}
f(stream, std::forward<Ts>(ts)...);
}
end = std::chrono::steady_clock::now();

std::chrono::duration<double, std::milli> avg_duration =
(end.value() - start.value()) / settings.measure_iters;
(end.value() - start.value()) /
settings.measure_iters.int_from_positive_int();

return milliseconds_t{
static_cast<float>(avg_duration.count()),
Expand All @@ -78,16 +81,18 @@ template <typename F, typename... Ts>
milliseconds_t gpu_profiling_wrapper(F const &f,
ProfilingSettings const &settings,
Ts &&...ts) {
ASSERT(settings.measure_iters > 0);
ASSERT(settings.measure_iters.int_from_positive_int() > 0);

device_stream_t stream = get_gpu_device_stream();

ffEvent_t t_start, t_end;
checkCUDA(ffEventCreate(&t_start));
checkCUDA(ffEventCreate(&t_end));

for (int i = 0; i < settings.warmup_iters + settings.measure_iters; i++) {
if (i == settings.warmup_iters) {
for (int i = 0; i < static_cast<int>(settings.warmup_iters) +
settings.measure_iters.int_from_positive_int();
i++) {
if (i == static_cast<int>(settings.warmup_iters)) {
checkCUDA(ffEventRecord(t_start, stream.require_gpu()));
}
f(stream, std::forward<Ts>(ts)...);
Expand All @@ -100,7 +105,7 @@ milliseconds_t gpu_profiling_wrapper(F const &f,
checkCUDA(ffEventDestroy(t_start));
checkCUDA(ffEventDestroy(t_end));
return milliseconds_t{
elapsed / settings.measure_iters,
elapsed / settings.measure_iters.int_from_positive_int(),
};
}

Expand Down
10 changes: 8 additions & 2 deletions lib/kernels/include/kernels/profiling_settings.dtg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@ features = [
"fmt",
]

includes = [
"utils/nonnegative_int/nonnegative_int.h",
"utils/positive_int/positive_int.h",
]

[[fields]]
name = "warmup_iters"
type = "int"
type = "::FlexFlow::nonnegative_int"


[[fields]]
name = "measure_iters"
type = "int"
type = "::FlexFlow::positive_int"
16 changes: 12 additions & 4 deletions lib/kernels/src/kernels/element_binary_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ void forward_kernel(
float *out_ptr,
OperatorType op_type,
bool broadcast_inputLHS,
device_handle_t const &handle) {
device_handle_t const &handle,
size_t const num_elements) {
if (stream.is_gpu()) {
gpu_forward_kernel(
/*stream=*/stream.require_gpu(),
Expand All @@ -53,12 +54,15 @@ void forward_kernel(
ASSERT(stream.is_cpu());
ASSERT(per_device_state == std::nullopt);
ASSERT(handle.is_for_cpu());
ASSERT(num_elements > 0,
"num_elements must be provided for CPU element_binary kernel");
cpu_forward_kernel(
/*lhs_ptr=*/lhs_ptr,
/*rhs_ptr=*/rhs_ptr,
/*out_ptr=*/out_ptr,
/*op_type=*/op_type,
/*broadcast_inputLHS=*/broadcast_inputLHS);
/*broadcast_inputLHS=*/broadcast_inputLHS,
/*num_elements=*/num_elements);
}
}

Expand All @@ -73,7 +77,8 @@ void backward_kernel(
OperatorType op_type,
bool broadcast_inputLHS,
bool broadcast_inputRHS,
device_handle_t const &handle) {
device_handle_t const &handle,
size_t const num_elements) {
if (stream.is_gpu()) {
gpu_backward_kernel(
/*stream=*/stream.require_gpu(),
Expand All @@ -91,6 +96,8 @@ void backward_kernel(
ASSERT(stream.is_cpu());
ASSERT(per_device_state == std::nullopt);
ASSERT(handle.is_for_cpu());
ASSERT(num_elements > 0,
"num_elements must be provided for CPU element_binary kernel");
cpu_backward_kernel(
/*out_grad_ptr=*/out_grad_ptr,
/*lhs_ptr=*/lhs_ptr,
Expand All @@ -99,7 +106,8 @@ void backward_kernel(
/*rhs_grad_ptr=*/rhs_grad_ptr,
/*op_type=*/op_type,
/*broadcast_inputLHS=*/broadcast_inputLHS,
/*broadcast_inputRHS=*/broadcast_inputRHS);
/*broadcast_inputRHS=*/broadcast_inputRHS,
/*num_elements=*/num_elements);
}
}

Expand Down
52 changes: 47 additions & 5 deletions lib/kernels/src/kernels/element_binary_kernels_cpu.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernels/element_binary_kernels_cpu.h"
#include "op-attrs/operator_type.dtg.h"
#include "utils/exception.h"

namespace FlexFlow::Kernels::ElementBinary {
Expand All @@ -7,8 +8,32 @@ void cpu_forward_kernel(float const *lhs_ptr,
float const *rhs_ptr,
float *out_ptr,
OperatorType op_type,
bool broadcast_inputLHS) {
NOT_IMPLEMENTED();
bool broadcast_inputLHS,
size_t num_elements) {
switch (op_type) {
case OperatorType::EW_ADD:
for (size_t i = 0; i < num_elements; i++) {
out_ptr[i] = lhs_ptr[i] + rhs_ptr[i];
}
break;
case OperatorType::EW_SUB:
for (size_t i = 0; i < num_elements; i++) {
out_ptr[i] = lhs_ptr[i] - rhs_ptr[i];
}
break;
case OperatorType::EW_MUL:
for (size_t i = 0; i < num_elements; i++) {
out_ptr[i] = lhs_ptr[i] * rhs_ptr[i];
}
break;
case OperatorType::EW_DIV:
for (size_t i = 0; i < num_elements; i++) {
out_ptr[i] = lhs_ptr[i] / rhs_ptr[i];
}
break;
default:
NOT_IMPLEMENTED();
}
}

void cpu_backward_kernel(float const *out_grad_ptr,
Expand All @@ -18,8 +43,25 @@ void cpu_backward_kernel(float const *out_grad_ptr,
float *rhs_grad_ptr,
OperatorType op_type,
bool broadcast_inputLHS,
bool broadcast_inputRHS) {
NOT_IMPLEMENTED();
bool broadcast_inputRHS,
size_t num_elements) {
switch (op_type) {
case OperatorType::EW_ADD:
case OperatorType::EW_SUB:
for (size_t i = 0; i < num_elements; i++) {
lhs_grad_ptr[i] += out_grad_ptr[i];
rhs_grad_ptr[i] += (op_type == OperatorType::EW_SUB) ? -out_grad_ptr[i]
: out_grad_ptr[i];
}
break;
case OperatorType::EW_MUL:
for (size_t i = 0; i < num_elements; i++) {
lhs_grad_ptr[i] += out_grad_ptr[i] * rhs_ptr[i];
rhs_grad_ptr[i] += out_grad_ptr[i] * lhs_ptr[i];
}
break;
default:
NOT_IMPLEMENTED();
}
}

} // namespace FlexFlow::Kernels::ElementBinary
31 changes: 29 additions & 2 deletions lib/kernels/src/kernels/element_unary_kernels_cpu.cc
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
#include "kernels/element_unary_kernels_cpu.h"
#include "kernels/map_tensor_accessors.h"
#include "kernels/tensor_accessor_unary_ops.h"
#include "op-attrs/ops/element_unary_attrs.dtg.h"
#include "utils/exception.h"

namespace FlexFlow::Kernels::ElementUnary {

void cpu_forward_kernel(ElementUnaryAttrs const &attrs,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
NOT_IMPLEMENTED();
switch (attrs.op_type) {
case OperatorType::RELU:
tensor_accessor_relu_to(input, output);
break;
default:
NOT_IMPLEMENTED();
}
}

void cpu_backward_kernel(ElementUnaryAttrs const &attrs,
GenericTensorAccessorR const &output,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &input_grad) {
NOT_IMPLEMENTED();

switch (attrs.op_type) {
case OperatorType::RELU:
// relu backward: input_grad = output_grad * (output > 0)
map_tensor_accessors2_to(
output_grad,
output,
output_grad.shape.data_type,
[](auto grad, auto out) {
return out > static_cast<decltype(out)>(0)
? grad
: static_cast<decltype(grad)>(0);
},
input_grad);
break;
default:
NOT_IMPLEMENTED();
}
}

} // namespace FlexFlow::Kernels::ElementUnary
30 changes: 18 additions & 12 deletions lib/kernels/src/kernels/linear_kernels_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ void linear_cpu_forward_kernel(
}
}

// template <typename T>
static float single_element_relu_bwd(float elem) {
if (elem > 0) {
return 1;
} else {
return 0;
}
}

void linear_cpu_backward_kernel(
LinearAttrs const &attrs,
GenericTensorAccessorR const &output,
Expand All @@ -65,11 +56,26 @@ void linear_cpu_backward_kernel(
std::optional<GenericTensorAccessorR> processed_output_grad = std::nullopt;
if (attrs.activation.has_value()) {
switch (attrs.activation.value()) {
case Activation::RELU:
case Activation::RELU: {
// relu backward: output_grad * (output > 0)
// output here is POST-activation (relu output)
// output > 0 iff pre-activation > 0 since relu(x) > 0 iff x > 0
GenericTensorAccessorW grad_buf =
cpu_allocator.allocate_tensor(output_grad.shape);
map_tensor_accessors2_to(
output_grad,
output,
output_grad.shape.data_type,
[](auto grad, auto out) {
return out > static_cast<decltype(out)>(0)
? grad
: static_cast<decltype(grad)>(0);
},
grad_buf);
processed_output_grad =
read_only_accessor_from_write_accessor(map_tensor_accessor(
output_grad, single_element_relu_bwd, cpu_allocator));
read_only_accessor_from_write_accessor(grad_buf);
break;
}
default:
PANIC("Unhandled activation function", attrs.activation.value());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor {
PCGOperatorAttrs get_op_attrs() const override;
LossAttrs get_loss_attrs() const override;
PerDeviceOpState get_per_device_op_state() const override;
bool has_per_device_op_state() const override;
FFIterationConfig get_iteration_config() const override;
OptimizerAttrs get_optimizer_attrs() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ PerDeviceOpState LocalTaskArgumentAccessor::get_per_device_op_state() const {
return assert_unwrap(this->per_device_op_state);
}

bool LocalTaskArgumentAccessor::has_per_device_op_state() const {
return this->per_device_op_state.has_value();
}

FFIterationConfig LocalTaskArgumentAccessor::get_iteration_config() const {
return this->iteration_config;
}
Expand Down
Loading