Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions csrc/cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include <iostream>
#include <stdexcept>

namespace infinilm::cache {
Expand All @@ -16,6 +17,17 @@ StaticKVCacheConfig::StaticKVCacheConfig(
max_cache_len_(_max_cache_len) {
}

StaticKVCacheConfig::StaticKVCacheConfig(
infinicore::Size _max_batch_size,
infinicore::Size _max_cache_len,
std::string kv_cache_dtype)
: max_batch_size_(_max_batch_size),
max_cache_len_(_max_cache_len) {
if (!kv_cache_dtype.empty()) {
this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype));
}
}

std::unique_ptr<CacheConfig>
StaticKVCacheConfig::unique_copy() const {
return std::make_unique<StaticKVCacheConfig>(*this);
Expand All @@ -42,7 +54,6 @@ StaticKVCache::StaticKVCache(
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::Size max_positional_embedding,
infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
Expand All @@ -53,7 +64,7 @@ StaticKVCache::StaticKVCache(
rank_batch_size_(config.max_batch_size()),
cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
rank_num_layers_(num_layers),
dtype_(dtype) {
dtype_(config.kv_cache_dtype()) {

// Allocate K cache
k_caches_ = infinicore::Tensor::empty(
Expand Down Expand Up @@ -115,9 +126,32 @@ StaticKVCache::update(size_t layer_idx,
return {k_cache_layer, v_cache_layer};
}

infinicore::DataType
StaticKVCacheConfig::kv_cache_dtype() const {
return kv_cache_dtype_.value();
}
void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) {
if (!this->kv_cache_dtype_.has_value()) {
this->kv_cache_dtype_ = std::make_optional(dtype);
} else {
return;
}
}

// ==========================
// PagedKVCacheConfig
// ==========================
PagedKVCacheConfig::PagedKVCacheConfig(
size_t num_blocks,
std::string kv_cache_dtype,
size_t block_size)
: num_blocks_(num_blocks),
block_size_(block_size) {
if (!kv_cache_dtype.empty()) {
this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype));
}
}

PagedKVCacheConfig::PagedKVCacheConfig(
size_t num_blocks,
size_t block_size)
Expand All @@ -140,6 +174,19 @@ PagedKVCacheConfig::block_size() const {
return block_size_;
}

infinicore::DataType
PagedKVCacheConfig::kv_cache_dtype() const {
return kv_cache_dtype_.value();
}

void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) {
if (!this->kv_cache_dtype_.has_value()) {
this->kv_cache_dtype_ = std::make_optional(dtype);
} else {
return;
}
}

// ==========================
// PagedKVCache
// ==========================
Expand All @@ -149,7 +196,6 @@ PagedKVCache::PagedKVCache(
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
Expand All @@ -158,7 +204,7 @@ PagedKVCache::PagedKVCache(
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_num_layers_(num_layers),
dtype_(dtype),
dtype_(config.kv_cache_dtype()),
num_blocks_per_layer_(config.num_blocks()),
block_size_(config.block_size()) {
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
Expand Down
23 changes: 21 additions & 2 deletions csrc/cache/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "base_cache.hpp"

#include "../utils.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
Expand All @@ -10,6 +11,7 @@
#include <limits>
#include <memory>
#include <numeric>
#include <optional>
#include <stdexcept>
#include <utility>

Expand All @@ -22,13 +24,23 @@ class StaticKVCacheConfig final : public CacheConfig {
infinicore::Size _max_batch_size = 1,
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());

StaticKVCacheConfig(
infinicore::Size _max_batch_size,
infinicore::Size _max_cache_len,
std::string kv_cache_dtype);

std::unique_ptr<CacheConfig> unique_copy() const override;
infinicore::Size max_batch_size() const;
infinicore::Size max_cache_len() const;

infinicore::DataType kv_cache_dtype() const;
void set_kv_cache_dtype(infinicore::DataType dtype);

private:
infinicore::Size max_batch_size_;
infinicore::Size max_cache_len_;

std::optional<infinicore::DataType> kv_cache_dtype_ = std::nullopt;
};

class StaticKVCache final : public Cache {
Expand All @@ -41,7 +53,6 @@ class StaticKVCache final : public Cache {
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::Size max_positional_embedding,
infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);

Expand Down Expand Up @@ -88,13 +99,22 @@ class PagedKVCacheConfig final : public CacheConfig {
size_t num_blocks,
size_t block_size = 256);

PagedKVCacheConfig(
size_t num_blocks,
std::string kv_cache_dtype,
size_t block_size = 16);

std::unique_ptr<CacheConfig> unique_copy() const override;
size_t num_blocks() const;
size_t block_size() const;
infinicore::DataType kv_cache_dtype() const;
void set_kv_cache_dtype(infinicore::DataType dtype);

private:
size_t num_blocks_;
size_t block_size_;

std::optional<infinicore::DataType> kv_cache_dtype_ = std::nullopt;
};

class PagedKVCache final : public Cache {
Expand All @@ -106,7 +126,6 @@ class PagedKVCache final : public Cache {
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);

Expand Down
21 changes: 3 additions & 18 deletions csrc/config/model_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,8 @@ ModelConfig::get_rope_scaling() const {
}
}

infinicore::DataType
ModelConfig::get_dtype() const {
try {
std::string dtype_str = this->get<std::string>("torch_dtype");
if (dtype_str == "float32") {
return infinicore::DataType::F32;
} else if (dtype_str == "float16") {
return infinicore::DataType::F16;
} else if (dtype_str == "bfloat16") {
return infinicore::DataType::BF16;
} else if (dtype_str == "int8") {
return infinicore::DataType::I8;
} else {
throw std::runtime_error("Unsupported dtype string: " + dtype_str);
}
} catch (const std::exception &e) {
throw std::runtime_error("Error getting dtype from config: " + std::string(e.what()));
}
infinicore::DataType ModelConfig::get_dtype() const {
std::string dtype_str = this->get<std::string>("torch_dtype");
return parse_dtype(dtype_str);
}
} // namespace infinilm::config
9 changes: 9 additions & 0 deletions csrc/config/model_config.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "../utils.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "quant_config.hpp"
Expand Down Expand Up @@ -63,6 +64,14 @@ class ModelConfig {
infinicore::DataType get_dtype() const;
infinicore::quantization::QuantScheme get_quant_scheme() const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;
void set_kv_quant_scheme(std::string kv_cache_dtype) {
if (kv_cache_dtype == "int8") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没进if报个错什么的吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里默认是一个NONE,说明不使用量化

this->quant_config.set_kv_quant_scheme(kv_cache_dtype);
}
}
infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {
return quant_config.get_kv_quant_scheme();
}

private:
nlohmann::json config_json;
Expand Down
20 changes: 19 additions & 1 deletion csrc/config/quant_config.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once
// #include "../quantization/quantization.hpp"
#include "../utils.hpp"
#include "infinicore/quantization.hpp"
#include "nlohmann/json.hpp"

Expand All @@ -22,9 +22,27 @@ class QuantConfig {
}
}

void set_kv_quant_scheme(std::string kv_cache_dtype) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉和model_config文件里的set功能应该换一下,model_config负责parse和分发,quant_config负责具体逻辑

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉quant config应该完全控制量化相关的内容,model_config最多做一个转发

switch (parse_dtype(kv_cache_dtype)) {
case infinicore::DataType::I8: {
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8;
break;
}
default: {
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
break;
}
}
}

infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {
return kv_quant_scheme;
}

private:
nlohmann::json quantization_config;
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_method;
infinicore::quantization::KVQuantAlgo kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
};

} // namespace infinilm::config
7 changes: 5 additions & 2 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@ InferEngine::InferEngine(
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling,
backends::AttentionBackend attention_backend) // Changed parameter
backends::AttentionBackend attention_backend,
const std::string &kv_cache_dtype) // Changed parameter
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}

// Load model config if model_path is provided, model_path must be valid, and config.json exists
this->model_config_ = std::make_shared<infinilm::config::ModelConfig>(model_path + "/config.json");
// Only support offline int8 kv cache quantization in this version
this->model_config_->set_kv_quant_scheme(kv_cache_dtype);
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
Expand Down Expand Up @@ -168,7 +171,7 @@ const distributed::DistConfig &InferEngine::get_dist_config() const {
//------------------------------------------------------
// reset_cache (overloaded with CacheConfig)
//------------------------------------------------------
void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
void InferEngine::reset_cache(cache::CacheConfig *new_config) {
for (auto &worker : workers_) {
worker->reset_cache(new_config);
}
Expand Down
5 changes: 3 additions & 2 deletions csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class InferEngine {
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
const std::string &kv_cache_dtype = "");

// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
Expand All @@ -59,7 +60,7 @@ class InferEngine {

void compile();

void reset_cache(const cache::CacheConfig *new_config);
void reset_cache( cache::CacheConfig *new_config);

~InferEngine();

Expand Down
2 changes: 1 addition & 1 deletion csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void RankWorker::wait() {
}
}

void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
void RankWorker::reset_cache(cache::CacheConfig *new_config) {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot reset_cache");
Expand Down
2 changes: 1 addition & 1 deletion csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class RankWorker {
void run(const Input &args);

// Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig *new_config);
void reset_cache(cache::CacheConfig *new_config);

// Compile the model graph if enabled.
void compile();
Expand Down
52 changes: 52 additions & 0 deletions csrc/layers/kv_quant.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "kv_quant.hpp"
#include "infinicore/ops/per_tensor_dequant_i8.hpp"
#include "infinicore/ops/per_tensor_quant_i8.hpp"

namespace infinilm {

void KVQuantUtils::quantize(
infinicore::Tensor &k,
infinicore::Tensor &v,
infinicore::quantization::KVQuantAlgo algo,
const infinicore::Tensor &k_scale,
const infinicore::Tensor &v_scale) {

if (algo == infinicore::quantization::KVQuantAlgo::NONE) {
return;
}

auto device = k->device();
auto dtype = k->dtype();
auto zero_point = infinicore::Tensor::zeros({1}, dtype, device);

k = infinicore::op::per_tensor_quant_i8(k, k_scale, zero_point, true);
v = infinicore::op::per_tensor_quant_i8(v, v_scale, zero_point, true);
}

void KVQuantUtils::dequantize(
infinicore::Tensor &k,
infinicore::Tensor &v,
infinicore::quantization::KVQuantAlgo algo,
const infinicore::Tensor &k_scale,
const infinicore::Tensor &v_scale,
const infinicore::Tensor &reference) {

if (algo == infinicore::quantization::KVQuantAlgo::NONE) {
return; // 无需反量化
}

auto zero_point = infinicore::Tensor::zeros({1}, reference->dtype(), reference->device());

auto k_dequant = infinicore::Tensor::strided_empty(
k->shape(), k->strides(), reference->dtype(), reference->device());
auto v_dequant = infinicore::Tensor::strided_empty(
v->shape(), v->strides(), reference->dtype(), reference->device());

infinicore::op::per_tensor_dequant_i8_(k_dequant, k, k_scale, zero_point);
infinicore::op::per_tensor_dequant_i8_(v_dequant, v, v_scale, zero_point);

k = std::move(k_dequant);
v = std::move(v_dequant);
}

} // namespace infinilm
Loading