-
Notifications
You must be signed in to change notification settings - Fork 123
Add Gemma 3 support for FunctionGemma and other Gemma 3 models #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
1f4f847 to
13b3a33
Compare
|
I generated the tiny models with # Create tiny random Gemma 3 models for Bumblebee testing
# Run with: HF_TOKEN=hf_xxx python create_tiny_gemma3.py
import os
from huggingface_hub import login
# Login with token
token = os.environ.get("HF_TOKEN")
if token:
login(token=token)
else:
print("Warning: HF_TOKEN not set, using cached credentials")
from transformers import (
Gemma3TextConfig,
Gemma3TextModel,
Gemma3ForCausalLM, # No "Text" variant for CausalLM
Gemma3TextForSequenceClassification,
)
# Tiny config matching Gemma 3 text architecture
config = Gemma3TextConfig(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
intermediate_size=64,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=512,
rms_norm_eps=1e-6,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
sliding_window=128,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
tie_word_embeddings=True,
initializer_range=0.02,
query_pre_attn_scalar=8,
)
# For sequence classification
config_seq_cls = Gemma3TextConfig(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
intermediate_size=64,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=512,
rms_norm_eps=1e-6,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
sliding_window=128,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
tie_word_embeddings=True,
initializer_range=0.02,
query_pre_attn_scalar=8,
num_labels=2,
)
models = [
(Gemma3TextModel, config, "tiny-random-Gemma3Model"),
(Gemma3ForCausalLM, config, "tiny-random-Gemma3ForCausalLM"),
(Gemma3TextForSequenceClassification, config_seq_cls, "tiny-random-Gemma3ForSequenceClassification"),
]
print("Creating tiny random Gemma 3 models...")
for model_class, model_config, name in models:
print(f"\nCreating {name}...")
model = model_class(model_config)
local_path = f"./{name}"
model.save_pretrained(local_path)
print(f" Saved to {local_path}")
repo_id = f"nmaroulis/{name}"
model.push_to_hub(repo_id)
print(f" Pushed to https://huggingface.co/{repo_id}")
print("\nDone!")
|
Gemma 3 architecture includes several key differences from Gemma v1: - QK-norm (RMS normalization on query/key after projection) - Pre/post FFN layer norms (pre_feedforward_layernorm, post_feedforward_layernorm) - Different residual connection order (after post_attention_layernorm) - Alternating local/global attention (sliding window) - RMS norm with shift=1.0 formula: output * (1.0 + weight) Files added: - lib/bumblebee/text/gemma3.ex: Full Gemma 3 model implementation - test/bumblebee/text/gemma3_test.exs: Unit tests - notebooks/function_calling.livemd: Livebook with FunctionGemma examples Files modified: - lib/bumblebee.ex: Model and tokenizer registrations - lib/bumblebee/layers/transformer.ex: Per-layer attention_window_size support
13b3a33 to
1fc7aaf
Compare
|
|
||
| outputs = Axon.predict(model, params, inputs) | ||
|
|
||
| assert Nx.shape(outputs.hidden_state) == {1, 10, 32} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the tests only assert shape. They should assert output against reference values from Python transformers, as we do for all other models. That's how we verify our implementation behaves correctly.
| end | ||
|
|
||
| {hidden_state, attention, block_cache} = | ||
| gemma3_block(state.hidden_state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything preventing us from using Layers.Transformer.blocks? If so, perhaps we can add extra options to Layers.Transformer.blocks to accommodate the necessary behaviour. We do that for all other models, so that the core transformer implementation is shared, rather than duplicated everywhere.
Adds support for Gemma 3 architecture, enabling FunctionGemma (
google/functiongemma-270m-it) and other Gemma 3 models to run in Bumblebee.Why FunctionGemma?
Gemma 3 Architecture Changes
Gemma 3 has several key differences from Gemma v1:
weight * normalized(1 + weight) * normalizedFiles Changed
lib/bumblebee/text/gemma3.ex- Full Gemma 3 model implementation with custom decoder supporting QK-norm and extra layer normslib/bumblebee.ex- Model and tokenizer registrations forGemma3Model,Gemma3ForCausalLM, etc.lib/bumblebee/layers/transformer.ex- Per-layerattention_window_sizecallback for alternating local/global attentiontest/bumblebee/text/gemma3_test.exs- Unit tests (require tiny-random models on HuggingFace)notebooks/function_calling.livemd- Comprehensive Livebook example with:FunctionGemma.Schema- Build function declarationsFunctionGemma.Parser- Parse function call responsesFunctionGemma.Executor- Execute parsed callsSmartHome- Mock functions demo (lights, thermostat, weather)Example Usage
Test Plan