Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions csrc/models/llama/llama_decoder_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "llama_decoder_layer.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/ops.hpp"
#include <optional>

namespace infinilm::models::llama {

Expand All @@ -21,34 +22,50 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
}

infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const {
// Save residual for attention
auto residual = hidden_states;

// 1. Pre-attention layer normalization
auto normed_states = input_layernorm_->forward(hidden_states);
std::pair<infinicore::Tensor, infinicore::Tensor> LlamaDecoderLayer::forward(
const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions,
const std::optional<infinicore::Tensor> &residual_in) const {

infinicore::Tensor normed_states;
infinicore::Tensor residual;

// 1. Pre-attention layer normalization with optional residual add from previous layer
if (residual_in.has_value()) {
// Fuse previous layer's MLP residual add with current layer's input normalization
// This avoids a separate add operation: residual_in + hidden_states
auto [normed_result, add_result] = infinicore::op::add_rms_norm(
residual_in.value(), hidden_states,
input_layernorm_->weight(),
static_cast<float>(input_layernorm_->eps()));
normed_states = normed_result;
residual = add_result; // This is residual_in + hidden_states
} else {
// First layer: no residual to add, just normalize
normed_states = input_layernorm_->forward(hidden_states);
residual = hidden_states;
}

// 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_positions);

// Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output);
// Save residual for MLP
residual = output;

// 3. Post-attention layer normalization
normed_states = post_attention_layernorm_->forward(output);
// 3. Add attention residual and apply post-attention layer normalization (fused)
auto [normed_states_result, add_result] = infinicore::op::add_rms_norm(
residual, attn_output,
post_attention_layernorm_->weight(),
static_cast<float>(post_attention_layernorm_->eps()));

normed_states = normed_states_result;
residual = add_result; // Save for MLP residual connection

// 4. MLP with residual connection
// 4. MLP
auto mlp_output = mlp_->forward(normed_states);

// Add residual: output = output + mlp_output
output = infinicore::op::add(residual, mlp_output);

return output;
// Return (mlp_output, residual) WITHOUT doing the final add
// Next layer will fuse this add with its input_layernorm using add_rms_norm
return std::make_pair(mlp_output, residual);
}

} // namespace infinilm::models::llama
15 changes: 10 additions & 5 deletions csrc/models/llama/llama_decoder_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "llama_mlp.hpp"

#include "../../engine/distributed/distributed.hpp"
#include <optional>

namespace infinilm::models::llama {

Expand Down Expand Up @@ -44,12 +45,16 @@ class LlamaDecoderLayer : public infinicore::nn::Module {
* @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional KV cache for incremental decoding
* @return Output tensor of shape [batch, seq_len, hidden_size]
* @param cache_positions Cache positions tensor
* @param residual Optional residual tensor from previous layer (for MLP residual connection)
* @return Pair of (output, residual) tensors, where residual can be reused by next layer
*/
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const;
std::pair<infinicore::Tensor, infinicore::Tensor> forward(
const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions,
const std::optional<infinicore::Tensor> &residual = std::nullopt) const;

/**
* @brief Get the layer index
Expand Down
26 changes: 22 additions & 4 deletions csrc/models/llama/llama_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,36 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
auto hidden_states = embed_tokens_->forward(input_ids);

// 2. Process through all decoder layers
// Reuse residual across layers to avoid redundant add operations
size_t num_layers = layers_.size();
std::optional<infinicore::Tensor> residual = std::nullopt;
for (size_t i = 0; i < num_layers; ++i) {
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_positions);
auto [output, next_residual] = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_positions, residual);
hidden_states = output;
residual = next_residual;
}

// 3. Apply final layer normalization to last token only (aligns with transformers)
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
auto shape = hidden_states->shape();
size_t seq_len = shape[1];
auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}});

auto normalized_last_token = norm_->forward(last_token);

// Narrow both residual and mlp_output to last token before fusing add and norm
// Note: narrow() creates a view (no data copy), so this is equivalent to:
// narrow(add(residual, mlp_output)) == add(narrow(residual), narrow(mlp_output))
// But doing narrow first allows us to:
// 1. Only compute add for the last token (not the entire sequence) - saves computation
// 2. Fuse add with norm in a single kernel using add_rms_norm - avoids separate add kernel
auto residual_last_token = residual.value()->narrow({{1, seq_len - 1, 1}});
auto mlp_output_last_token = hidden_states->narrow({{1, seq_len - 1, 1}});

// Fuse final residual add with layer normalization using add_rms_norm
// This avoids a separate add operation - add and norm are computed in one fused kernel
// Result is mathematically equivalent to: norm(add(residual, mlp_output))[last_token]
auto [normalized_last_token, _] = infinicore::op::add_rms_norm(
residual_last_token, mlp_output_last_token,
norm_->weight(),
static_cast<float>(norm_->eps()));

return normalized_last_token;
}
Expand Down