Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ public List<Integer> encodeMessage(Message message) {

@Override
public int getBeginOfText() {
return beginOfText;
if (beginOfText == -1) {
// deepseek-r1
return startHeader;
} else {
return beginOfText;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens;
import org.beehive.gpullama3.model.qwen2.DeepSeekR1Qwen;
import org.beehive.gpullama3.model.qwen2.Qwen2;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;
Expand Down Expand Up @@ -85,7 +86,9 @@ protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weig
// Qwen2.5-Coder uses <|endoftext|> as stop-token.
ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
: new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
return isDeepSeekR1DistillQwen
? new DeepSeekR1Qwen(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens))
: new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
}
// @formatter:on

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.beehive.gpullama3.model.qwen2;

import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.tokenizer.Tokenizer;

public class DeepSeekR1Qwen extends Qwen2 {

public DeepSeekR1Qwen(Qwen2Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
super(configuration, tokenizer, weights, chatFormat);
}

@Override
public ModelType getModelType() {
return ModelType.DEEPSEEK_R1_DISTILL_QWEN;
}

@Override
public boolean shouldAddBeginOfText() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory;
import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.tornadovm;
package org.beehive.gpullama3.tornadovm.layerplanner;

import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.tornadovm.layerplanner.base;
package org.beehive.gpullama3.tornadovm.layerplanner;

import org.beehive.gpullama3.inference.state.GraniteState;
import org.beehive.gpullama3.tensor.GGMLType;
Expand All @@ -8,14 +8,15 @@
import org.beehive.gpullama3.inference.state.Qwen3State;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.MistralQ8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner;
Expand Down Expand Up @@ -54,7 +55,8 @@ public static GenericLayerPlanner create(GGMLType quantization, State state, Mod
// ============ FP16 Planners ============
private static GenericLayerPlanner createFP16Planner(State state, Model model) {
return switch (model.getModelType()) {
case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model);
case LLAMA_3 -> new LlamaFP16LayerPlanner((LlamaState) state, model);
case MISTRAL -> new MistralFP16LayerPlanner((LlamaState) state, model);
case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model);
case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
Expand All @@ -67,7 +69,8 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) {
// ============ Q8_0 Planners ============
private static GenericLayerPlanner createQ8_0Planner(State state, Model model) {
return switch (model.getModelType()) {
case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
case LLAMA_3 -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
case MISTRAL -> new MistralQ8_0LayerPlanner((LlamaState) state, model);
case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model);
case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package org.beehive.gpullama3.tornadovm.layerplanner;

import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;

import java.util.ArrayList;
import java.util.List;

/**
* Abstract base for all quantization-specific planners.
*
* Extracts common state from the model, detects the hardware scheduler type,
* and assembles the full execution plan via createTornadoInferencePlan().
* Subclasses (FP16LayerPlanner, Q8_0LayerPlanner) only provide quantization validation.
*/
public abstract class QuantizedLayerPlanner<S extends State, C extends Configuration, W extends Weights>
implements GenericLayerPlanner {

protected final S state;
protected final C config;
protected final W weights;
protected final KernelContext context;
protected final Model model;
protected final SchedulerType schedulerType;

protected Activation activationLayer;
protected AbstractFFNLayers<W, C> ffnLayers;
protected AbstractLogitsLayer logitsLayer;

private List<ImmutableTaskGraph> immutableTaskGraphs;
private GridScheduler gridScheduler;

@SuppressWarnings("unchecked")
protected QuantizedLayerPlanner(S state, Model model) {
this.state = state;
this.model = model;
this.config = (C) model.configuration();
this.weights = (W) model.weights();
this.context = new KernelContext();
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
validateQuantizationType();
}

/** Validates that the model weights match the expected quantization type. */
protected abstract void validateQuantizationType();

/**
* Creates the TornadoVM inference execution pipeline.
* It represents the entire Feed-Forward Network (FFN) and consists of:
* <ul>
* <li>Activation layer</li>
* <li>FFN layers (N transformer layers, model-specific)</li>
* <li>Logits layer</li>
* </ul>
* <p>
* Each component is represented as an {@link ImmutableTaskGraph}, along with a
* corresponding {@link GridScheduler} configuration that defines how tasks are
* mapped on the GPU.
* </p>
* This method assembles all components into a unified execution pipeline and
* caches the resulting task graphs and scheduler for reuse across inference runs.
*/
protected final void createTornadoInferencePlan() {
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
GridScheduler masterScheduler = new GridScheduler();

// 1. Activation layer (common to all models)
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
activationLayer.updateGridScheduler(masterScheduler);

// 2. FFN layers (N transformer layers - model-specific)
allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs());
ffnLayers.updateGridScheduler(masterScheduler);

// 3. Logits layer (common to all models)
allTaskGraphs.add(logitsLayer.getImmutableTaskGraph());
logitsLayer.updateGridScheduler(masterScheduler);

// Cache for future retrievals
this.immutableTaskGraphs = allTaskGraphs;
this.gridScheduler = masterScheduler;
}

@Override
public final List<ImmutableTaskGraph> getImmutableTaskGraphs() {
return this.immutableTaskGraphs;
}

@Override
public final GridScheduler getGridScheduler() {
return this.gridScheduler;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;

import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner;

/**
* Base for all FP16-quantized layer planners.
*/
public abstract class FP16LayerPlanner<S extends State, C extends Configuration, W extends TornadoWeights> extends QuantizedLayerPlanner<S, C, W> {

protected FP16LayerPlanner(S state, Model model) {
super(state, model);
}

@Override
protected void validateQuantizationType() {
if (this.weights.getWeightType() != GGMLType.F16) {
throw new IllegalArgumentException("FP16LayerPlanner requires GGMLType.F16, got: " + this.weights.getWeightType());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,17 @@
import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer;

public class GraniteFP16LayerPlanner extends FP16LayerPlanner<GraniteState, GraniteConfiguration, GraniteTornadoWeights> {

public GraniteFP16LayerPlanner(GraniteState state, Model model) {
super(state, model);
validateQuantizationType();
setupTornadoForwardPlan();
}

@Override
protected void initializeLayerComponents() {
this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config);
this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType);
this.logitsLayer = new LogitsGraniteFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
createTornadoInferencePlan();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
Expand All @@ -13,15 +12,9 @@ public class LlamaFP16LayerPlanner extends FP16LayerPlanner<LlamaState, LlamaCon

public LlamaFP16LayerPlanner(LlamaState state, Model model) {
super(state, model);
validateQuantizationType();
setupTornadoForwardPlan();
this.activationLayer = new Activation("activationUpdate", state, weights, config);
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType);
this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
createTornadoInferencePlan();
}

@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;

import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;

public class MistralFP16LayerPlanner extends FP16LayerPlanner<LlamaState, MistralConfiguration, LlamaTornadoWeights> {

public MistralFP16LayerPlanner(LlamaState state, Model model) {
super(state, model);
this.activationLayer = new Activation("activationUpdate", state, weights, config);
this.ffnLayers = new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType);
this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
createTornadoInferencePlan();
}
}
Loading
Loading