Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const {
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
// Trigger each worker to run inference
for (auto &worker : workers_) {
worker->run(input.to_model_input());
worker->run(input);
}
// Wait for all workers
for (auto &worker : workers_) {
worker->wait();
}

return {workers_[0]->get_output().logits};
return workers_[0]->get_output();
}

//------------------------------------------------------
Expand Down
23 changes: 2 additions & 21 deletions csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,9 @@ namespace infinilm::engine {

class InferEngine {
public:
struct Input {
/// Token IDs tensor of shape `[batch, seq_len]`.
std::optional<infinicore::Tensor> input_ids;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std::optional<infinicore::Tensor> position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std::optional<infinicore::Tensor> cache_lengths;
/// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_offsets;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
std::optional<infinicore::Tensor> slot_mapping;
using Input = RankWorker::Input;

infinilm::InfinilmModel::Input to_model_input() const;
};

struct Output {
infinicore::Tensor logits;
};
using Output = RankWorker::Output;

// Updated constructor: accept CacheConfig instead of CacheType
InferEngine(
Expand Down
33 changes: 28 additions & 5 deletions csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include "../models/model_factory.hpp"

#include "infinicore/ops.hpp"

#include <iostream>
#include <spdlog/spdlog.h>
#include <stdexcept>
Expand Down Expand Up @@ -95,7 +97,7 @@ std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dic
//------------------------------------------------------
// run -- asynchronous
//------------------------------------------------------
void RankWorker::run(const InfinilmModel::Input &args) {
void RankWorker::run(const Input &args) {
std::lock_guard<std::mutex> lock(mutex_);

if (should_exit_) {
Expand Down Expand Up @@ -156,7 +158,7 @@ void RankWorker::close() {
//------------------------------------------------------
// get_output (thread safe)
//------------------------------------------------------
InfinilmModel::Output RankWorker::get_output() {
RankWorker::Output RankWorker::get_output() {
std::lock_guard<std::mutex> lock(mutex_);
return output_;
}
Expand Down Expand Up @@ -204,7 +206,7 @@ void RankWorker::thread_loop() {
local_param_name = pending_param_name_;
local_param = pending_param_;
} else if (local_cmd == Command::RUN) {
local_args = pending_args_;
local_args = pending_args_.to_model_input();
} else if (local_cmd == Command::RESET_CACHE) {
if (pending_cache_config_ != nullptr) {
local_cache_config = pending_cache_config_->unique_copy();
Expand Down Expand Up @@ -239,10 +241,31 @@ void RankWorker::thread_loop() {

} else if (local_cmd == Command::RUN) {
try {
auto out = model_->forward(local_args);
auto logits{model_->forward(local_args).logits};
infinicore::context::syncStream();

{
if (rank_info_.tp_rank == 0) {
// Perform random sampling
auto temperature{pending_args_.temperature};
auto top_p{pending_args_.top_p};
auto top_k{pending_args_.top_k};
auto random_val{pending_args_.random_val};

const auto &logits_shape{logits->shape()};
const auto &batch_size{logits_shape[0]};
const auto &vocab_size{logits_shape[2]};

auto output_ids{infinicore::Tensor::empty({batch_size}, infinicore::DataType::I32, rank_info_.device)};

for (auto i{decltype(batch_size)(0)}; i < batch_size; ++i) {
auto score{logits->narrow({{0, i, 1}})->view({vocab_size})};
auto out{output_ids->narrow({{0, i, 1}})->view({})};
infinicore::op::random_sample_(
out, score, random_val, top_p, top_k, temperature);
}

auto out{Output{output_ids}};

std::lock_guard<std::mutex> lk(mutex_);
output_ = std::move(out);
job_done_ = true;
Expand Down
39 changes: 35 additions & 4 deletions csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,37 @@ class RankWorker {
};

public:
struct Input {
/// Token IDs tensor of shape `[batch, seq_len]`.
std::optional<infinicore::Tensor> input_ids;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std::optional<infinicore::Tensor> position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std::optional<infinicore::Tensor> cache_lengths;
/// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_offsets;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
std::optional<infinicore::Tensor> slot_mapping;

float temperature{1};

int top_k{50};

float top_p{1};

float random_val{0.1};

infinilm::InfinilmModel::Input to_model_input() const;
};

struct Output {
infinicore::Tensor output_ids;
};

RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config);
Expand All @@ -35,7 +66,7 @@ class RankWorker {
std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();

// Submit a run (forward) job.
void run(const InfinilmModel::Input &args);
void run(const Input &args);

// Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig *new_config);
Expand All @@ -47,7 +78,7 @@ class RankWorker {
void close();

// Thread-safe accessor for last output produced by RUN.
InfinilmModel::Output get_output();
Output get_output();

std::string info() const;

Expand All @@ -73,11 +104,11 @@ class RankWorker {
// Task payloads (protected by mutex)
std::string pending_param_name_;
infinicore::Tensor pending_param_;
InfinilmModel::Input pending_args_;
Input pending_args_;
std::unique_ptr<cache::CacheConfig> pending_cache_config_;

// Output (protected by mutex)
InfinilmModel::Output output_;
Output output_;

// Thread sync
std::thread thread_;
Expand Down
2 changes: 1 addition & 1 deletion csrc/models/infinilm_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class InfinilmModel : public infinicore::nn::Module {
};

struct Output {
/// Output tensor of shape [batch, seq_len, vocab_size].
/// Logits.
infinicore::Tensor logits;
};

Expand Down
23 changes: 19 additions & 4 deletions csrc/pybind11/engine/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,28 @@ inline void bind_infer_engine(py::module &m) {
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) {
return InferEngine::Input{
std::optional<infinicore::Tensor> slot_mapping,
py::kwargs kwargs) {
auto input{InferEngine::Input{
std::move(input_ids),
std::move(position_ids),
std::move(cache_lengths),
std::move(block_tables),
std::move(slot_mapping)};
std::move(slot_mapping)}};

if (kwargs) {
if (kwargs.contains("temperature")) {
input.temperature = kwargs["temperature"].cast<float>();
}
if (kwargs.contains("top_k")) {
input.top_k = kwargs["top_k"].cast<int>();
}
if (kwargs.contains("top_p")) {
input.top_p = kwargs["top_p"].cast<float>();
}
}

return input;
}),
py::arg("input_ids") = std::nullopt,
py::arg("position_ids") = std::nullopt,
Expand All @@ -108,7 +123,7 @@ inline void bind_infer_engine(py::module &m) {
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping);

py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("logits", &InferEngine::Output::logits, "Output tensor");
.def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor");
}

} // namespace infinilm::engine
17 changes: 10 additions & 7 deletions examples/bench.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import infinicore
from transformers import AutoTokenizer
from infinilm.modeling_utils import load_model_state_dict_by_file
import infinilm
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
import argparse
import sys
import time
import os
import json
from collections import OrderedDict
import numpy as np
from tqdm import tqdm

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
Expand Down Expand Up @@ -205,10 +206,9 @@ def __init__(
# ---------------------------------------------------------------------------- #
# 创建模型,
# ---------------------------------------------------------------------------- #
model = infinilm.AutoLlamaModel.from_pretrained(
model = InferEngine(
model_path,
device=infini_device,
backend="cpp",
distributed_config=DistConfig(tp),
)

Expand Down Expand Up @@ -257,14 +257,17 @@ def run(

t1 = time.time()
print("=================== start generate ====================")
self.model.generate(
output_ids = self.model.generate(
input_ids_infini,
max_new_tokens=output_len,
tokenizer=self.tokenizer,
stop_on_eos=False,
GenerationConfig(max_new_tokens=output_len, eos_token_id=[]),
)
t2 = time.time()

numpy_output_ids = np.array(
[output_id.to_numpy()[0] for output_id in output_ids]
)
print(self.tokenizer.decode(numpy_output_ids, skip_special_tokens=True))

print(
f"total_time: {round((t2 - t1) * 1000, 2)} ms",
)
Expand Down
21 changes: 13 additions & 8 deletions examples/jiuge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
from infinilm.modeling_utils import load_model_state_dict_by_file
import infinilm
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
import argparse
import sys
import time
import os
import numpy as np

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))

Expand Down Expand Up @@ -90,17 +91,15 @@ def test(
model_path,
max_new_tokens=100,
infini_device=infinicore.device("cpu", 0),
backend="python",
tp=1,
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
# 创建模型,
# ---------------------------------------------------------------------------- #
model = infinilm.AutoLlamaModel.from_pretrained(
model = InferEngine(
model_path,
device=infini_device,
backend=backend,
distributed_config=DistConfig(tp),
)

Expand Down Expand Up @@ -165,13 +164,17 @@ def test(

t1 = time.time()
print("=================== start generate ====================")
model.generate(
output_ids = model.generate(
input_ids_infini,
max_new_tokens=max_new_tokens,
tokenizer=tokenizer,
GenerationConfig(
max_new_tokens=max_new_tokens, temperature=1, top_k=1, top_p=0.8
),
)
t2 = time.time()

numpy_output_ids = np.array([output_id.to_numpy()[0] for output_id in output_ids])
print(tokenizer.decode(numpy_output_ids, skip_special_tokens=True))

print(
f"total_time: {round((t2 - t1) * 1000, 2)} ms",
)
Expand Down Expand Up @@ -208,13 +211,15 @@ def test(
backend = args.backend
tp = args.tp

if backend != "cpp":
raise ValueError(f"Unsupported backend: {backend}.")

infini_device = infinicore.device(device_str, 0)

test(
prompts,
model_path,
max_new_tokens,
infini_device=infini_device,
backend=backend,
tp=tp,
)
6 changes: 3 additions & 3 deletions examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def test(
model_path,
max_new_tokens=100,
infini_device=infinicore.device("cpu", 0),
backend="python",
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
Expand All @@ -87,7 +86,6 @@ def test(
model = infinilm.AutoLlamaModel.from_pretrained(
model_path,
device=infini_device,
backend=backend,
)

# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -192,12 +190,14 @@ def test(
max_new_tokens = args.max_new_tokens
backend = args.backend

if backend != "python":
raise ValueError(f"Unsupported backend: {backend}.")

infini_device = infinicore.device(device_str, 0)

test(
prompts,
model_path,
max_new_tokens,
infini_device=infini_device,
backend=backend,
)
Loading