Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ enum OperatorType {
OP_RESIDUAL_RMS_NORM,
OP_BEAM_TOPK,
OP_ARGMAX,
OP_DECODING,
OP_INC_MULTIHEAD_SELF_ATTENTION,
OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION,
OP_TREE_INC_MULTIHEAD_SELF_ATTENTION,
Expand Down
8 changes: 8 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ enum TaskIDs {
ARGMAX_INIT_TASK_ID,
ARGMAX_BEAM_INF_TASK_ID,
ARGMAX_NORM_INF_TASK_ID,
DECODING_INIT_TASK_ID,
DECODING_BEAM_INF_TASK_ID,
DECODING_NORM_INF_TASK_ID,
DECODING_PEFT_BWD_TASK_ID,
TRANSPOSE_INIT_TASK_ID,
TRANSPOSE_FWD_TASK_ID,
TRANSPOSE_BWD_TASK_ID,
Expand Down Expand Up @@ -375,6 +379,7 @@ class BeamTopK;
class SpecIncMultiHeadSelfAttention;
class Sampling;
class ArgMax;
class Decoding;
class Combine;
class Repartition;
class Reduction;
Expand Down Expand Up @@ -720,6 +725,7 @@ class FFModel {
bool speculative_decoding,
char const *name = NULL);
Tensor argmax(const Tensor input, bool beam_search, char const *name = NULL);
Tensor decoding(const Tensor input, bool beam_search, char const *name = NULL);
Tensor sampling(const Tensor input, float top_p, char const *name = NULL);
Tensor multihead_attention(const Tensor query,
const Tensor key,
Expand Down Expand Up @@ -1221,6 +1227,8 @@ class FFModel {
Sampling *>,
std::unordered_map<std::pair<ParallelTensorShape, ArgMaxParams>,
ArgMax *>,
std::unordered_map<std::pair<ParallelTensorShape, DecodingParams>,
Decoding *>,
std::unordered_map<
std::pair<ParallelTensorShape, SpecIncMultiHeadSelfAttentionParams>,
SpecIncMultiHeadSelfAttention *>,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/operator_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "flexflow/ops/aggregate_spec_params.h"
#include "flexflow/ops/arg_topk_params.h"
#include "flexflow/ops/argmax_params.h"
#include "flexflow/ops/decoding_params.h"
#include "flexflow/ops/attention_params.h"
#include "flexflow/ops/batch_matmul_params.h"
#include "flexflow/ops/beam_topk_params.h"
Expand Down Expand Up @@ -85,6 +86,7 @@ using OperatorParameters = mp::variant<AggregateParams,
ArgTopKParams,
SamplingParams,
ArgMaxParams,
DecodingParams,
SoftmaxParams,
TransposeParams,
RepartitionParams,
Expand Down
143 changes: 143 additions & 0 deletions include/flexflow/ops/decoding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#ifndef _FLEXFLOW_DECODING_H
#define _FLEXFLOW_DECODING_H

#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/layer.h"
#include "flexflow/node.h"
#include "flexflow/operator.h"
#include "flexflow/ops/decoding_params.h"
#include "flexflow/utils/memory_allocator.h"
#include "flexflow/fftype.h"
#include "flexflow/device.h"

namespace FlexFlow {

// forward declaration
class DecodingMeta;
struct SoftmaxShardedContext;

class Decoding : public Op {
public:
using Params = DecodingParams;
using Input = ParallelTensor;
Decoding(FFModel &model,
LayerID const &_layer_guid,
const ParallelTensor input,
bool beam_search,
char const *name);
Decoding(FFModel &model,
Params const &params,
const Input input,
char const *name = nullptr);
void init(FFModel const &) override;
void init_inference(FFModel const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void forward(FFModel const &) override;
Legion::FutureMap inference(FFModel const &,
BatchConfigFuture const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
Legion::FutureMap peft_bwd(FFModel const &,
BatchConfigFuture const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void backward(FFModel const &) override;
void print_layer(FFModel const &model) override {
assert(0);
}
static Op *
create_operator_from_layer(FFModel &model,
Layer const *layer,
std::vector<ParallelTensor> const &inputs);
static OpMeta *init_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static BeamInferenceResult
inference_task_beam(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static InferenceResult
inference_task_norm(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static bool peft_bwd_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;
void serialize(Legion::Serializer &) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
ParallelTensor inputs[],
int num_inputs);
Op *materialize(FFModel &ff,
ParallelTensor inputs[],
int num_inputs) const override;
Params get_params() const;

template <typename DT>
static void inference_kernel(DecodingMeta const *m,
BatchConfig const *bc,
DT const *input_ptr,
DT *softmax_output_ptr,
int *argmax_output_ptr,
int num_classes,
int vocab_offset,
float *loss,
ffStream_t stream);
static void inference_kernel_wrapper(DecodingMeta *m,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &softmax_output,
GenericTensorAccessorW const &argmax_output);
template <typename DT>
static void peft_bwd_kernel(DecodingMeta const *m,
BatchConfig const *bc,
DT *input_grad_ptr,
int num_classes,
int shard_id,
ffStream_t stream);
static void peft_bwd_kernel_wrapper(DecodingMeta *m,
BatchConfig const *bc,
int shard_id,
GenericTensorAccessorW const &input_grad);

public:
LayerID layer_guid;
bool beam_search;
};

class DecodingMeta : public OpMeta {
public:
DecodingMeta(FFHandler handler,
Decoding const *decoding,
Legion::Domain const &input_domain,
MemoryAllocator &gpu_mem_allocator);
~DecodingMeta(void);
bool beam_search;
float *probs;
float *d_loss;
// Temporary buffers
int *parent_output_buffer;
// Sharded softmax context
SoftmaxShardedContext *softmax_context;
// PEFT related fields
void *output_grad_ptr = nullptr;
size_t allocated_peft_buffer_size = 0;
Realm::RegionInstance reserveInst;
BatchConfig::TokenId peft_token_ids[BatchConfig::MAX_NUM_TOKENS];
};

}; // namespace FlexFlow

#endif // _FLEXFLOW_DECODING_H
26 changes: 26 additions & 0 deletions include/flexflow/ops/decoding_params.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef _FLEXFLOW_DECODING_PARAMS_H
#define _FLEXFLOW_DECODING_PARAMS_H

#include "flexflow/ffconst.h"
#include "flexflow/parallel_tensor.h"

namespace FlexFlow {

struct DecodingParams {
LayerID layer_guid;
bool beam_search;
bool is_valid(ParallelTensorShape const &) const;
char name[MAX_OPNAME];
};
bool operator==(DecodingParams const &, DecodingParams const &);

} // namespace FlexFlow

namespace std {
template <>
struct hash<FlexFlow::DecodingParams> {
size_t operator()(FlexFlow::DecodingParams const &) const;
};
} // namespace std

#endif // _FLEXFLOW_DECODING_PARAMS_H
5 changes: 3 additions & 2 deletions inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,9 @@ void LLAMA::create_llama_model(FFModel &ff,
output = ff.sampling(softmax, generation_config.topp);
} else {
// output = ff.arg_top_k(dense, /*k=*/1, false);
Tensor softmax = ff.softmax(dense, -1);
output = ff.argmax(softmax, /*beam_Search*/ false);
// Tensor softmax = ff.softmax(dense, -1);
// output = ff.argmax(softmax, /*beam_Search*/ false);
output = ff.decoding(dense, /*beam_search*/ false, "decoding");
}
}

Expand Down
32 changes: 32 additions & 0 deletions src/ops/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ bool operator==(ArgMaxParams const &lhs, ArgMaxParams const &rhs) {
return lhs.beam_search == rhs.beam_search;
}

static std::string remove_uid(char const *op_name) {
std::string op_name_without_uid = std::string(op_name);
size_t last_underscore = op_name_without_uid.length();
for (int i = op_name_without_uid.length() - 1; i > 0; i--) {
if (!(std::isdigit(op_name[i]) || op_name[i] == '_')) {
break;
} else if (op_name[i] == '_') {
last_underscore = i;
}
}
if (last_underscore < op_name_without_uid.length()) {
op_name_without_uid.erase(last_underscore);
}
return op_name_without_uid;
}

ArgMax::ArgMax(FFModel &model,
const ParallelTensor _input,
bool _beam_search,
Expand Down Expand Up @@ -136,6 +152,10 @@ ArgMax::ArgMax(FFModel &model,
outputs[1] = model.create_parallel_tensor_legion_ordering(
numdim, dims, DT_INT32, this, 1 /*owner_idx*/);
}
std::string const &input_label = std::string("Argmax input tensor");
_input->print(input_label);
std::string const &label = std::string("Argmax output tensor");
outputs[0]->print(label);
}

ArgMax::ArgMax(FFModel &model, ArgMax const &other, const ParallelTensor input)
Expand Down Expand Up @@ -397,6 +417,18 @@ InferenceResult

ArgMax::inference_kernel_wrapper(m, bc, input, indices, parent, &loss);

if (task->index_point.point_data[0] == 0) {
int in_dim0 = input.domain.hi()[0] - input.domain.lo()[0] + 1;
int in_dim1 = input.domain.hi()[1] - input.domain.lo()[1] + 1;
int out_dim0 = indices.domain.hi()[0] - indices.domain.lo()[0] + 1;
int out_dim1 = indices.domain.hi()[1] - indices.domain.lo()[1] + 1;
std::string op_name_without_uid = remove_uid(m->op_name);
printf("Argmax(%s): in=[%i, bz=%i/%i] -> out=[%i,bz=%i/%i]\n",
op_name_without_uid.c_str(),
in_dim0, bc->num_tokens, in_dim1,
out_dim0, bc->num_tokens, out_dim1);
}

InferenceResult ir;
ir.finetuning_loss = loss;

Expand Down
Loading
Loading