Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/layers/mlp/moe_mlp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class MoeMLP : public infinicore::nn::Module {

size_t hidden_size() const { return hidden_size_; }
size_t moe_intermediate_size() const { return moe_intermediate_size_; }
void set_alpha(float alpha) { down_proj_->set_alpha(alpha); }

protected:
INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, gate_proj);
Expand Down
59 changes: 59 additions & 0 deletions csrc/models/qwen3_moe/qwen3_moe_experts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "qwen3_moe_experts.hpp"

namespace infinilm::models::qwen3_moe {
Qwen3MoeExperts::Qwen3MoeExperts(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device) {

num_experts_ = model_config->get<size_t>("num_experts");
num_experts_per_tok_ = model_config->get<size_t>("num_experts_per_tok");

ASSERT((num_experts_ > 0) && (num_experts_per_tok_ > 0) && (num_experts_per_tok_ <= num_experts_));

for (size_t i = 0; i < num_experts_; ++i) {
experts_.push_back(this->register_module<Qwen3MoeMLP>(std::to_string(i), model_config, device));
}
}

infinicore::Tensor Qwen3MoeExperts::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &top_k_index,
const infinicore::Tensor &top_k_weights) const {
ASSERT(hidden_states->ndim() == 2);

auto top_k_weights_cpu = top_k_weights->to(infinicore::Device::Type::CPU);
auto top_k_index_cpu = top_k_index->to(infinicore::Device::Type::CPU);

int *top_k_index_ptr = (int *)top_k_index_cpu->data();
float *top_k_weights_ptr = (float *)top_k_weights_cpu->data();

size_t ntoken = hidden_states->shape()[0];
int index;
float score;

auto final_hidden_states = infinicore::Tensor::empty(hidden_states->shape(), hidden_states->dtype(), hidden_states->device());
for (size_t itok = 0; itok < ntoken; ++itok) {
auto hidden_states_i = hidden_states->narrow({{0, itok, 1}});
const size_t route_row = itok * num_experts_per_tok_;

infinicore::Tensor final_hidden_states_i;
for (size_t k = 0; k < num_experts_per_tok_; ++k) {
index = top_k_index_ptr[route_row + k];
score = top_k_weights_ptr[route_row + k];

ASSERT(index >= 0 && index < num_experts_);

experts_[index]->set_alpha(score);
auto expert_out = experts_[index]->forward(hidden_states_i);

if (k == 0) {
final_hidden_states_i = expert_out;
} else {
infinicore::op::add_(final_hidden_states_i, final_hidden_states_i, expert_out);
}
}

final_hidden_states->narrow({{0, itok, 1}})->copy_from(final_hidden_states_i);
}
return final_hidden_states;
}

} // namespace infinilm::models::qwen3_moe
27 changes: 27 additions & 0 deletions csrc/models/qwen3_moe/qwen3_moe_experts.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@


#pragma once
#include "../../layers/common_modules.hpp"

#include <memory>

namespace infinilm::models::qwen3_moe {

using Qwen3MoeMLP = infinilm::layers::MoeMLP;

class Qwen3MoeExperts : public infinicore::nn::Module {
public:
Qwen3MoeExperts(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device);

infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &top_k_index,
const infinicore::Tensor &top_k_weights) const;

protected:
INFINICORE_NN_MODULE_VEC(Qwen3MoeMLP, experts);
size_t num_experts_per_tok_{0};
size_t num_experts_{0};
};

} // namespace infinilm::models::qwen3_moe
1 change: 1 addition & 0 deletions csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ using Qwen3MoeModel = infinilm::layers::causal_lm_templates::TextModel<Qwen3MoeD
using Qwen3MoeForCausalLM = infinilm::layers::causal_lm_templates::TextCausalLM<Qwen3MoeModel>;

std::shared_ptr<infinilm::config::ModelConfig> create_qwen3_moe_model_config(std::shared_ptr<infinilm::config::ModelConfig> model_config);

} // namespace infinilm::models::qwen3_moe
30 changes: 11 additions & 19 deletions csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
#include "qwen3_moe_sparse_moe_block.hpp"
#include <spdlog/spdlog.h>

namespace infinilm::models::qwen3_moe {

Qwen3MoeSparseMoeBlock::Qwen3MoeSparseMoeBlock(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device) {
const auto &dtype{model_config->get_dtype()};
size_t hidden_size = model_config->get<size_t>("hidden_size");
size_t moe_intermediate_size = model_config->get<size_t>("moe_intermediate_size");
size_t shared_expert_intermediate_size = model_config->get_or<size_t>("shared_expert_intermediate_size", 0);
size_t num_experts = model_config->get<size_t>("num_experts");

INFINICORE_NN_MODULE_INIT(gate, hidden_size, num_experts, false, dtype, device);
experts_.reserve(num_experts);
for (size_t i = 0; i < num_experts; ++i) {
experts_.push_back(this->register_module<Qwen3MoeMLP>("experts." + std::to_string(i), model_config, device));
}

if (shared_expert_intermediate_size > 0) {
INFINICORE_NN_MODULE_INIT(shared_expert, model_config, device);
INFINICORE_NN_MODULE_INIT(shared_expert_gate, hidden_size, 1, false, dtype, device);
}
INFINICORE_NN_MODULE_INIT(gate, model_config, device);
INFINICORE_NN_MODULE_INIT(experts, model_config, device);
}

infinicore::Tensor Qwen3MoeSparseMoeBlock::forward(const infinicore::Tensor &hidden_states) const {
spdlog::error("Qwen3MoeSparseMoeBlock: forward not implemented");
return hidden_states;
ASSERT(hidden_states->ndim() == 3);

auto shape = hidden_states->shape(); // shape[ 1 11 2048 ]
auto hidden_states_reshaped = hidden_states->view({shape[0] * shape[1], shape[2]});

auto [routing_weights, selected_experts] = gate_->forward(hidden_states_reshaped);
auto final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights);

return final_hidden_states->view({shape[0], shape[1], shape[2]});
}

} // namespace infinilm::models::qwen3_moe
10 changes: 4 additions & 6 deletions csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include "../../layers/common_modules.hpp"
#include "qwen3_moe_experts.hpp"
#include "qwen3_moe_topk_router.hpp"

namespace infinilm::models::qwen3_moe {
using Qwen3MoeMLP = infinilm::layers::MoeMLP;

class Qwen3MoeSparseMoeBlock : public infinicore::nn::Module {
public:
Expand All @@ -13,10 +13,8 @@ class Qwen3MoeSparseMoeBlock : public infinicore::nn::Module {
infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const;

protected:
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, gate);
INFINICORE_NN_MODULE_VEC(Qwen3MoeMLP, experts);
INFINICORE_NN_MODULE(Qwen3MoeMLP, shared_expert);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, shared_expert_gate);
INFINICORE_NN_MODULE(Qwen3MoeTopKRouter, gate);
INFINICORE_NN_MODULE(Qwen3MoeExperts, experts);
};

} // namespace infinilm::models::qwen3_moe
36 changes: 36 additions & 0 deletions csrc/models/qwen3_moe/qwen3_moe_topk_router.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "qwen3_moe_topk_router.hpp"

#include "infinicore/ops.hpp"

namespace infinilm::models::qwen3_moe {

Qwen3MoeTopKRouter::Qwen3MoeTopKRouter(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device) {
const auto &dtype{model_config->get_dtype()};

size_t hidden_size = model_config->get<size_t>("hidden_size");
size_t num_experts = model_config->get<size_t>("num_experts");
num_experts_per_tok_ = model_config->get<size_t>("num_experts_per_tok");
norm_topk_prob_ = model_config->get<bool>("norm_topk_prob");

ASSERT((num_experts > 0) && (num_experts_per_tok_ > 0) && (num_experts_per_tok_ <= num_experts));

INFINICORE_NN_PARAMETER_INIT(weight, ({num_experts, hidden_size}, dtype, device));
}

std::tuple<infinicore::Tensor, infinicore::Tensor> Qwen3MoeTopKRouter::forward(const infinicore::Tensor &hidden_states) const {

ASSERT(hidden_states->ndim() == 2);

size_t ntoken = hidden_states->shape()[0];
auto router_logits = infinicore::op::linear(hidden_states, weight_, std::nullopt, 1.0f);

auto router_scores = infinicore::Tensor::empty({ntoken, num_experts_per_tok_}, infinicore::DataType::F32, hidden_states->device());
auto router_indices = infinicore::Tensor::empty({ntoken, num_experts_per_tok_}, infinicore::DataType::I32, hidden_states->device());

infinicore::op::topksoftmax(router_scores, router_indices, router_logits, num_experts_per_tok_, norm_topk_prob_);

return std::make_tuple(router_scores, router_indices);
}

} // namespace infinilm::models::qwen3_moe
25 changes: 25 additions & 0 deletions csrc/models/qwen3_moe/qwen3_moe_topk_router.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@


#pragma once
#include "../../layers/common_modules.hpp"

#include <memory>
#include <tuple>

namespace infinilm::models::qwen3_moe {

class Qwen3MoeTopKRouter : public infinicore::nn::Module {
public:
Qwen3MoeTopKRouter(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device);

std::tuple<infinicore::Tensor, infinicore::Tensor> forward(const infinicore::Tensor &hidden_states) const;

protected:
INFINICORE_NN_PARAMETER(weight);

size_t num_experts_per_tok_{0};
bool norm_topk_prob_{false};
};

} // namespace infinilm::models::qwen3_moe
17 changes: 1 addition & 16 deletions python/infinilm/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,7 @@ def from_pretrained(model_path):
config_dict["model_type"] == "qwen2" or config_dict["model_type"] == "qwen3"
):
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "minicpm":
elif config_dict["model_type"] == ["minicpm", "fm9g", "fm9g7b"]:
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "fm9g":
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "fm9g7b":
return LlamaConfig(**config_dict)
elif config_dict["model_type"] in [
"qwen3_next",
"minicpm_sala",
"qwen3_vl",
"qwen3_moe",
]:
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "minicpmv":
cfg = LlamaConfig(**config_dict)
cfg.model_type = "minicpmv"
return cfg

raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")