Skip to content

Conversation

@nyo16
Copy link
Contributor

@nyo16 nyo16 commented Dec 28, 2025

Adds support for Gemma 3 architecture, enabling FunctionGemma (google/functiongemma-270m-it) and other Gemma 3 models to run in Bumblebee.

Why FunctionGemma?

  • Lightweight - Only 270M parameters, runs on CPU or modest GPU
  • Function calling - Specifically trained for tool/function invocation
  • Easy to fine-tune - Small enough to train on Google Colab T4
  • Edge/IoT ready - Perfect for home assistants, voice interfaces, embedded systems

Gemma 3 Architecture Changes

Gemma 3 has several key differences from Gemma v1:

Feature Gemma v1 Gemma 3
QK-norm No Yes (RMS norm on Q/K after projection)
FFN layer norms 1 (post-attention) 3 (post-attention, pre-FFN, post-FFN)
Residual order Before post-attention norm After post-attention norm
Attention Global only Alternating local/global (5:1 ratio)
RMS norm formula weight * normalized (1 + weight) * normalized

Files Changed

  • lib/bumblebee/text/gemma3.ex - Full Gemma 3 model implementation with custom decoder supporting QK-norm and extra layer norms
  • lib/bumblebee.ex - Model and tokenizer registrations for Gemma3Model, Gemma3ForCausalLM, etc.
  • lib/bumblebee/layers/transformer.ex - Per-layer attention_window_size callback for alternating local/global attention
  • test/bumblebee/text/gemma3_test.exs - Unit tests (require tiny-random models on HuggingFace)
  • notebooks/function_calling.livemd - Comprehensive Livebook example with:
    • FunctionGemma.Schema - Build function declarations
    • FunctionGemma.Parser - Parse function call responses
    • FunctionGemma.Executor - Execute parsed calls
    • SmartHome - Mock functions demo (lights, thermostat, weather)

Example Usage

{:ok, model_info} = Bumblebee.load_model({:hf, "google/functiongemma-270m-it", auth_token: token})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google/functiongemma-270m-it", auth_token: token})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "google/functiongemma-270m-it", auth_token: token})

serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config,
  compile: [batch_size: 1, sequence_length: 512],
  defn_options: [compiler: EXLA]
)

prompt = """
<start_of_turn>developer
You are a helpful assistant.
<start_function_declaration>declaration:get_weather{description:<escape>Get
weather<escape>,parameters:{properties:{location:{type:<escape>STRING<escape>}},required:[<escape>location<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>
<start_of_turn>user
What's the weather in Paris?<end_of_turn>
<start_of_turn>model
"""

%{results: [%{text: text}]} = Nx.Serving.run(serving, prompt)
# => "<start_function_call>call:get_weather{location:<escape>Paris<escape>}<end_function_call>"

Test Plan

  • Model loads without unused params warnings
  • FunctionGemma generates correct function call format
  • Multiple function types work (weather, lights, thermostat)
  • Livebook runs end-to-end with mock function execution
  • Unit tests pass (requires creating tiny-random-Gemma3* models on HuggingFace)"

@nyo16 nyo16 force-pushed the feat/add-functiongemma-support branch from 1f4f847 to 13b3a33 Compare December 28, 2025 17:28
@nyo16
Copy link
Contributor Author

nyo16 commented Dec 28, 2025

I generated the tiny models with

# Create tiny random Gemma 3 models for Bumblebee testing
# Run with: HF_TOKEN=hf_xxx python create_tiny_gemma3.py

import os
from huggingface_hub import login

# Login with token
token = os.environ.get("HF_TOKEN")
if token:
    login(token=token)
else:
    print("Warning: HF_TOKEN not set, using cached credentials")

from transformers import (
    Gemma3TextConfig,
    Gemma3TextModel,
    Gemma3ForCausalLM,  # No "Text" variant for CausalLM
    Gemma3TextForSequenceClassification,
)

# Tiny config matching Gemma 3 text architecture
config = Gemma3TextConfig(
    vocab_size=1024,
    hidden_size=32,
    num_hidden_layers=2,
    num_attention_heads=4,
    num_key_value_heads=2,
    head_dim=8,
    intermediate_size=64,
    hidden_activation="gelu_pytorch_tanh",
    max_position_embeddings=512,
    rms_norm_eps=1e-6,
    rope_theta=10000.0,
    attention_bias=False,
    attention_dropout=0.0,
    sliding_window=128,
    pad_token_id=0,
    bos_token_id=2,
    eos_token_id=1,
    tie_word_embeddings=True,
    initializer_range=0.02,
    query_pre_attn_scalar=8,
)

# For sequence classification
config_seq_cls = Gemma3TextConfig(
    vocab_size=1024,
    hidden_size=32,
    num_hidden_layers=2,
    num_attention_heads=4,
    num_key_value_heads=2,
    head_dim=8,
    intermediate_size=64,
    hidden_activation="gelu_pytorch_tanh",
    max_position_embeddings=512,
    rms_norm_eps=1e-6,
    rope_theta=10000.0,
    attention_bias=False,
    attention_dropout=0.0,
    sliding_window=128,
    pad_token_id=0,
    bos_token_id=2,
    eos_token_id=1,
    tie_word_embeddings=True,
    initializer_range=0.02,
    query_pre_attn_scalar=8,
    num_labels=2,
)

models = [
    (Gemma3TextModel, config, "tiny-random-Gemma3Model"),
    (Gemma3ForCausalLM, config, "tiny-random-Gemma3ForCausalLM"),
    (Gemma3TextForSequenceClassification, config_seq_cls, "tiny-random-Gemma3ForSequenceClassification"),
]

print("Creating tiny random Gemma 3 models...")

for model_class, model_config, name in models:
    print(f"\nCreating {name}...")
    model = model_class(model_config)

    local_path = f"./{name}"
    model.save_pretrained(local_path)
    print(f"  Saved to {local_path}")

    repo_id = f"nmaroulis/{name}"
    model.push_to_hub(repo_id)
    print(f"  Pushed to https://huggingface.co/{repo_id}")

print("\nDone!")

Gemma 3 architecture includes several key differences from Gemma v1:
- QK-norm (RMS normalization on query/key after projection)
- Pre/post FFN layer norms (pre_feedforward_layernorm, post_feedforward_layernorm)
- Different residual connection order (after post_attention_layernorm)
- Alternating local/global attention (sliding window)
- RMS norm with shift=1.0 formula: output * (1.0 + weight)

Files added:
- lib/bumblebee/text/gemma3.ex: Full Gemma 3 model implementation
- test/bumblebee/text/gemma3_test.exs: Unit tests
- notebooks/function_calling.livemd: Livebook with FunctionGemma examples

Files modified:
- lib/bumblebee.ex: Model and tokenizer registrations
- lib/bumblebee/layers/transformer.ex: Per-layer attention_window_size support
@nyo16 nyo16 force-pushed the feat/add-functiongemma-support branch from 13b3a33 to 1fc7aaf Compare December 28, 2025 17:30

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
Copy link
Member

@jonatanklosko jonatanklosko Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the tests only assert shape. They should assert output against reference values from Python transformers, as we do for all other models. That's how we verify our implementation behaves correctly.

end

{hidden_state, attention, block_cache} =
gemma3_block(state.hidden_state,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything preventing us from using Layers.Transformer.blocks? If so, perhaps we can add extra options to Layers.Transformer.blocks to accommodate the necessary behaviour. We do that for all other models, so that the core transformer implementation is shared, rather than duplicated everywhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants