Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self.key_hidden_size = self.q_head_dim
self.val_hidden_size = self.config.v_head_dim

mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale_all_dim)
self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)

if self.config.rope_type == "rope":
Expand All @@ -102,7 +102,7 @@ def __init__(
self.config.qk_pos_emb_head_dim,
rotary_base=self.config.rotary_base,
scaling_factor=self.config.rotary_scaling_factor,
original_max_position_embeddings=self.config.max_position_embeddings,
original_max_position_embeddings=self.config.original_max_position_embeddings,
beta_fast=self.config.beta_fast,
beta_slow=self.config.beta_slow,
mscale=self.config.mscale,
Expand Down
14 changes: 11 additions & 3 deletions aiak_megatron/megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,16 +1078,24 @@ class MLATransformerConfig(TransformerConfig):
"""Rotary scaling factor for the rotary embeddings, used by yarn."""

max_position_embeddings: int = 4096
"""Maximum position embeddings for the original model, used by yarn."""
"""This arg is not used, will be deprecated."""

original_max_position_embeddings: int = 4096
"""Original maximum position embeddings for the original model, used by yarn."""

beta_fast: float = 32
"""Beta fast for YaRN RoPE, used by yarn."""

beta_slow: float = 1
"""Beta slow for YaRN RoPE, used by yarn."""

mscale: float = 0.707
mscale: float = 1.0
"""Mscale for YaRN RoPE in Multi-Latent Attention, used by yarn."""

mscale_all_dim: float = 0.707
mscale_all_dim: float = 0.0
"""Mscale all dimensions for YaRN RoPE in Multi-Latent Attention, used by yarn."""

def __post_init__(self):
super().__post_init__()
if self.multi_latent_attention and self.apply_rope_fusion and self.rope_type != "yarn":
raise ValueError("apply_rope_fusion for MLA only works with YARN RoPE.")
9 changes: 8 additions & 1 deletion aiak_megatron/megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2619,10 +2619,17 @@ def _add_mla_args(parser):
help="Dimension of the head in the V projection.")
group.add_argument('--rotary-scaling-factor', type=float, default=1.0,
help="Rotary scaling factor for the rotary embeddings.")
group.add_argument('--beta-fast', type=float, default=32.0,
help="beta_fast for YaRN RoPE in multi-latent attention.")
group.add_argument('--beta-slow', type=float, default=1.0,
help="beta_slow for YaRN RoPE in multi-latent attention.")
group.add_argument('--mscale', type=float, default=1.0,
help="Mscale for YaRN RoPE in multi-latent attention.")
group.add_argument('--mscale-all-dim', type=float, default=1.0,
group.add_argument('--mscale-all-dim', type=float, default=0.0,
help="Mscale all dimensions for YaRN RoPE in multi-latent attention.")
group.add_argument('--original-max-position-embeddings', type=int, default=None,
help='Original maximum number of position embeddings, used by yarn.'
'This is the size of position embedding.')

return parser

Expand Down
2 changes: 1 addition & 1 deletion aiak_megatron/megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _compare(arg_name, old_arg_name=None, default=None):
_compare("num_attention_heads")
_compare("add_position_embedding", default=False)
if args.vocab_file:
_compare("max_position_embeddings")
# _compare("max_position_embeddings")
_compare("make_vocab_size_divisible_by")
if not args.use_dist_ckpt:
_compare("padded_vocab_size")
Expand Down