Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=Fals
if x is None:
return None
if transpose:
if hasattr(x, "get"):
x = x.get()
if isinstance(x, paddle.Tensor):
x = paddle.transpose(x, [1, 0])
else:
Expand Down
7 changes: 6 additions & 1 deletion paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None

try:
from paddle.incubate.nn.functional import fused_rms_norm_ext
except ImportError:
fused_rms_norm_ext = None
try:
from paddle.incubate.nn.functional import swiglu
except ImportError:
Expand Down Expand Up @@ -132,6 +135,8 @@ def fusion_rope(


def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
if fused_rms_norm_ext is not None:
return fused_rms_norm_ext(x_in, w, eps)[0].astype(w.dtype)
if use_fast_ln:
fast_ln = try_import("fast_ln")
return fast_ln.fast_rms_norm(x_in, w, eps)[0]
Expand Down
204 changes: 143 additions & 61 deletions paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@
"Qwen2ForTokenClassification",
"Qwen2SentenceEmbedding",
]
import os


def str2bool(v):
if isinstance(v, bool):
return v
elif v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ValueError("Unsupported value encountered.")


FLAGS_ALIGN_PADDLEFORMERS = str2bool(os.getenv("FLAGS_ALIGN_PADDLEFORMERS", "True"))


def get_triangle_upper_mask(x, mask=None):
Expand Down Expand Up @@ -329,65 +344,125 @@ def forward(self, hidden_states):
return hidden_states * self.weight


class Qwen2RotaryEmbedding(nn.Layer):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# [dim / 2]
self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim))
self._set_cos_sin_cache(seq_len=max_position_embeddings)

def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
if self.inv_freq.dtype != paddle.float32:
self.inv_freq = 1.0 / (
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)
)
# [seq_len]
t = paddle.arange(seq_len, dtype="float32")
# [seq_len, dim/2]
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
cos = self.cos_cached[:, :seq_len, :, :]
sin = self.sin_cached[:, :seq_len, :, :]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
if position_ids is None:
# Note: Only for Qwen2MoEForCausalLMPipe model pretraining
cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
else:
cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
if FLAGS_ALIGN_PADDLEFORMERS:

def _apply_rotary_emb(
x: paddle.Tensor,
cos: paddle.Tensor,
sin: paddle.Tensor,
) -> paddle.Tensor:
x = x.transpose([0, 2, 1, 3])
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.transpose([0, 2, 1, 3])

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = _apply_rotary_emb(q, cos, sin)
k_embed = _apply_rotary_emb(k, cos, sin)
return q_embed.astype(q.dtype), k_embed.astype(k.dtype)

class Qwen2RotaryEmbedding(nn.Layer):
def __init__(self, config: Qwen2Config):
super().__init__()
self.config = config
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * partial_rotary_factor)

inv_freq = 1.0 / (
base ** (paddle.arange(0, dim, 2, dtype=paddle.int64).astype(dtype=paddle.float32) / dim)
)
self.attention_scaling = 1.0
self.register_buffer("inv_freq", inv_freq, persistable=False)
self.original_inv_freq = self.inv_freq

def forward(self, x, position_ids):
# NOTE: Paddle's Automatic Mixed Precision (AMP) has a default op whitelist that may automatically cast
# certain operations (like matmul) to FP16/BF16 for performance optimization. However, in scenarios where
# numerical stability is critical (e.g., RoPE init/compute), this conversion can lead to precision loss.
# Disabling auto_cast here ensures the matmul operation runs in the original precision (FP32) as intended.
with paddle.amp.auto_cast(False):
inv_freq_expanded = (
self.inv_freq.unsqueeze(0)
.unsqueeze(-1)
.cast(paddle.float32)
.expand([position_ids.shape[0], -1, 1])
.to(x.place)
)
position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32)

freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1])
emb = paddle.cat((freqs, freqs), axis=-1)
cos = paddle.cos(emb) * self.attention_scaling
sin = paddle.sin(emb) * self.attention_scaling

return cos.cast(dtype="float32"), sin.cast(dtype="float32")

else:

class Qwen2RotaryEmbedding(nn.Layer):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# [dim / 2]
self.inv_freq = 1.0 / (
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)
)
self._set_cos_sin_cache(seq_len=max_position_embeddings)

def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
if self.inv_freq.dtype != paddle.float32:
self.inv_freq = 1.0 / (
self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)
)
# [seq_len]
t = paddle.arange(seq_len, dtype="float32")
# [seq_len, dim/2]
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
cos = self.cos_cached[:, :seq_len, :, :]
sin = self.sin_cached[:, :seq_len, :, :]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
if position_ids is None:
# Note: Only for Qwen2MoEForCausalLMPipe model pretraining
cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
else:
cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class Qwen2MLP(nn.Layer):
Expand Down Expand Up @@ -612,11 +687,14 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, skip_r
)
self.o_proj = Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias_attr=False)

self.rotary_emb = Qwen2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
if FLAGS_ALIGN_PADDLEFORMERS:
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
else:
self.rotary_emb = Qwen2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)

self.attn_func = scaled_dot_product_attention

Expand Down Expand Up @@ -692,7 +770,10 @@ def forward(
use_neox_rotary_style=False,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if FLAGS_ALIGN_PADDLEFORMERS:
cos, sin = self.rotary_emb(hidden_states, position_ids)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# [bs, seq_len, num_head, head_dim]
Expand Down Expand Up @@ -948,8 +1029,9 @@ def get_tensor_parallel_split_mappings(num_layers):
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}

if config.tie_word_embeddings:
if FLAGS_ALIGN_PADDLEFORMERS:
base_actions["lm_head.weight"] = partial(fn, is_column=False, transpose=True)
elif config.tie_word_embeddings:
base_actions["lm_head.weight"] = partial(fn, is_column=False)
else:
base_actions["lm_head.weight"] = partial(fn, is_column=True)
Expand Down Expand Up @@ -1518,7 +1600,7 @@ def __init__(self, config: Qwen2Config):
self.lm_head = Qwen2LMHead(config, embedding_weights=self.qwen2.embed_tokens.weight, transpose_y=True)
self.tie_weights()
else:
self.lm_head = Qwen2LMHead(config)
self.lm_head = Qwen2LMHead(config, transpose_y=FLAGS_ALIGN_PADDLEFORMERS)
self.criterion = Qwen2PretrainingCriterion(config)
self.vocab_size = config.vocab_size

Expand Down
Loading