diff --git a/src/main/java/com/example/LlamaApp.java b/src/main/java/com/example/LlamaApp.java
index 5ea0cb2..4a8b37c 100644
--- a/src/main/java/com/example/LlamaApp.java
+++ b/src/main/java/com/example/LlamaApp.java
@@ -5,7 +5,7 @@
import com.example.inference.sampler.CategoricalSampler;
import com.example.inference.sampler.Sampler;
import com.example.inference.sampler.ToppSampler;
-import com.example.loader.weights.ModelLoader;
+import com.example.model.loader.ModelLoader;
import com.example.model.Model;
import com.example.tornadovm.FloatArrayUtils;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
diff --git a/src/main/java/com/example/aot/AOT.java b/src/main/java/com/example/aot/AOT.java
index 837a56b..42e7ffc 100644
--- a/src/main/java/com/example/aot/AOT.java
+++ b/src/main/java/com/example/aot/AOT.java
@@ -3,11 +3,13 @@
import com.example.auxiliary.Timer;
import com.example.core.model.GGUF;
import com.example.core.model.tensor.GGMLTensorEntry;
+import com.example.model.loader.LlamaModelLoader;
import com.example.model.Model;
import com.example.Options;
+import com.example.model.format.LlamaChatFormat;
import com.example.model.llama.Llama;
-import com.example.loader.weights.ModelLoader;
-import com.example.loader.weights.Weights;
+import com.example.inference.weights.Weights;
+import com.example.tokenizer.impl.LlamaTokenizer;
import java.io.IOException;
import java.nio.channels.FileChannel;
@@ -28,6 +30,8 @@
public final class AOT {
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
+ static LlamaModelLoader modelLoader;
+
record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map tensorInfos) {}
@@ -44,9 +48,10 @@ private static PartialModel preLoadGGUF(String modelPath) {
}
GGUF gguf = GGUF.loadModel(path);
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
+ modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false);
return new PartialModel(
path.getFileName().toString(),
- Llama.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), // TODO: needs proper handling for AOT
+ modelLoader.loadModel(), // TODO: needs proper handling for AOT
gguf.getTensorDataOffset(),
gguf.getTensorInfos()
);
@@ -77,8 +82,8 @@ public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IO
var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
// Load only the tensors (mmap slices).
Map tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos());
- Weights weights = ModelLoader.loadWeights(tensorEntries, baseModel.configuration());
- return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights);
+ Weights weights = modelLoader.loadWeights(tensorEntries, baseModel.configuration());
+ return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights, new LlamaChatFormat((LlamaTokenizer) baseModel.tokenizer()));
}
}
}
diff --git a/src/main/java/com/example/auxiliary/Utf8Mask.java b/src/main/java/com/example/auxiliary/Utf8Mask.java
new file mode 100644
index 0000000..d5ee242
--- /dev/null
+++ b/src/main/java/com/example/auxiliary/Utf8Mask.java
@@ -0,0 +1,10 @@
+package com.example.auxiliary;
+
+/** mask of a byte-sequence in UTF-8 encoding */
+public record Utf8Mask(int mask, int pattern, int len) {
+ public static final Utf8Mask[] MASKS = {
+ new Utf8Mask(0b11100000, 0b11000000, 2),
+ new Utf8Mask(0b11110000, 0b11100000, 3),
+ new Utf8Mask(0b11111000, 0b11110000, 4)
+ };
+}
diff --git a/src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java b/src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java
index 6efeea9..b8cfa23 100644
--- a/src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java
+++ b/src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java
@@ -13,7 +13,7 @@ public final class ArrayFloatTensor extends FloatTensor {
final float[] values;
- ArrayFloatTensor(float[] values) {
+ public ArrayFloatTensor(float[] values) {
this.values = values;
}
diff --git a/src/main/java/com/example/core/model/tensor/F32FloatTensor.java b/src/main/java/com/example/core/model/tensor/F32FloatTensor.java
new file mode 100644
index 0000000..b650c36
--- /dev/null
+++ b/src/main/java/com/example/core/model/tensor/F32FloatTensor.java
@@ -0,0 +1,48 @@
+package com.example.core.model.tensor;
+
+import com.example.core.model.GGMLType;
+import jdk.incubator.vector.FloatVector;
+import jdk.incubator.vector.VectorSpecies;
+
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+
+public final class F32FloatTensor extends FloatTensor {
+ final int size;
+ final MemorySegment segment;
+
+ public F32FloatTensor(int size, MemorySegment segment) {
+ this.size = size;
+ this.segment = segment;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public GGMLType type() {
+ return GGMLType.F32;
+ }
+
+ @Override
+ public MemorySegment asMemorySegment() {
+ return null;
+ }
+
+ @Override
+ public float getFloat(int index) {
+ return segment.get(ValueLayout.OfFloat.JAVA_FLOAT, index * Float.BYTES);
+ }
+
+ @Override
+ public void setFloat(int index, float value) {
+ segment.set(ValueLayout.OfFloat.JAVA_FLOAT, index * Float.BYTES, value);
+ }
+
+ @Override
+ protected FloatVector getFloatVector(VectorSpecies species, int offset) {
+ throw new UnsupportedOperationException("getFloatVector is not yet implemented.");
+ }
+}
diff --git a/src/main/java/com/example/inference/InferenceCore.java b/src/main/java/com/example/inference/InferenceCore.java
index 81e432c..b135e1c 100644
--- a/src/main/java/com/example/inference/InferenceCore.java
+++ b/src/main/java/com/example/inference/InferenceCore.java
@@ -2,22 +2,35 @@
import com.example.auxiliary.Parallel;
import com.example.core.model.tensor.FloatTensor;
-import com.example.loader.weights.State;
-import com.example.loader.weights.Weights;
+import com.example.inference.state.State;
+import com.example.inference.weights.standard.Qwen3StandardWeights;
+import com.example.inference.weights.standard.StandardWeights;
+import com.example.inference.weights.tornado.TornadoWeights;
import com.example.model.Configuration;
import com.example.model.Model;
+import com.example.model.qwen3.Qwen3Configuration;
import com.example.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import java.lang.foreign.MemorySegment;
-import java.nio.FloatBuffer;
/**
* Low-level operations for model inference.
*
*
- * Provides core computational operations: RMS normalization and forward passes
- * through model layers. Supports both CPU and GPU implementations.
+ * This class provides core computational operations such as RMS normalization and
+ * forward passes through model layers. It supports both CPU and GPU implementations.
+ *
+ *
+ *
+ * Specifically, it implements:
+ *
+ * - {@code rmsnorm} – applies Root Mean Square Layer Normalization to input vectors
+ * - {@code forwardJava} – executes a Forward pass for LLaMA and Mistral models on CPU
+ * - {@code forwardJavaQwen3} – executes a Forward pass for Qwen3 models on CPU
+ * - {@code forwardTornadoVM} – executes a Forward pass using TornadoVM for GPU acceleration
+ *
+ *
*/
public final class InferenceCore {
@@ -26,21 +39,21 @@ private InferenceCore() {
// prevent instantiation
}
- public static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
+ public static void rmsnorm(FloatTensor out, FloatTensor x, FloatTensor weight, int offset, int size, float rmsNormEps) {
// calculate sum of squares
- float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
+ float ss = x.reduce(offset, size, 0f, (acc, xi) -> acc + xi * xi);
ss /= size;
ss += rmsNormEps;
ss = (float) (1.0 / Math.sqrt(ss));
// normalize and scale
final float finalss = ss; // for the lambda
- out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index)));
+ out.mapWithIndexInPlace(offset, size, (value, index) -> weight.getFloat(index % size) * (finalss * x.getFloat(index)));
}
public static FloatTensor forwardJava(Model model, State state, int token, int position) {
// a few convenience variables
final Configuration config = model.configuration();
- final Weights weights = model.weights();
+ final StandardWeights weights = (StandardWeights) model.weights();
int dim = config.dim();
int headSize = config.headSize();
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
@@ -53,7 +66,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
// forward all the layers
for (int l = 0; l < config.numberOfLayers(); l++) {
// attention rmsnorm
- rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps());
+ rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());
// qkv matmuls for this position
@@ -64,8 +77,8 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i += 2) {
int head_dim = i % headSize;
- float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2));
- float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2));
+ float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2));
+ float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2));
int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int v = 0; v < rotn; v++) {
FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key)
@@ -133,7 +146,146 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
state.x.addInPlace(state.xb2);
// ffn rmsnorm
- rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps());
+ rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());
+
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
+ // first calculate self.w1(x) and self.w3(x)
+ weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
+ weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);
+
+ // SwiGLU non-linearity
+ // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
+ state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
+
+ // elementwise multiply with w3(x)
+ state.hb.multiplyInPlace(state.hb2);
+
+ // final matmul to get the output of the ffn
+ weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());
+
+ // residual connection
+ state.x.addInPlace(state.xb);
+ }
+
+ rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());
+
+ weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
+
+ return state.logits;
+ }
+
+ public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) {
+ // a few convenience variables
+ final Qwen3Configuration config = (Qwen3Configuration) model.configuration(); // same
+ final Qwen3StandardWeights weights = (Qwen3StandardWeights) model.weights(); // same
+ int dim = config.dim(); // same
+ int nHeadKv = config.numberOfKeyValueHeads(); // n_head_kv = numberOfKeyValueHeads
+ int nEmbdHeadK = config.numberOfHeadsKey(); // n_embd_head_k = n_embd / n_head; %s.attention.key_length
+ int nEmbdHeadV = config.numberOfHeadsValue(); // n_embd_head_v = n_embd / n_head; %s.attention.value_length
+ int nEmbdVGqa = nEmbdHeadV * nHeadKv; // n_embd_v_gqa = n_embd_head_v * n_head_kv
+ int nEmbdHead = nEmbdHeadV;
+ int nEmbdGqa = nEmbdVGqa;
+ int gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
+ float sqrtHeadSize = (float) Math.sqrt(nEmbdHead);
+
+ // copy the token embedding into x
+ weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
+
+ // forward all the layers
+ for (int l = 0; l < config.numberOfLayers(); l++) {
+ // attention rmsnorm
+ final int curLayer = l;
+ rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps());
+
+ // qkv matmuls for this position
+ weights.wq[curLayer].matmul(state.xb, state.q, nEmbdHeadK * config.numberOfHeads(), dim);
+ weights.wk[curLayer].matmul(state.xb, state.k, nEmbdGqa, dim);
+ weights.wv[curLayer].matmul(state.xb, state.v, nEmbdGqa, dim);
+
+ // Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ for (int i = 0; i < config.numberOfHeads(); i++) {
+ rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
+ }
+ // Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ for (int i = 0; i < config.numberOfKeyValueHeads(); i++) {
+ rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
+ }
+
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
+ //for (int i = 0; i < config.numberOfHeads(); i += 2) {
+ for (int h = 0; h < config.numberOfHeads(); ++h) {
+ int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
+ int poffset = h * nEmbdHead;
+ int nComplEmbdHead = nEmbdHead / 2;
+ for (int ic = 0; ic < nComplEmbdHead; ic++) {
+ float fcr = weights.freq_cis_real.getFloat(position * nComplEmbdHead + ic);
+ float fci = weights.freq_cis_imag.getFloat(position * nComplEmbdHead + ic);
+ for (int vi = 0; vi < rotn; vi++) {
+ FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key)
+ float v0 = vec.getFloat(poffset + ic);
+ float v1 = vec.getFloat(poffset + ic + nComplEmbdHead);
+ vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
+ vec.setFloat(poffset + ic + nComplEmbdHead, v0 * fci + v1 * fcr);
+ }
+ }
+ }
+
+ // save key,value at this time step (position) to our kv cache
+ //int loff = l * config.seq_len * kvDim;
+ // kv cache layer offset for convenience
+ state.k.copyTo(0, state.keyCache[curLayer], position * nEmbdGqa, nEmbdGqa);
+ state.v.copyTo(0, state.valueCache[curLayer], position * nEmbdGqa, nEmbdGqa);
+
+ // multihead attention. iterate over all heads
+ // Process tokens one by one instead of in parallel
+ Parallel.parallelFor(0, config.numberOfHeads(), h -> {
+ // get the query vector for this head
+ // float* q = s.q + h * headSize;
+ int qOffset = h * nEmbdHead;
+ // attention scores for this head
+ // float* att = s.att + h * config.seq_len;
+ int attOffset = h * config.contextLength();
+
+ // iterate over all timesteps, including the current one
+ for (int t = 0; t <= position; t++) {
+ // get the key vector for this head and at this timestep
+ // float* k = s.key_cache + loff + t * dim + h * headSize;
+ int keyCacheOffset = /* loff + */ (t * nEmbdGqa + (h / gqa) * nEmbdHead);
+ // calculate the attention score as the dot product of q and k
+ float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, nEmbdHeadK);
+ //state.kq.setFloat(h + t, score);
+ score /= sqrtHeadSize;
+ // save the score to the attention buffer
+ state.att.setFloat(attOffset + t, score);
+ }
+
+ state.att.softmaxInPlace(attOffset, position + 1); // position + 0 + 1
+
+ // weighted sum of the values, store back into xb
+ // float* xb = s.xb + h * headSize;
+ int xbOffset = h * nEmbdHeadV;
+ // memset(xb, 0, headSize * sizeof(float));
+ state.xb.fillInPlace(xbOffset, nEmbdHeadV, 0f);
+
+ for (int t = 0; t <= position; t++) {
+ // get the value vector for this head and at this timestep
+ // float* v = s.value_cache + loff + t * dim + h * headSize;C
+ int vOffset = /* loff + */ t * nEmbdGqa + (h / gqa) * nEmbdHeadV;
+ // get the attention weight for this timestep
+ float a = state.att.getFloat(attOffset + t);
+ // accumulate the weighted value into xb
+ state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, nEmbdHeadV, a);
+ }
+ });
+
+ // final matmul to get the output of the attention
+ weights.wo[l].matmul(state.xb, state.xb2, dim, nEmbdHeadK * config.numberOfHeads());
+
+ // residual connection back into x
+ state.x.addInPlace(state.xb2);
+
+ // ffn rmsnorm
+ rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps());
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
@@ -154,7 +306,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
state.x.addInPlace(state.xb);
}
- rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps());
+ rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
@@ -188,7 +340,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
*/
public static FloatArray forwardTornadoVM(Model model, State state, int token, int position, TornadoVMMasterPlan tornadoVMMasterPlan) {
final Configuration configuration = model.configuration();
- final Weights weights = model.weights();
+ final TornadoWeights weights = (TornadoWeights) model.weights();
MemorySegment.copy(weights.tokenEmbeddingTable.getSegment(), token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
diff --git a/src/main/java/com/example/inference/InferenceEngine.java b/src/main/java/com/example/inference/InferenceEngine.java
index 814b1ae..bad6b67 100644
--- a/src/main/java/com/example/inference/InferenceEngine.java
+++ b/src/main/java/com/example/inference/InferenceEngine.java
@@ -2,7 +2,7 @@
import com.example.auxiliary.LastRunMetrics;
import com.example.inference.sampler.Sampler;
-import com.example.loader.weights.State;
+import com.example.inference.state.State;
import com.example.model.Configuration;
import com.example.model.Model;
import com.example.tokenizer.impl.Tokenizer;
@@ -20,6 +20,16 @@
*
* Orchestrates the complete inference process: ingests prompt tokens, then generates
* new tokens until a stop condition is met. Supports both CPU and GPU execution.
+ *
+ *
+ *
+ * It provides unified logic for the following methods:
+ *
+ * - {@code generateTokensLlama} – for LLaMA and Mistral models running on CPU
+ * - {@code generateTokensQwen3} – for Qwen3 models running on CPU
+ * - {@code generateTokensGPU} – for models executed on GPU
+ *
+ *
*/
public final class InferenceEngine {
@@ -46,7 +56,7 @@ private InferenceEngine() {
* @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens
* @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt
*/
- public static List generateTokens(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
+ public static List generateTokensLlama(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
// Start timing the whole process
long startNanos = System.nanoTime();
@@ -122,6 +132,90 @@ public static List generateTokens(Model model, State state, int startPo
return generatedTokens;
}
+ public static List generateTokensQwen3(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
+ IntConsumer onTokenGenerated) {
+ // Start timing the whole process
+ long startNanos = System.nanoTime();
+ long startGen = 0;
+ long inferenceStartNanos = 0;
+
+ // Validate and adjust maxTokens if necessary
+ if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) {
+ maxTokens = model.configuration().contextLength();
+ }
+
+ // Storage for generated tokens
+ List generatedTokens = new ArrayList<>();
+
+ // Initialize token variables
+ int currentToken = state.latestToken; // BOS?
+ int nextToken = 0;
+ int promptIndex = 0;
+
+ for (int position = startPosition; position < maxTokens; ++position) {
+
+ // Handle token processing
+ if (promptIndex < promptTokens.size()) {
+ // We're still processing the prompt tokens
+ final int token = promptTokens.get(promptIndex);
+
+ model.forward(state, token, position);
+
+ promptIndex++;
+ if (promptIndex < promptTokens.size()) {
+ continue;
+ }
+ if (echo) {
+ System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
+ }
+ // We have reached the last prompt token and computed the first response-token.
+ startGen = System.nanoTime();
+ position++; // The current logit belongs to the next position
+ } else {
+ // Mark the start of actual generation (after prompt processing)
+ if (inferenceStartNanos == 0) {
+ inferenceStartNanos = System.nanoTime();
+ }
+
+ model.forward(state, currentToken, position);
+
+ }
+
+ // Sample the next token
+ nextToken = sampler.sampleToken(state.logits);
+
+ // Output the token if echo is enabled
+ if (echo) {
+ System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
+ }
+
+ // Track the generated token
+ generatedTokens.add(nextToken);
+
+ // Notify via callback if provided
+ if (onTokenGenerated != null) {
+ onTokenGenerated.accept(nextToken);
+ }
+
+ // Check for stop condition
+ if (stopTokens.contains(nextToken)) {
+ break;
+ }
+
+ // Update for next iteration
+ state.latestToken = currentToken = nextToken;
+ }
+
+ // Calculate and print performance metrics
+ long endNanos = System.nanoTime();
+ double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0;
+ int totalTokens = promptIndex + generatedTokens.size();
+
+ LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds);
+
+ return generatedTokens;
+ }
+
public static List generateTokensGPU(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
// === Setup and Initialization ===
diff --git a/src/main/java/com/example/inference/state/LlamaState.java b/src/main/java/com/example/inference/state/LlamaState.java
new file mode 100644
index 0000000..350ad72
--- /dev/null
+++ b/src/main/java/com/example/inference/state/LlamaState.java
@@ -0,0 +1,65 @@
+package com.example.inference.state;
+
+import com.example.core.model.tensor.ArrayFloatTensor;
+import com.example.core.model.tensor.FloatTensor;
+import com.example.model.Configuration;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+import java.util.stream.Stream;
+
+public final class LlamaState extends State {
+
+ public LlamaState(Configuration config, int batchsize) {
+ super(config, batchsize);
+ }
+
+ @Override
+ protected StateFields createStateFields(Configuration config) {
+ StateFields fields = new StateFields();
+
+ // Allocation with Llama/Mistral dimensions
+ fields.x = ArrayFloatTensor.allocate(config.dim());
+ fields.xb = ArrayFloatTensor.allocate(config.dim());
+ fields.xb2 = ArrayFloatTensor.allocate(config.dim());
+ fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
+ fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
+ fields.q = ArrayFloatTensor.allocate(config.dim());
+ fields.k = ArrayFloatTensor.allocate(config.dim());
+ fields.v = ArrayFloatTensor.allocate(config.dim());
+ fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
+ fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
+
+ // Key-value cache with Llama/Mistral dimensions
+ int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
+ fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
+ fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
+
+ // TornadoVM wrappers with Llama/Mistral dimensions
+ fields.wrapX = new FloatArray(config.dim());
+ fields.wrapXb = new FloatArray(config.dim());
+ fields.wrapXb2 = new FloatArray(config.dim());
+ fields.wrapHb = new FloatArray(config.hiddenDim());
+ fields.wrapHb2 = new FloatArray(config.hiddenDim());
+
+ fields.wrapLogits = new FloatArray(config.vocabularySize());
+ fields.wrapQ = new FloatArray(config.dim());
+ fields.wrapK = new FloatArray(config.dim());
+ fields.wrapV = new FloatArray(config.dim());
+
+ // dim vs kvdim
+ fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
+ fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
+ fields.wrapValueCache.init(0.f);
+ fields.wrapKeyCache.init(0.f);
+ fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
+ fields.positionHolder = new IntArray(1);
+
+ // Temporary arrays
+ fields.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
+ fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
+ fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
+
+ return fields;
+ }
+}
diff --git a/src/main/java/com/example/inference/state/Qwen3State.java b/src/main/java/com/example/inference/state/Qwen3State.java
new file mode 100644
index 0000000..b82a056
--- /dev/null
+++ b/src/main/java/com/example/inference/state/Qwen3State.java
@@ -0,0 +1,78 @@
+package com.example.inference.state;
+
+import com.example.core.model.tensor.ArrayFloatTensor;
+import com.example.core.model.tensor.FloatTensor;
+import com.example.model.Configuration;
+import com.example.model.qwen3.Qwen3Configuration;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+import java.util.stream.Stream;
+
+public final class Qwen3State extends State {
+
+ // Qwen3-specific field
+ public final FloatTensor kq;
+
+ public Qwen3State(Configuration config, int batchsize) {
+ super(config, batchsize);
+ // Initialize Qwen3-specific field
+ this.kq = ArrayFloatTensor.allocate(config.numberOfHeads(), 32, 15);
+ }
+
+ @Override
+ protected StateFields createStateFields(Configuration configuration) {
+ StateFields fields = new StateFields();
+
+ Qwen3Configuration config = (Qwen3Configuration) configuration;
+
+ // Qwen3-specific calculations
+ int nHeadKv = config.numberOfKeyValueHeads();
+ int nEmbdHeadK = config.numberOfHeadsKey();
+ int nEmbdKGqa = nEmbdHeadK * nHeadKv;
+
+ // Qwen3-specific allocation logic
+ fields.x = ArrayFloatTensor.allocate(config.dim());
+ fields.xb = ArrayFloatTensor.allocate(nEmbdHeadK * config.numberOfHeads());
+ fields.xb2 = ArrayFloatTensor.allocate(config.dim());
+ fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
+ fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
+ fields.q = ArrayFloatTensor.allocate(nEmbdHeadK * config.numberOfHeads());
+ fields.k = ArrayFloatTensor.allocate(nEmbdKGqa); // Different from Llama!
+ fields.v = ArrayFloatTensor.allocate(nEmbdKGqa); // Different from Llama!
+ fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
+ fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
+
+ // Key-value cache with Qwen3 dimensions
+ int kvDim = nEmbdKGqa;
+ fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim))
+ .limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
+ fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim))
+ .limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
+
+ // TornadoVM wrappers with Qwen3-specific sizes
+ fields.wrapX = new FloatArray(config.dim());
+ fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads()); // Different from Llama!
+ fields.wrapXb2 = new FloatArray(config.dim());
+ fields.wrapHb = new FloatArray(config.hiddenDim());
+ fields.wrapHb2 = new FloatArray(config.hiddenDim());
+ fields.wrapLogits = new FloatArray(config.vocabularySize());
+ fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads()); // Different from Llama!
+ fields.wrapK = new FloatArray(nEmbdKGqa); // Different from Llama!
+ fields.wrapV = new FloatArray(nEmbdKGqa); // Different from Llama!
+
+ fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
+ fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
+ fields.wrapValueCache.init(0.f);
+ fields.wrapKeyCache.init(0.f);
+ fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
+ fields.positionHolder = new IntArray(1);
+
+ // Temporary arrays
+ fields.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
+ fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
+ fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
+
+ return fields;
+ }
+}
diff --git a/src/main/java/com/example/inference/state/State.java b/src/main/java/com/example/inference/state/State.java
new file mode 100644
index 0000000..a4436f2
--- /dev/null
+++ b/src/main/java/com/example/inference/state/State.java
@@ -0,0 +1,116 @@
+package com.example.inference.state;
+
+import com.example.core.model.tensor.FloatTensor;
+import com.example.model.Configuration;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+/**
+ * Base class for State
+ */
+public abstract class State{
+
+ // current wave of activations
+ public final FloatTensor x; // activation at current time stamp (dim,)
+ public final FloatTensor xb; // same, but inside a residual branch (dim,)
+ public final FloatTensor xb2; // an additional buffer just for convenience (dim,)
+ public final FloatTensor hb; // buffer for hidden dimension in the ffn (hidden_dim,)
+ public final FloatTensor hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
+ public final FloatTensor q; // query (dim,)
+ public final FloatTensor k; // key (dim,)
+ public final FloatTensor v; // value (dim,)
+ public final FloatTensor att; // buffer for scores/attention values (n_heads, seq_len)
+ public final FloatTensor logits; // output logits
+ public final int batchsize;
+
+ // kv cache
+ public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim)
+ public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim)
+
+ // Wrappers for TornadoVM compatibility (FloatArray data structure for TornadoVM acceleration)
+ // TornadoVM uses FloatArray for more efficient handling of data, particularly when running on GPU or other accelerators.
+ public final FloatArray wrapLogits; // FloatArray wrapper for the logits tensor, compatible with TornadoVM for GPU execution.
+ public final FloatArray wrapXb; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage.
+ public final FloatArray wrapXb2; // FloatArray wrapper for xb2, another residual buffer to aid in computations with TornadoVM.
+ public final FloatArray wrapHb; // FloatArray wrapper for hb (hidden dimension buffer for FFN), optimized for TornadoVM.
+ public final FloatArray wrapHb2; // FloatArray wrapper for hb2, additional hidden buffer for FFN, for compatibility with TornadoVM.
+ public final FloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM.
+ public final FloatArray wrapQ; // FloatArray wrapper for the query tensor, optimized for TornadoVM.
+ public final FloatArray wrapK; // FloatArray wrapper for the key tensor, optimized for TornadoVM.
+ public final FloatArray wrapV; // FloatArray wrapper for the value tensor, optimized for TornadoVM.
+ public final FloatArray wrapAtt; // FloatArray wrapper for the attention scores, optimized for TornadoVM.
+ public final FloatArray wrapKeyCache; // FloatArray wrapper for the key cache, optimized for TornadoVM.
+ public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM.
+ public final IntArray positionHolder;
+
+ // store inter
+ public int localSize;
+ public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size.
+ public FloatArray tempFFN; // Temporary buffer for feed-forward network calculations, size adjusted for local workgroup size.
+ public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size.
+ public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models.
+
+ /** last index in previous block */
+
+ protected State(Configuration config, int batchsize) {
+ this.batchsize = -1;
+ this.latestToken = -1;
+ this.localSize = 256;
+
+ // Initialize all fields through the creation method
+ StateFields fields = createStateFields(config);
+
+ this.x = fields.x;
+ this.xb = fields.xb;
+ this.xb2 = fields.xb2;
+ this.hb = fields.hb;
+ this.hb2 = fields.hb2;
+ this.q = fields.q;
+ this.k = fields.k;
+ this.v = fields.v;
+ this.att = fields.att;
+ this.logits = fields.logits;
+ //int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
+ this.keyCache = fields.keyCache;
+ this.valueCache = fields.valueCache;
+
+ this.wrapX = fields.wrapX;
+ this.wrapXb = fields.wrapXb;
+ this.wrapXb2 = fields.wrapXb2;
+ this.wrapHb = fields.wrapHb;
+ this.wrapHb2 = fields.wrapHb2;
+ this.wrapLogits = fields.wrapLogits;
+ this.wrapQ = fields.wrapQ;
+ this.wrapK = fields.wrapK;
+ this.wrapV = fields.wrapV;
+
+ // dim vs kvdim
+ this.wrapKeyCache = fields.wrapKeyCache;
+ this.wrapValueCache = fields.wrapValueCache;
+ this.wrapAtt = fields.wrapAtt;
+ this.positionHolder = fields.positionHolder;
+
+ // You need at least 9 elements: 1 for the final result + 8 for the workgroup partial sums
+ this.temp = fields.temp;
+ this.tempFFN = fields.tempFFN;
+ this.tempLogits = fields.tempLogits;
+ }
+
+ // Abstract method - subclasses implement their specific allocation logic and sizes
+ protected abstract StateFields createStateFields(Configuration config);
+
+ // Helper class to hold all the state fields during construction
+ protected static class StateFields {
+ public FloatTensor x, xb, xb2, hb, hb2, q, k, v, att, logits;
+ public FloatTensor[] keyCache, valueCache;
+ public FloatArray wrapX, wrapXb, wrapXb2, wrapHb, wrapHb2, wrapLogits;
+ public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache;
+ public IntArray positionHolder;
+ public FloatArray temp, tempFFN, tempLogits;
+ }
+
+ @Override
+ public State clone() throws CloneNotSupportedException {
+ return (State) super.clone();
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/inference/weights/Weights.java b/src/main/java/com/example/inference/weights/Weights.java
new file mode 100644
index 0000000..f2c495a
--- /dev/null
+++ b/src/main/java/com/example/inference/weights/Weights.java
@@ -0,0 +1,9 @@
+package com.example.inference.weights;
+
+import com.example.core.model.GGMLType;
+
+public interface Weights {
+
+ GGMLType getWeightType();
+
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java b/src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java
new file mode 100644
index 0000000..621d69a
--- /dev/null
+++ b/src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java
@@ -0,0 +1,17 @@
+package com.example.inference.weights.standard;
+
+import com.example.core.model.GGMLType;
+import com.example.core.model.tensor.FloatTensor;
+
+public class LlamaStandardWeights extends StandardWeights {
+
+ public LlamaStandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatTensor[] rms_ffn_weight,
+ FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatTensor rms_final_weight, FloatTensor freq_cis_real, FloatTensor freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
+ super(token_embedding_table, rms_att_weight, wq, wk, wv, wo, rms_ffn_weight, w1, w2, w3, rms_final_weight, freq_cis_real, freq_cis_imag, wcls, weightType);
+ }
+
+ @Override
+ public GGMLType getWeightType() {
+ return weightType;
+ }
+}
diff --git a/src/main/java/com/example/inference/weights/standard/Qwen3StandardWeights.java b/src/main/java/com/example/inference/weights/standard/Qwen3StandardWeights.java
new file mode 100644
index 0000000..e40c679
--- /dev/null
+++ b/src/main/java/com/example/inference/weights/standard/Qwen3StandardWeights.java
@@ -0,0 +1,25 @@
+package com.example.inference.weights.standard;
+
+import com.example.core.model.GGMLType;
+import com.example.core.model.tensor.FloatTensor;
+
+public class Qwen3StandardWeights extends StandardWeights {
+ public final FloatTensor[] attnKNorm, attnQNorm;
+
+ public Qwen3StandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight,
+ FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo,
+ FloatTensor[] attnKNorm, FloatTensor[] attnQNorm,
+ FloatTensor[] rms_ffn_weight,
+ FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3,
+ FloatTensor rms_final_weight, FloatTensor freq_cis_real, FloatTensor freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
+ // call to StandardWeights constructor
+ super(token_embedding_table, rms_att_weight, wq, wk, wv, wo, rms_ffn_weight, w1, w2, w3, rms_final_weight, freq_cis_real, freq_cis_imag, wcls, weightType);
+ this.attnKNorm = attnKNorm;
+ this.attnQNorm = attnQNorm;
+ }
+
+ @Override
+ public GGMLType getWeightType() {
+ return weightType;
+ }
+}
diff --git a/src/main/java/com/example/inference/weights/standard/StandardWeights.java b/src/main/java/com/example/inference/weights/standard/StandardWeights.java
new file mode 100644
index 0000000..afebf60
--- /dev/null
+++ b/src/main/java/com/example/inference/weights/standard/StandardWeights.java
@@ -0,0 +1,95 @@
+package com.example.inference.weights.standard;
+
+import com.example.core.model.GGMLType;
+import com.example.core.model.tensor.FloatTensor;
+import com.example.inference.weights.Weights;
+
+public abstract class StandardWeights implements Weights {
+ // token embedding table
+ public final FloatTensor token_embedding_table; // (vocab_size, dim)
+ // weights for rmsnorms
+ public final FloatTensor[] rms_att_weight; // (layer, dim) rmsnorm weights
+ // weights for matmuls
+ public final FloatTensor[] wq; // (layer, n_heads * head_size)
+ public final FloatTensor[] wk; // (layer, n_kv_heads, head_size)
+ public final FloatTensor[] wv; // (layer, n_kv_heads * head_size)
+ public final FloatTensor[] wo; // (layer, n_heads * head_size, dim)
+ //public final FloatTensor[] attnKNorm; // qwen3
+ //public final FloatTensor[] attnQNorm; // qwen3
+ public final FloatTensor[] rms_ffn_weight; // (layer, dim)
+
+ // weights for ffn
+ public final FloatTensor[] w1; // (layer, hidden_dim, dim)
+ public final FloatTensor[] w2; // (layer, dim, hidden_dim)
+ public final FloatTensor[] w3; // (layer, hidden_dim, dim)
+ //
+ public final FloatTensor wcls; // (vocab_size, dim)
+ // public final rmsnorm
+ public final FloatTensor rms_final_weight; // (dim,)
+ // freq_cis for RoPE relatively positional embeddings
+ public final FloatTensor freq_cis_real; // (seq_len, head_size/2)
+ public final FloatTensor freq_cis_imag; // (seq_len, head_size/2)
+
+ // (optional) classifier weights for the logits, on the last layer
+ protected final GGMLType weightType;
+
+ /**
+ * Constructor for standard (non-TornadoVM) mode
+ *
+ * @param token_embedding_table
+ * Token embeddings matrix
+ * @param rms_att_weight
+ * RMSNorm weights for attention layers
+ * @param wq
+ * Query weight matrices
+ * @param wk
+ * Key weight matrices
+ * @param wv
+ * Value weight matrices
+ * @param wo
+ * Output projection matrices
+ * @param rms_ffn_weight
+ * RMSNorm weights for FFN layers
+ * @param w1
+ * First FFN weight matrices
+ * @param w2
+ * Second FFN weight matrices
+ * @param w3
+ * Third FFN weight matrices (gate)
+ * @param rms_final_weight
+ * Final layer normalization weights
+ * @param freq_cis_real
+ * RoPE cosine components
+ * @param freq_cis_imag
+ * RoPE sine components
+ * @param wcls
+ * Classifier weights for output logits
+ */
+ protected StandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight,
+ FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo,
+ //FloatTensor[] attnKNorm, FloatTensor[] attnQNorm,
+ FloatTensor[] rms_ffn_weight,
+ FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3,
+ FloatTensor rms_final_weight,
+ FloatTensor freq_cis_real, FloatTensor freq_cis_imag,
+ FloatTensor wcls, GGMLType weightType) {
+
+ // Standard format
+ this.token_embedding_table = token_embedding_table;
+ this.rms_att_weight = rms_att_weight;
+ this.wq = wq;
+ this.wk = wk;
+ this.wv = wv;
+ this.wo = wo;
+
+ this.rms_ffn_weight = rms_ffn_weight;
+ this.w1 = w1;
+ this.w2 = w2;
+ this.w3 = w3;
+ this.wcls = wcls;
+ this.rms_final_weight = rms_final_weight;
+ this.freq_cis_real = freq_cis_real;
+ this.freq_cis_imag = freq_cis_imag;
+ this.weightType = weightType;
+ }
+}
diff --git a/src/main/java/com/example/inference/weights/tornado/LlamaTornadoWeights.java b/src/main/java/com/example/inference/weights/tornado/LlamaTornadoWeights.java
new file mode 100644
index 0000000..ec786e3
--- /dev/null
+++ b/src/main/java/com/example/inference/weights/tornado/LlamaTornadoWeights.java
@@ -0,0 +1,15 @@
+package com.example.inference.weights.tornado;
+
+import com.example.core.model.GGMLType;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+
+public class LlamaTornadoWeights extends TornadoWeights{
+
+ public LlamaTornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered,
+ HalfFloatArray[] woLayered, FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered,
+ FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray, GGMLType weightType) {
+ super(tokenEmbeddingTable, rms_att_weightLayered, wqLayered, wkLayered, wvLayered, woLayered, rms_ffn_weightLayered, w1Layered, w2Layered, w3Layered, rms_final_weight_as_floatArray,
+ freq_cis_realFlat, freq_cis_imagFlat, wclsByteArray, weightType);
+ }
+}
diff --git a/src/main/java/com/example/inference/weights/tornado/TornadoWeights.java b/src/main/java/com/example/inference/weights/tornado/TornadoWeights.java
new file mode 100644
index 0000000..39ab75b
--- /dev/null
+++ b/src/main/java/com/example/inference/weights/tornado/TornadoWeights.java
@@ -0,0 +1,68 @@
+package com.example.inference.weights.tornado;
+
+import com.example.core.model.GGMLType;
+import com.example.inference.weights.Weights;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+
+public abstract class TornadoWeights implements Weights {
+
+ public FloatArray[] rms_att_weightLayered; // (layer, dim) rmsnorm weights
+ public HalfFloatArray[] wqLayered; // (layer, n_heads * head_size)
+ public HalfFloatArray[] wkLayered; // (layer, n_kv_heads, head_size)
+ public HalfFloatArray[] wvLayered; // (layer, n_kv_heads * head_size)
+ public HalfFloatArray[] woLayered; // (layer, n_heads * head_size, dim)
+ public FloatArray[] rms_ffn_weightLayered; // (layer, dim)
+ public HalfFloatArray[] w1Layered; // (layer, hidden_dim, dim)
+ public HalfFloatArray[] w2Layered; // (layer, dim, hidden_dim)
+ public HalfFloatArray[] w3Layered; // (layer, hidden_dim, dim)
+ public FloatArray rms_final_weight_as_floatArray;
+ public FloatArray tokenEmbeddingTable; // (vocab_size, dim)
+ public FloatArray freq_cis_realFlat; // (seq_len, head_size/2)
+ public FloatArray freq_cis_imagFlat; // (seq_len, head_size/2)
+ public HalfFloatArray wclsHalfFloat;
+
+ // (optional) classifier weights for the logits, on the last layer
+ protected final GGMLType weightType;
+
+ protected TornadoWeights(
+ FloatArray tokenEmbeddingTable,
+ FloatArray[] rms_att_weightLayered,
+ HalfFloatArray[] wqLayered,
+ HalfFloatArray[] wkLayered,
+ HalfFloatArray[] wvLayered,
+ HalfFloatArray[] woLayered,
+ FloatArray[] rms_ffn_weightLayered,
+ HalfFloatArray[] w1Layered,
+ HalfFloatArray[] w2Layered,
+ HalfFloatArray[] w3Layered,
+ FloatArray rms_final_weight_as_floatArray,
+ FloatArray freq_cis_realFlat,
+ FloatArray freq_cis_imagFlat,
+ HalfFloatArray wclsByteArray,
+ GGMLType weightType) {
+ // TornadoVM format
+ this.tokenEmbeddingTable = tokenEmbeddingTable;
+ this.rms_att_weightLayered = rms_att_weightLayered;
+ this.wqLayered = wqLayered;
+ this.wkLayered = wkLayered;
+ this.wvLayered = wvLayered;
+ this.woLayered = woLayered;
+ this.rms_ffn_weightLayered = rms_ffn_weightLayered;
+ this.w1Layered = w1Layered;
+ this.w2Layered = w2Layered;
+ this.w3Layered = w3Layered;
+ this.rms_final_weight_as_floatArray = rms_final_weight_as_floatArray;
+ this.freq_cis_realFlat = freq_cis_realFlat;
+ this.freq_cis_imagFlat = freq_cis_imagFlat;
+ this.wclsHalfFloat = wclsByteArray;
+ this.weightType = weightType;
+ }
+
+ @Override
+ public GGMLType getWeightType() {
+ return weightType;
+ }
+
+
+}
diff --git a/src/main/java/com/example/loader/weights/State.java b/src/main/java/com/example/loader/weights/State.java
deleted file mode 100644
index 12f968a..0000000
--- a/src/main/java/com/example/loader/weights/State.java
+++ /dev/null
@@ -1,107 +0,0 @@
-package com.example.loader.weights;
-
-import com.example.core.model.tensor.ArrayFloatTensor;
-import com.example.core.model.tensor.FloatTensor;
-import com.example.model.Configuration;
-import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import uk.ac.manchester.tornado.api.types.arrays.IntArray;
-
-import java.util.stream.Stream;
-
-public final class State {
-
- // current wave of activations
- public final FloatTensor x; // activation at current time stamp (dim,)
- public final FloatTensor xb; // same, but inside a residual branch (dim,)
- public final FloatTensor xb2; // an additional buffer just for convenience (dim,)
- public final FloatTensor hb; // buffer for hidden dimension in the ffn (hidden_dim,)
- public final FloatTensor hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
- public final FloatTensor q; // query (dim,)
- public final FloatTensor k; // key (dim,)
- public final FloatTensor v; // value (dim,)
- public final FloatTensor att; // buffer for scores/attention values (n_heads, seq_len)
- public final FloatTensor logits; // output logits
- public final int batchsize;
-
- // kv cache
- public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim)
- public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim)
-
- // Wrappers for TornadoVM compatibility (FloatArray data structure for TornadoVM acceleration)
- // TornadoVM uses FloatArray for more efficient handling of data, particularly when running on GPU or other accelerators.
- public final FloatArray wrapLogits; // FloatArray wrapper for the logits tensor, compatible with TornadoVM for GPU execution.
- public final FloatArray wrapXb; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage.
- public final FloatArray wrapXb2; // FloatArray wrapper for xb2, another residual buffer to aid in computations with TornadoVM.
- public final FloatArray wrapHb; // FloatArray wrapper for hb (hidden dimension buffer for FFN), optimized for TornadoVM.
- public final FloatArray wrapHb2; // FloatArray wrapper for hb2, additional hidden buffer for FFN, for compatibility with TornadoVM.
- public final FloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM.
-
- public final FloatArray wrapQ; // FloatArray wrapper for the query tensor, optimized for TornadoVM.
- public final FloatArray wrapK; // FloatArray wrapper for the key tensor, optimized for TornadoVM.
- public final FloatArray wrapV; // FloatArray wrapper for the value tensor, optimized for TornadoVM.
- public final FloatArray wrapAtt; // FloatArray wrapper for the attention scores, optimized for TornadoVM.
- public final FloatArray wrapKeyCache;// FloatArray wrapper for the key cache, optimized for TornadoVM.
- public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM.
- public final IntArray positionHolder;
-
- // store inter
- //
- public int localSize;
- public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size.
- public FloatArray tempFFN; // Temporary buffer for feed-forward network calculations, size adjusted for local workgroup size.
- public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size.
-
- public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models.
-
- /** last index in previous block */
-
- public State(Configuration config, int batchsize) {
- this.batchsize = -1;
-
- this.x = ArrayFloatTensor.allocate(config.dim());
- this.xb = ArrayFloatTensor.allocate(config.dim());
- this.xb2 = ArrayFloatTensor.allocate(config.dim());
- this.hb = ArrayFloatTensor.allocate(config.hiddenDim());
- this.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
- this.q = ArrayFloatTensor.allocate(config.dim());
- this.k = ArrayFloatTensor.allocate(config.dim());
- this.v = ArrayFloatTensor.allocate(config.dim());
- this.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
- this.logits = ArrayFloatTensor.allocate(config.vocabularySize());
- int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
- this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
- this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
-
- this.wrapX = new FloatArray(config.dim());
- this.wrapXb = new FloatArray(config.dim());
- this.wrapXb2 = new FloatArray(config.dim());
- this.wrapHb = new FloatArray(config.hiddenDim());
- this.wrapHb2 = new FloatArray(config.hiddenDim());
-
- this.wrapLogits = new FloatArray(config.vocabularySize());
- this.wrapQ = new FloatArray(config.dim());
- this.wrapK = new FloatArray(config.dim());
- this.wrapV = new FloatArray(config.dim());
-
- // dim vs kvdim
- this.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
- this.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
- this.wrapValueCache.init(0.f);
- this.wrapKeyCache.init(0.f);
- this.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
- this.positionHolder = new IntArray(1);
- this.latestToken = -1;
-
- //
- this.localSize = 256;
- // You need at least 9 elements: 1 for the final result + 8 for the workgroup partial sums
- this.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
- this.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
- this.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
- }
-
- @Override
- public State clone() throws CloneNotSupportedException {
- return (State) super.clone();
- }
-}
\ No newline at end of file
diff --git a/src/main/java/com/example/loader/weights/Weights.java b/src/main/java/com/example/loader/weights/Weights.java
deleted file mode 100644
index d5b004c..0000000
--- a/src/main/java/com/example/loader/weights/Weights.java
+++ /dev/null
@@ -1,173 +0,0 @@
-package com.example.loader.weights;
-
-import com.example.core.model.GGMLType;
-import com.example.core.model.tensor.FloatTensor;
-import com.example.core.model.tensor.GGMLTensorEntry;
-import com.example.core.types.Float16;
-import uk.ac.manchester.tornado.api.types.HalfFloat;
-import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
-import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-
-import java.lang.foreign.MemorySegment;
-import java.nio.ByteOrder;
-import java.nio.FloatBuffer;
-import java.util.function.IntFunction;
-
-import static com.example.core.model.tensor.FloatTensor.readByte;
-import static com.example.core.model.tensor.FloatTensor.readShort;
-
-public class Weights {
- // token embedding table
- public final FloatTensor token_embedding_table; // (vocab_size, dim)
- // weights for rmsnorms
- public final FloatBuffer[] rms_att_weight; // (layer, dim) rmsnorm weights
- // weights for matmuls
- public final FloatTensor[] wq; // (layer, n_heads * head_size)
- public final FloatTensor[] wk; // (layer, n_kv_heads, head_size)
- public final FloatTensor[] wv; // (layer, n_kv_heads * head_size)
- public final FloatTensor[] wo; // (layer, n_heads * head_size, dim)
- public final FloatBuffer[] rms_ffn_weight; // (layer, dim)
-
- // weights for ffn
- public final FloatTensor[] w1; // (layer, hidden_dim, dim)
- public final FloatTensor[] w2; // (layer, dim, hidden_dim)
- public final FloatTensor[] w3; // (layer, hidden_dim, dim)
- //
- public final FloatTensor wcls; // (vocab_size, dim)
- public final HalfFloatArray wclsHalfFloat;
- // public final rmsnorm
- public final FloatBuffer rms_final_weight; // (dim,)
- // freq_cis for RoPE relatively positional embeddings
- public final FloatBuffer freq_cis_real; // (seq_len, head_size/2)
- public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2)
- // // Layered Data structures
- public FloatArray[] rms_att_weightLayered; // (layer, dim) rmsnorm weights
- public HalfFloatArray[] wqLayered; // (layer, n_heads * head_size)
- public HalfFloatArray[] wkLayered; // (layer, n_kv_heads, head_size)
- public HalfFloatArray[] wvLayered; // (layer, n_kv_heads * head_size)
- public HalfFloatArray[] woLayered; // (layer, n_heads * head_size, dim)
- public FloatArray[] rms_ffn_weightLayered; // (layer, dim)
- public HalfFloatArray[] w1Layered; // (layer, hidden_dim, dim)
- public HalfFloatArray[] w2Layered; // (layer, dim, hidden_dim)
- //
- public HalfFloatArray[] w3Layered; // (layer, hidden_dim, dim)
- public FloatArray rms_final_weight_as_floatArray;
- public FloatArray tokenEmbeddingTable; // (vocab_size, dim)
- public FloatArray freq_cis_realFlat; // (seq_len, head_size/2)
- public FloatArray freq_cis_imagFlat; // (seq_len, head_size/2)
- // (optional) classifier weights for the logits, on the last layer
- public GGMLType weightType;
-
- /**
- * Constructor to initialize all weight tensors for the model. Automatically creates TornadoVM-compatible versions when needed.
- *
- * @param token_embedding_table
- * Token embeddings matrix
- * @param rms_att_weight
- * RMSNorm weights for attention layers
- * @param wq
- * Query weight matrices
- * @param wk
- * Key weight matrices
- * @param wv
- * Value weight matrices
- * @param wo
- * Output projection matrices
- * @param rms_ffn_weight
- * RMSNorm weights for FFN layers
- * @param w1
- * First FFN weight matrices
- * @param w2
- * Second FFN weight matrices
- * @param w3
- * Third FFN weight matrices (gate)
- * @param rms_final_weight
- * Final layer normalization weights
- * @param freq_cis_real
- * RoPE cosine components
- * @param freq_cis_imag
- * RoPE sine components
- * @param wcls
- * Classifier weights for output logits
- *
- /**
- * Constructor for standard (non-TornadoVM) mode
- */
- public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight,
- FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
- // Standard format
- this.token_embedding_table = token_embedding_table;
- this.rms_att_weight = rms_att_weight;
- this.wq = wq;
- this.wk = wk;
- this.wv = wv;
- this.wo = wo;
- this.rms_ffn_weight = rms_ffn_weight;
- this.w1 = w1;
- this.w2 = w2;
- this.w3 = w3;
- this.wcls = wcls;
- this.rms_final_weight = rms_final_weight;
- this.freq_cis_real = freq_cis_real;
- this.freq_cis_imag = freq_cis_imag;
- this.weightType = weightType;
-
- // TornadoVM format (null when not using TornadoVM)
- this.tokenEmbeddingTable = null;
- this.rms_att_weightLayered = null;
- this.wqLayered = null;
- this.wkLayered = null;
- this.wvLayered = null;
- this.woLayered = null;
- this.rms_ffn_weightLayered = null;
- this.w1Layered = null;
- this.w2Layered = null;
- this.w3Layered = null;
- this.rms_final_weight_as_floatArray = null;
- this.freq_cis_realFlat = null;
- this.freq_cis_imagFlat = null;
- this.wclsHalfFloat = null;
- }
-
- /**
- * Constructor for TornadoVM mode
- */
- public Weights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, HalfFloatArray[] woLayered,
- FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray,
- FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray, GGMLType weightType) {
- // Standard format (null when using TornadoVM)
- this.token_embedding_table = null;
- this.rms_att_weight = null;
- this.wq = null;
- this.wk = null;
- this.wv = null;
- this.wo = null;
- this.rms_ffn_weight = null;
- this.w1 = null;
- this.w2 = null;
- this.w3 = null;
- this.wcls = null;
- this.rms_final_weight = null;
- this.freq_cis_real = null;
- this.freq_cis_imag = null;
-
- // TornadoVM format
- this.tokenEmbeddingTable = tokenEmbeddingTable;
- this.rms_att_weightLayered = rms_att_weightLayered;
- this.wqLayered = wqLayered;
- this.wkLayered = wkLayered;
- this.wvLayered = wvLayered;
- this.woLayered = woLayered;
- this.rms_ffn_weightLayered = rms_ffn_weightLayered;
- this.w1Layered = w1Layered;
- this.w2Layered = w2Layered;
- this.w3Layered = w3Layered;
- this.rms_final_weight_as_floatArray = rms_final_weight_as_floatArray;
- this.freq_cis_realFlat = freq_cis_realFlat;
- this.freq_cis_imagFlat = freq_cis_imagFlat;
- this.wclsHalfFloat = wclsByteArray;
- this.weightType = weightType;
- }
-
-}
\ No newline at end of file
diff --git a/src/main/java/com/example/model/Configuration.java b/src/main/java/com/example/model/Configuration.java
index 6ff558c..56e14b8 100644
--- a/src/main/java/com/example/model/Configuration.java
+++ b/src/main/java/com/example/model/Configuration.java
@@ -17,12 +17,17 @@ public interface Configuration {
/** Number of key/value heads (can be fewer than query heads in multi-query attention) */
int numberOfKeyValueHeads();
+ int numberOfHeadsKey();
+
/** Size of the vocabulary (token set) */
int vocabularySize();
/** Maximum sequence length the model can process */
int contextLength();
+ /** Max sequence length in model */
+ int contextLengthModel();
+
/** Epsilon value for RMSNorm layers (stabilizes normalization) */
float rmsNormEps();
diff --git a/src/main/java/com/example/model/Model.java b/src/main/java/com/example/model/Model.java
index e42349b..3147ee2 100644
--- a/src/main/java/com/example/model/Model.java
+++ b/src/main/java/com/example/model/Model.java
@@ -4,8 +4,8 @@
import com.example.auxiliary.LastRunMetrics;
import com.example.inference.InferenceEngine;
import com.example.inference.sampler.Sampler;
-import com.example.loader.weights.State;
-import com.example.loader.weights.Weights;
+import com.example.inference.state.State;
+import com.example.inference.weights.Weights;
import com.example.model.format.ChatFormat;
import com.example.tokenizer.impl.Tokenizer;
import com.example.tornadovm.TornadoVMMasterPlan;
@@ -20,18 +20,36 @@
import static com.example.LlamaApp.USE_TORNADOVM;
public interface Model {
+
Configuration configuration();
Tokenizer tokenizer();
Weights weights();
+ ChatFormat chatFormat();
+
ModelType getModelType();
State createNewState();
State createNewState(int batchsize);
+ /**
+ * Wrapper for invoking the model-specific forward pass via InferenceCore.
+ *
+ *
+ * Delegates to the appropriate InferenceCore method based on the model type
+ * (e.g., {@code forwardJava}, {@code forwardJavaQwen3}).
+ *
+ */
+ void forward(State state, int token, int position);
+
+ /**
+ * Wrapper for invoking the model-specific {@code InferenceEngine.generateTokens} call.
+ */
+ List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated);
+
/**
* Model agnostic default implementation for interactive mode.
* @param sampler
@@ -41,8 +59,11 @@ default void runInteractive(Sampler sampler, Options options) {
State state = null;
List conversationTokens = new ArrayList<>();
- ChatFormat chatFormat = ChatFormat.create(tokenizer());
- conversationTokens.add(chatFormat.getBeginOfText());
+ ChatFormat chatFormat = chatFormat();
+
+ if (!getModelType().equals(ModelType.QWEN_3)) {
+ conversationTokens.add(chatFormat.getBeginOfText());
+ }
if (options.systemPrompt() != null) {
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
@@ -92,7 +113,7 @@ default void runInteractive(Sampler sampler, Options options) {
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
} else {
// CPU path
- responseTokens = InferenceEngine.generateTokens(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
+ responseTokens = generateTokens(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
sampler, options.echo(), tokenConsumer);
}
@@ -138,11 +159,14 @@ default void runInteractive(Sampler sampler, Options options) {
*/
default void runInstructOnce(Sampler sampler, Options options) {
State state = createNewState();
- ChatFormat chatFormat = ChatFormat.create(tokenizer());
+ ChatFormat chatFormat = chatFormat();
TornadoVMMasterPlan tornadoVMPlan = null;
List promptTokens = new ArrayList<>();
- promptTokens.add(chatFormat.getBeginOfText());
+
+ if (!getModelType().equals(ModelType.QWEN_3)) {
+ promptTokens.add(chatFormat.getBeginOfText());
+ }
if (options.systemPrompt() != null) {
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
@@ -168,7 +192,7 @@ default void runInstructOnce(Sampler sampler, Options options) {
responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null,
tornadoVMPlan);
} else {
- responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
+ responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
}
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
diff --git a/src/main/java/com/example/model/ModelType.java b/src/main/java/com/example/model/ModelType.java
index 0426824..f36a4e1 100644
--- a/src/main/java/com/example/model/ModelType.java
+++ b/src/main/java/com/example/model/ModelType.java
@@ -1,8 +1,9 @@
package com.example.model;
import com.example.core.model.GGUF;
-import com.example.model.llama.Llama;
-import com.example.model.mistral.Mistral;
+import com.example.model.loader.LlamaModelLoader;
+import com.example.model.loader.MistralModelLoader;
+import com.example.model.loader.Qwen3ModelLoader;
import java.nio.channels.FileChannel;
@@ -10,14 +11,21 @@ public enum ModelType {
LLAMA_3 {
@Override
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
- return Llama.loadModel(fileChannel, gguf, contextLength, loadWeights);
+ return new LlamaModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
}
},
MISTRAL {
@Override
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
- return Mistral.loadModel(fileChannel, gguf, contextLength, loadWeights);
+ return new MistralModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
+ }
+ },
+
+ QWEN_3 {
+ @Override
+ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
+ return new Qwen3ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
}
},
diff --git a/src/main/java/com/example/model/format/ChatFormat.java b/src/main/java/com/example/model/format/ChatFormat.java
index b114c7d..d6e7e74 100644
--- a/src/main/java/com/example/model/format/ChatFormat.java
+++ b/src/main/java/com/example/model/format/ChatFormat.java
@@ -2,17 +2,26 @@
import com.example.tokenizer.impl.LlamaTokenizer;
import com.example.tokenizer.impl.MistralTokenizer;
+import com.example.tokenizer.impl.Qwen3Tokenizer;
import java.util.List;
import java.util.Set;
public interface ChatFormat {
- static ChatFormat create(Object tokenizer) {
+ default ChatTokens chatTokens() {
+ throw new UnsupportedOperationException("ChatFormat for Llama and Mistral does not support chatTokens");
+ }
+
+ public record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) { }
+
+ static ChatFormat create(Object tokenizer, ChatTokens chatTokens) {
if (tokenizer instanceof LlamaTokenizer llamaTokenizer) {
return new LlamaChatFormat(llamaTokenizer);
} else if (tokenizer instanceof MistralTokenizer mistralTokenizer) {
return new MistralChatFormat(mistralTokenizer);
+ } else if (tokenizer instanceof Qwen3Tokenizer qwen3Tokenizer) {
+ return new Qwen3ChatFormat(qwen3Tokenizer, chatTokens);
} else {
throw new IllegalArgumentException("Unsupported tokenizer type: " + tokenizer.getClass().getName());
}
@@ -54,6 +63,9 @@ record Role(String name) {
public static Role SYSTEM = new Role("system");
public static Role USER = new Role("user");
public static Role ASSISTANT = new Role("assistant");
+ public static Role FIM_PREFIX = new ChatFormat.Role("fim_prefix");
+ public static Role FIM_SUFFIX = new ChatFormat.Role("fim_suffix");
+ public static Role FIM_MIDDLE = new ChatFormat.Role("fim_middle");
@Override
public String toString() {
diff --git a/src/main/java/com/example/model/format/Qwen3ChatFormat.java b/src/main/java/com/example/model/format/Qwen3ChatFormat.java
new file mode 100644
index 0000000..2a42a03
--- /dev/null
+++ b/src/main/java/com/example/model/format/Qwen3ChatFormat.java
@@ -0,0 +1,123 @@
+package com.example.model.format;
+
+import com.example.tokenizer.impl.Qwen3Tokenizer;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Utility tailored for the Chat Markup Language (ChatML) prompt format.
+ */
+public class Qwen3ChatFormat implements ChatFormat {
+
+ protected Qwen3Tokenizer tokenizer;
+ protected ChatTokens chatTokens;
+
+ protected final int beginOfText;
+ protected final int startHeader;
+ protected final int endHeader;
+ protected final int endOfTurn;
+ protected final int endOfText;
+ protected final int endOfMessage;
+ protected final int endOfTextFim;
+
+ protected final int imStart; // beginOfText
+ protected final int imEnd; // endOfText
+
+ protected final int fimPrefix;
+ protected final int fimSuffix;
+ protected final int fimMiddle;
+
+ public Qwen3ChatFormat(Qwen3Tokenizer tokenizer, ChatTokens chatTokens) {
+ //super(tokenizer, "", chatTokens.tStartHeader(), chatTokens.tEndHeader(), chatTokens.tEndOfTurn(), chatTokens.tEndOfText(), "", chatTokens.tEndOfTextFim());
+ this.tokenizer = tokenizer;
+ this.chatTokens = chatTokens;
+ Map specialTokens = tokenizer.getSpecialTokens();
+ this.beginOfText = specialTokens.getOrDefault("", -1);
+ this.startHeader = specialTokens.getOrDefault(chatTokens.tStartHeader(), -1);
+ this.endHeader = specialTokens.getOrDefault(chatTokens.tEndHeader(), -1);
+ this.endOfTurn = specialTokens.getOrDefault(chatTokens.tEndOfTurn(), -1);
+ this.endOfText = specialTokens.getOrDefault(chatTokens.tEndOfText(), -1);
+ this.endOfTextFim = specialTokens.getOrDefault(chatTokens.tEndOfTextFim(), -1);
+ this.endOfMessage = specialTokens.getOrDefault("", -1); // Use default value if key not found
+
+ this.imStart = startHeader;
+ this.imEnd = endHeader;
+
+ fimPrefix = specialTokens.getOrDefault("<|fim_prefix|>", -1);
+ fimSuffix = specialTokens.getOrDefault("<|fim_suffix|>", -1);
+ fimMiddle = specialTokens.getOrDefault("<|fim_middle|>", -1);
+ }
+
+ public ChatTokens chatTokens() {
+ return chatTokens;
+ }
+
+ @Override
+ public List encodeHeader(Message message) {
+ List tokens = new ArrayList<>();
+ if (endHeader == -1) {
+ // DeepSeek-R1
+ String sToken = switch (message.role().name()) {
+ case "system" -> null;
+ case "user" -> "<|User|>";
+ case "assistant" -> "<|Assistant|>";
+ case "fim_prefix" -> "<|fim_prefix|>";
+ case "fim_middle" -> "<|fim_middle|>";
+ case "fim_suffix" -> "<|fim_suffix|>";
+ default -> null;
+ };
+ if (sToken != null) {
+ Integer token = tokenizer.getSpecialTokens().get("<|User|>");
+ if (token == null) {
+ throw new IllegalStateException(String.format("Unknown token '%s'", sToken));
+ }
+ tokens.add(token);
+ }
+ } else if (Role.FIM_PREFIX.equals(message.role())) {
+ // fill-in-the-middle, token fim_prefix.
+ tokens.add(fimPrefix);
+ } else if (Role.FIM_SUFFIX.equals(message.role())) {
+ tokens.add(fimSuffix);
+ } else if (Role.FIM_MIDDLE.equals(message.role())) {
+ tokens.add(fimMiddle);
+ } else {
+ tokens.add(imStart);
+ tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
+ tokens.addAll(this.tokenizer.encodeAsList("\n"));
+ }
+ return tokens;
+ }
+
+ @Override
+ public List encodeMessage(Message message) {
+ List tokens = this.encodeHeader(message);
+ tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
+ boolean isFim = Role.FIM_PREFIX.equals(message.role())
+ || Role.FIM_SUFFIX.equals(message.role())
+ || Role.FIM_MIDDLE.equals(message.role());
+ if (imEnd != -1 && !isFim) {
+ tokens.add(imEnd);
+ }
+ return tokens;
+ }
+
+ @Override
+ public int getBeginOfText() {
+ return beginOfText;
+ }
+
+ @Override
+ public Set getStopTokens() {
+ if (imEnd == -1 && endOfText == -1) {
+ throw new IllegalStateException("No stop token is defined.");
+ }
+ if (imEnd == -1) {
+
+ return Set.of(endOfText);
+ }
+ return Set.of(imEnd, endOfText, endOfTextFim);
+ }
+}
diff --git a/src/main/java/com/example/model/llama/Llama.java b/src/main/java/com/example/model/llama/Llama.java
index c8326a1..0ad3427 100644
--- a/src/main/java/com/example/model/llama/Llama.java
+++ b/src/main/java/com/example/model/llama/Llama.java
@@ -1,24 +1,22 @@
package com.example.model.llama;
-import com.example.auxiliary.Timer;
-import com.example.core.model.GGUF;
-import com.example.core.model.tensor.GGMLTensorEntry;
+import com.example.inference.InferenceCore;
+import com.example.inference.InferenceEngine;
+import com.example.inference.sampler.Sampler;
+import com.example.inference.state.LlamaState;
+import com.example.inference.state.State;
+import com.example.inference.weights.Weights;
import com.example.model.Model;
-import com.example.loader.weights.State;
-import com.example.loader.weights.Weights;
import com.example.model.ModelType;
+import com.example.model.format.ChatFormat;
import com.example.tokenizer.impl.LlamaTokenizer;
import com.example.tokenizer.impl.Tokenizer;
-import com.example.tokenizer.vocabulary.Vocabulary;
-import java.io.IOException;
-import java.nio.channels.FileChannel;
-import java.util.Map;
+import java.util.List;
+import java.util.Set;
+import java.util.function.IntConsumer;
-import static com.example.loader.weights.ModelLoader.loadWeights;
-
-public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model {
- private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);
+public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model {
/* For explicit use */
private LlamaTokenizer getAsLlamaTokenizer() {
@@ -32,53 +30,27 @@ public ModelType getModelType() {
@Override
public State createNewState() {
- State state = new State(configuration(), -1);
+ State state = new LlamaState(configuration(), -1);
state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>");
return state;
}
@Override
public State createNewState(int batchsize) {
- State state = new State(configuration(), batchsize);
+ State state = new LlamaState(configuration(), batchsize);
state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>");
return state;
}
- // @formatter:off
- public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
- try (var ignored = Timer.log("Load LlaMa model")) {
- Map metadata = gguf.getMetadata();
-
- Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata);
- Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary);
-
- LlamaConfiguration config = new LlamaConfiguration(
- (int) metadata.get("llama.embedding_length"),
- (int) metadata.get("llama.feed_forward_length"),
- (int) metadata.get("llama.block_count"),
- (int) metadata.get("llama.attention.head_count"),
-
- metadata.containsKey("llama.attention.head_count_kv") ?
- (int) metadata.get("llama.attention.head_count_kv") :
- (int) metadata.get("llama.attention.head_count"),
+ @Override
+ public void forward(State state, int token, int position) {
+ InferenceCore.forwardJava(this, state, token, position);
+ }
- vocabulary.size(),
- (int) metadata.get("llama.context_length"),
- (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
- (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
- ).withContextLength(contextLength);
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config);
- }
- return new Llama(config, tokenizer, weights);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ @Override
+ public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) {
+ return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
}
- // @formatter:on
-
}
diff --git a/src/main/java/com/example/model/llama/LlamaConfiguration.java b/src/main/java/com/example/model/llama/LlamaConfiguration.java
index 53195d1..a363bc6 100644
--- a/src/main/java/com/example/model/llama/LlamaConfiguration.java
+++ b/src/main/java/com/example/model/llama/LlamaConfiguration.java
@@ -5,6 +5,16 @@
public record LlamaConfiguration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta)
implements Configuration {
+ @Override
+ public int numberOfHeadsKey() {
+ throw new UnsupportedOperationException("Not supported for Llama.");
+ }
+
+ @Override
+ public int contextLengthModel() {
+ throw new UnsupportedOperationException("Not supported for Llama.");
+ }
+
public int headSize() {
return dim / numberOfHeads;
}
diff --git a/src/main/java/com/example/model/loader/LlamaModelLoader.java b/src/main/java/com/example/model/loader/LlamaModelLoader.java
new file mode 100644
index 0000000..f6cb245
--- /dev/null
+++ b/src/main/java/com/example/model/loader/LlamaModelLoader.java
@@ -0,0 +1,60 @@
+package com.example.model.loader;
+
+import com.example.auxiliary.Timer;
+import com.example.core.model.GGUF;
+import com.example.core.model.tensor.GGMLTensorEntry;
+import com.example.inference.weights.Weights;
+import com.example.model.format.ChatFormat;
+import com.example.model.llama.Llama;
+import com.example.model.llama.LlamaConfiguration;
+import com.example.tokenizer.impl.LlamaTokenizer;
+import com.example.tokenizer.impl.Tokenizer;
+import com.example.tokenizer.vocabulary.Vocabulary;
+
+import java.io.IOException;
+import java.nio.channels.FileChannel;
+import java.util.Map;
+
+public class LlamaModelLoader extends ModelLoader {
+
+ public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
+ super(fileChannel, gguf, contextLength, loadWeights);
+ }
+
+ // @formatter:off
+ @Override
+ public Llama loadModel() {
+ try (var ignored = Timer.log("Load LlaMa model")) {
+ Map metadata = gguf.getMetadata();
+
+ Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata);
+ Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary);
+
+ LlamaConfiguration config = new LlamaConfiguration(
+ (int) metadata.get("llama.embedding_length"),
+ (int) metadata.get("llama.feed_forward_length"),
+ (int) metadata.get("llama.block_count"),
+ (int) metadata.get("llama.attention.head_count"),
+
+ metadata.containsKey("llama.attention.head_count_kv") ?
+ (int) metadata.get("llama.attention.head_count_kv") :
+ (int) metadata.get("llama.attention.head_count"),
+
+ vocabulary.size(),
+ (int) metadata.get("llama.context_length"),
+ (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
+ (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
+ ).withContextLength(contextLength);
+
+ Weights weights = null;
+ if (loadWeights) {
+ Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
+ weights = loadWeights(tensorEntries, config);
+ }
+ return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ // @formatter:on
+}
diff --git a/src/main/java/com/example/model/loader/MistralModelLoader.java b/src/main/java/com/example/model/loader/MistralModelLoader.java
new file mode 100644
index 0000000..159ad9a
--- /dev/null
+++ b/src/main/java/com/example/model/loader/MistralModelLoader.java
@@ -0,0 +1,66 @@
+package com.example.model.loader;
+
+import com.example.auxiliary.Timer;
+import com.example.core.model.GGUF;
+import com.example.core.model.tensor.GGMLTensorEntry;
+import com.example.inference.weights.Weights;
+import com.example.model.format.ChatFormat;
+import com.example.model.mistral.Mistral;
+import com.example.model.mistral.MistralConfiguration;
+import com.example.tokenizer.impl.MistralTokenizer;
+import com.example.tokenizer.impl.Tokenizer;
+import com.example.tokenizer.vocabulary.Vocabulary;
+
+import java.io.IOException;
+import java.nio.channels.FileChannel;
+import java.util.Map;
+
+public class MistralModelLoader extends ModelLoader {
+
+ public MistralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
+ super(fileChannel, gguf, contextLength, loadWeights);
+ }
+
+ // @formatter:off
+ @Override
+ public Mistral loadModel() {
+ try (var ignored = Timer.log("Load Mistral model")) {
+ Map metadata = gguf.getMetadata();
+
+ Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata);
+ Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary);
+
+ int modelContextLength = (int) metadata.get("llama.context_length");
+ if (contextLength < 0 || modelContextLength < contextLength) {
+ contextLength = modelContextLength;
+ }
+
+ MistralConfiguration config = new MistralConfiguration(
+ (int) metadata.get("llama.embedding_length"),
+ (int) metadata.get("llama.feed_forward_length"),
+ (int) metadata.get("llama.block_count"),
+ (int) metadata.get("llama.attention.head_count"),
+
+ metadata.containsKey("llama.attention.head_count_kv")
+ ? (int) metadata.get("llama.attention.head_count_kv")
+ : (int) metadata.get("llama.attention.head_count"),
+
+ vocabulary.size(),
+ contextLength,
+ false,
+ (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
+ (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
+ );
+
+ Weights weights = null;
+ if (loadWeights) {
+ Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
+ weights = loadWeights(tensorEntries, config);
+ }
+ return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ // @formatter:on
+}
diff --git a/src/main/java/com/example/loader/weights/ModelLoader.java b/src/main/java/com/example/model/loader/ModelLoader.java
similarity index 79%
rename from src/main/java/com/example/loader/weights/ModelLoader.java
rename to src/main/java/com/example/model/loader/ModelLoader.java
index 353600e..6482414 100644
--- a/src/main/java/com/example/loader/weights/ModelLoader.java
+++ b/src/main/java/com/example/model/loader/ModelLoader.java
@@ -1,14 +1,20 @@
-package com.example.loader.weights;
+package com.example.model.loader;
import com.example.LlamaApp;
import com.example.core.model.GGMLType;
import com.example.core.model.GGUF;
+import com.example.core.model.tensor.ArrayFloatTensor;
import com.example.core.model.tensor.F16FloatTensor;
+import com.example.core.model.tensor.F32FloatTensor;
import com.example.core.model.tensor.FloatTensor;
import com.example.core.model.tensor.GGMLTensorEntry;
import com.example.core.model.tensor.Q4_0FloatTensor;
import com.example.core.model.tensor.Q8_0FloatTensor;
import com.example.core.types.Pair;
+import com.example.inference.weights.standard.LlamaStandardWeights;
+import com.example.inference.weights.tornado.LlamaTornadoWeights;
+import com.example.inference.weights.tornado.TornadoWeights;
+import com.example.inference.weights.Weights;
import com.example.model.Configuration;
import com.example.model.Model;
import com.example.model.ModelType;
@@ -27,10 +33,22 @@
import java.util.Map;
import java.util.function.IntFunction;
-public final class ModelLoader {
+public abstract class ModelLoader {
private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
private static final String TOKENIZER_MISTRAL_MODEL = "llama";
+ protected FileChannel fileChannel;
+ GGUF gguf;
+ int contextLength;
+ boolean loadWeights;
+
+ public ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
+ this.fileChannel = fileChannel;
+ this.gguf = gguf;
+ this.contextLength = contextLength;
+ this.loadWeights = loadWeights;
+ }
+
private static ModelType detectModelType(Map metadata) {
String name = (String) metadata.get("general.name");
String tokenizerModel = (String) metadata.get("tokenizer.ggml.model");
@@ -43,6 +61,8 @@ private static ModelType detectModelType(Map metadata) {
return ModelType.MISTRAL;
} else if (lowerName.contains("llama")) {
return ModelType.LLAMA_3;
+ } else if (lowerName.contains("qwen3")) {
+ return ModelType.QWEN_3;
}
}
@@ -65,6 +85,8 @@ private static ModelType detectModelType(Map metadata) {
return ModelType.UNKNOWN;
}
+ public abstract Model loadModel();
+
public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException {
// initial load of metadata from gguf file
GGUF gguf = GGUF.loadModel(ggufPath);
@@ -75,7 +97,7 @@ public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeig
return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights);
}
- public static Weights loadWeights(Map tensorEntries, Configuration config) {
+ public Weights loadWeights(Map tensorEntries, Configuration config) {
boolean ropeScaling = tensorEntries.containsKey("rope_freqs");
RopeConfig ropeConfig = new RopeConfig(8.0f, // scaleFactor
1.0f, // loFreqFactor
@@ -83,14 +105,15 @@ public static Weights loadWeights(Map tensorEntries, Co
8192 // oldContextLength
);
- Pair ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength(), // Maximum sequence length the model can process
- config.headSize(), // Dimension of each attention head
- config.ropeTheta(), // Base frequency parameter (typically 10000.0)
- ropeScaling, // Whether to apply frequency scaling (determined by model type)
- ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling)
- ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies
- ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision
- ropeConfig.oldContextLength // Original context length the model was trained with
+ Pair ropeFreqs = RoPE.precomputeFreqsCis(
+ config.contextLength(), // Maximum sequence length the model can process
+ config.headSize(), // Dimension of each attention head
+ config.ropeTheta(), // Base frequency parameter (typically 10000.0)
+ ropeScaling, // Whether to apply frequency scaling (determined by model type)
+ ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling)
+ ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies
+ ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision
+ ropeConfig.oldContextLength // Original context length the model was trained with
);
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
@@ -104,9 +127,9 @@ public static Weights loadWeights(Map tensorEntries, Co
}
}
- private static Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
- return new Weights(
+ return new LlamaTornadoWeights(
// Load directly to TornadoVM format
loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
@@ -117,30 +140,37 @@ private static Weights createTornadoVMWeights(Map tenso
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
- FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType());
+ FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()) {
+ };
}
/**
* Creates weights in standard format only
*/
- private static Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
- return new Weights(loadQuantized(tokenEmbeddings), loadArrayOfFloatBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ return new LlamaStandardWeights(
+ loadQuantized(tokenEmbeddings),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayOfFloatBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), toFloatBuffer(tensorEntries.get("output_norm.weight")),
- FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType());
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadQuantized(tensorEntries.get("output_norm.weight")),
+ new ArrayFloatTensor(ropeFreqs.first()),
+ new ArrayFloatTensor(ropeFreqs.second()),
+ loadQuantized(outputWeight),
+ outputWeight.ggmlType());
}
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
GGMLType ggmlType = entry.ggmlType();
return switch (ggmlType) {
- // case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
+ case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
diff --git a/src/main/java/com/example/model/loader/Qwen3ModelLoader.java b/src/main/java/com/example/model/loader/Qwen3ModelLoader.java
new file mode 100644
index 0000000..51e5451
--- /dev/null
+++ b/src/main/java/com/example/model/loader/Qwen3ModelLoader.java
@@ -0,0 +1,146 @@
+package com.example.model.loader;
+
+import com.example.LlamaApp;
+import com.example.auxiliary.Timer;
+import com.example.core.model.GGMLType;
+import com.example.core.model.GGUF;
+import com.example.core.model.tensor.ArrayFloatTensor;
+import com.example.core.model.tensor.GGMLTensorEntry;
+import com.example.core.types.Pair;
+import com.example.inference.operation.RoPE;
+import com.example.inference.weights.standard.Qwen3StandardWeights;
+import com.example.inference.weights.Weights;
+import com.example.model.Configuration;
+import com.example.model.format.ChatFormat;
+import com.example.model.format.ChatFormat.ChatTokens;
+import com.example.model.qwen3.Qwen3;
+import com.example.model.qwen3.Qwen3Configuration;
+import com.example.tokenizer.impl.Qwen3Tokenizer;
+import com.example.tokenizer.impl.Tokenizer;
+import com.example.tokenizer.vocabulary.Vocabulary;
+
+import java.io.IOException;
+import java.nio.channels.FileChannel;
+import java.util.Map;
+
+import static com.example.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary;
+
+public class Qwen3ModelLoader extends ModelLoader {
+
+ public Qwen3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
+ super(fileChannel, gguf, contextLength, loadWeights);
+ }
+
+ @Override
+ public Qwen3 loadModel() {
+ try (var ignored = Timer.log("Load Qwen3 model")) {
+ Map metadata = gguf.getMetadata();
+
+ Vocabulary vocabulary = loadQwen3Vocabulary(metadata);
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
+
+ int modelContextLength = (int) metadata.get("qwen3.context_length");
+ if (contextLength < 0 || modelContextLength < contextLength) {
+ contextLength = modelContextLength;
+ }
+
+ //String modelName = ggufPath.getFileName().toString();
+ Qwen3Configuration config = new Qwen3Configuration(
+ //modelName,
+ (int) metadata.get("qwen3.embedding_length"),
+ (int) metadata.get("qwen3.feed_forward_length"),
+ (int) metadata.get("qwen3.block_count"),
+ (int) metadata.get("qwen3.attention.head_count"),
+
+ metadata.containsKey("qwen3.attention.head_count_kv")
+ ? (int) metadata.get("qwen3.attention.head_count_kv")
+ : (int) metadata.get("qwen3.attention.head_count"),
+ (int) metadata.get("qwen3.attention.key_length"),
+ (int) metadata.get("qwen3.attention.value_length"),
+
+ vocabulary.size(),
+ modelContextLength, contextLength,
+ false,
+ (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"),
+ (float) metadata.get("qwen3.rope.freq_base")
+ );
+
+ Weights weights = null;
+ if (loadWeights) {
+ Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
+ weights = loadWeights(tensorEntries, config);
+ }
+ // Qwen2.5-coder uses <|endoftext|> as stop-token.
+ ChatTokens chatTokens = isDeepSeekR1DistillQwen ?
+ new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") :
+ new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
+ return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public Weights loadWeights(Map tensorEntries, Configuration config) {
+ Pair ropeFreqs = RoPE.precomputeFreqsCis(
+ config.contextLengthModel(),
+ config.numberOfHeadsKey(),
+ config.ropeTheta(),
+ false,
+ 0,
+ 0,
+ 0,
+ 0
+ );
+
+ GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
+ GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
+
+ if (LlamaApp.USE_TORNADOVM) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
+ return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ } else {
+ return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ }
+ }
+
+ @Override
+ public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
+
+ @Override
+ public Weights createStandardWeights(Map tensorEntries,
+ Configuration config,
+ Pair ropeFreqs,
+ GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ float[] ropeFreqsReal = ropeFreqs.first();
+ float[] ropeFreqsImag = ropeFreqs.second();
+ return new Qwen3StandardWeights(
+ loadQuantized(tokenEmbeddings),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
+
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm
+
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
+ loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight
+ new ArrayFloatTensor(ropeFreqsReal),
+ new ArrayFloatTensor(ropeFreqsImag),
+ tensorEntries.containsKey("output.weight")
+ ? ModelLoader.loadQuantized(tensorEntries.get("output.weight"))
+ : loadQuantized(tokenEmbeddings), // weights are shared
+ null
+ );
+ }
+}
diff --git a/src/main/java/com/example/model/mistral/Mistral.java b/src/main/java/com/example/model/mistral/Mistral.java
index c1f9d09..9dc74d7 100644
--- a/src/main/java/com/example/model/mistral/Mistral.java
+++ b/src/main/java/com/example/model/mistral/Mistral.java
@@ -1,23 +1,22 @@
package com.example.model.mistral;
-import com.example.auxiliary.Timer;
-import com.example.core.model.GGUF;
-import com.example.core.model.tensor.GGMLTensorEntry;
+import com.example.inference.InferenceCore;
+import com.example.inference.InferenceEngine;
+import com.example.inference.sampler.Sampler;
+import com.example.inference.state.LlamaState;
+import com.example.inference.state.State;
+import com.example.inference.weights.Weights;
import com.example.model.Model;
-import com.example.loader.weights.State;
-import com.example.loader.weights.Weights;
import com.example.model.ModelType;
+import com.example.model.format.ChatFormat;
import com.example.tokenizer.impl.MistralTokenizer;
import com.example.tokenizer.impl.Tokenizer;
-import com.example.tokenizer.vocabulary.Vocabulary;
-import java.io.IOException;
-import java.nio.channels.FileChannel;
-import java.util.Map;
+import java.util.List;
+import java.util.Set;
+import java.util.function.IntConsumer;
-import static com.example.loader.weights.ModelLoader.loadWeights;
-
-public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model {
+public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model {
/* For explicit use */
private MistralTokenizer getAsMistralTokenizer() {
@@ -30,57 +29,25 @@ public ModelType getModelType() {
}
public State createNewState() {
- State state = new State(configuration(), -1);
+ State state = new LlamaState(configuration(), -1);
state.latestToken = tokenizer.getSpecialTokens().get("");
return state;
}
public State createNewState(int batchsize) {
- State state = new State(configuration(), batchsize);
+ State state = new LlamaState(configuration(), batchsize);
state.latestToken = tokenizer.getSpecialTokens().get("");
return state;
}
- // @formatter:off
- public static Mistral loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
- try (var ignored = Timer.log("Load Mistral model")) {
- Map metadata = gguf.getMetadata();
-
- Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata);
- Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary);
-
- int modelContextLength = (int) metadata.get("llama.context_length");
- if (contextLength < 0 || modelContextLength < contextLength) {
- contextLength = modelContextLength;
- }
-
- MistralConfiguration config = new MistralConfiguration(
- (int) metadata.get("llama.embedding_length"),
- (int) metadata.get("llama.feed_forward_length"),
- (int) metadata.get("llama.block_count"),
- (int) metadata.get("llama.attention.head_count"),
-
- metadata.containsKey("llama.attention.head_count_kv")
- ? (int) metadata.get("llama.attention.head_count_kv")
- : (int) metadata.get("llama.attention.head_count"),
-
- vocabulary.size(),
- contextLength,
- false,
- (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
- (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
- );
+ @Override
+ public void forward(State state, int token, int position) {
+ InferenceCore.forwardJava(this, state, token, position);
+ }
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config);
- }
- return new Mistral(config, tokenizer, weights);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ @Override
+ public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) {
+ return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
}
- // @formatter:on
}
diff --git a/src/main/java/com/example/model/mistral/MistralConfiguration.java b/src/main/java/com/example/model/mistral/MistralConfiguration.java
index d2c14d4..dad7e0b 100644
--- a/src/main/java/com/example/model/mistral/MistralConfiguration.java
+++ b/src/main/java/com/example/model/mistral/MistralConfiguration.java
@@ -13,6 +13,16 @@ public int kvMul() {
return numberOfHeads / numberOfKeyValueHeads;
}
+ @Override
+ public int numberOfHeadsKey() {
+ throw new UnsupportedOperationException("Not supported for Mistral.");
+ }
+
+ @Override
+ public int contextLengthModel() {
+ throw new UnsupportedOperationException("Not supported for Mistral.");
+ }
+
public int headSize() {
return dim / numberOfHeads;
}
diff --git a/src/main/java/com/example/model/qwen3/Qwen3.java b/src/main/java/com/example/model/qwen3/Qwen3.java
new file mode 100644
index 0000000..ad1fc35
--- /dev/null
+++ b/src/main/java/com/example/model/qwen3/Qwen3.java
@@ -0,0 +1,49 @@
+package com.example.model.qwen3;
+
+import com.example.inference.InferenceCore;
+import com.example.inference.InferenceEngine;
+import com.example.inference.sampler.Sampler;
+import com.example.inference.state.Qwen3State;
+import com.example.inference.state.State;
+import com.example.inference.weights.Weights;
+import com.example.model.Model;
+import com.example.model.ModelType;
+import com.example.model.format.ChatFormat;
+import com.example.tokenizer.impl.Tokenizer;
+
+import java.util.List;
+import java.util.Set;
+import java.util.function.IntConsumer;
+
+public record Qwen3(Qwen3Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model {
+
+ @Override
+ public ModelType getModelType() {
+ return ModelType.QWEN_3;
+ }
+
+ @Override
+ public State createNewState() {
+ State state = new Qwen3State(configuration(), -1);
+ state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader());
+ return state;
+ }
+
+ @Override
+ public State createNewState(int batchsize) {
+ State state = new Qwen3State(configuration(), batchsize);
+ state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader());
+ return state;
+ }
+
+ @Override
+ public void forward(State state, int token, int position) {
+ InferenceCore.forwardJavaQwen3(this, state, token, position);
+ }
+
+ @Override
+ public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) {
+ return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
+ }
+
+}
diff --git a/src/main/java/com/example/model/qwen3/Qwen3Configuration.java b/src/main/java/com/example/model/qwen3/Qwen3Configuration.java
new file mode 100644
index 0000000..f1cdf87
--- /dev/null
+++ b/src/main/java/com/example/model/qwen3/Qwen3Configuration.java
@@ -0,0 +1,37 @@
+package com.example.model.qwen3;
+
+import com.example.model.Configuration;
+
+public record Qwen3Configuration(int dim,
+ int hiddenDim,
+ int numberOfLayers,
+ int numberOfHeads,
+ int numberOfKeyValueHeads,
+ int numberOfHeadsKey,
+ int numberOfHeadsValue,
+ int vocabularySize,
+ int contextLengthModel,
+ int contextLength,
+ boolean sharedWeights,
+ float rmsNormEps,
+ float ropeTheta) implements Configuration {
+ @Override
+ public int headSize() {
+ throw new UnsupportedOperationException("Not supported for Qwen3.");
+ }
+
+ @Override
+ public int kvDim() {
+ throw new UnsupportedOperationException("Not supported for Qwen3.");
+ }
+
+ @Override
+ public int kvMul() {
+ throw new UnsupportedOperationException("Not supported for Qwen3.");
+ }
+
+ @Override
+ public int contextLengthModel() {
+ return contextLengthModel;
+ }
+}
diff --git a/src/main/java/com/example/tokenizer/impl/Qwen3Tokenizer.java b/src/main/java/com/example/tokenizer/impl/Qwen3Tokenizer.java
new file mode 100644
index 0000000..24f7cf2
--- /dev/null
+++ b/src/main/java/com/example/tokenizer/impl/Qwen3Tokenizer.java
@@ -0,0 +1,317 @@
+package com.example.tokenizer.impl;
+
+import com.example.auxiliary.Utf8Mask;
+import com.example.core.types.Pair;
+import com.example.tokenizer.vocabulary.Vocabulary;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+public class Qwen3Tokenizer implements Tokenizer {
+ private final static String QWEN3_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
+ private final Pattern compiledPattern;
+ private final Vocabulary vocabulary;
+ private final Map, Integer> merges;
+ private final Map specialTokens;
+ private final int[] tokenTypes;
+
+ /** buffer to store incomplete UTF-8 sequence */
+ private final byte[] bufUtf8 = new byte[4];
+ /** index in UTF-8 buffer */
+ private int currUtf8Index = 0;
+ /** current UTF-8 mask */
+ private Utf8Mask currUtf8Mask;
+
+ @Override
+ public String regexPattern() {
+ if (compiledPattern == null) {
+ return null;
+ }
+ return compiledPattern.pattern();
+ }
+
+ @Override
+ public Map getSpecialTokens() {
+ return specialTokens;
+ }
+
+ @Override
+ public boolean isSpecialToken(int tokenIndex) {
+ return specialTokens.containsValue(tokenIndex);
+ }
+
+ @Override
+ public boolean shouldDisplayToken(int token) {
+ int tokenType = getTokenType(token);
+
+ return tokenType == 1 || tokenType == 6;
+ }
+
+ public int getTokenType(int tokenIndex) {
+ if (tokenTypes == null) {
+ throw new IllegalStateException("Qwen3Tokenizer hasn't been constructed using tokenTypes");
+ }
+ return tokenTypes[tokenIndex];
+ }
+
+ public Qwen3Tokenizer(Map metadata, Vocabulary vocabulary, boolean isDeepSeekR1DistillQwen) {
+ int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
+ String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
+ List> merges = Arrays.stream(mergeLines)
+ .map(line -> line.split(" "))
+ .map(parts ->
+ new Pair<>(
+ vocabulary.getIndex(parts[0]).orElseThrow(),
+ vocabulary.getIndex(parts[1]).orElseThrow())
+ ).toList();
+
+ int allTokens = vocabulary.size();
+ String firstSpecialToken = isDeepSeekR1DistillQwen ? "<|end▁of▁sentence|>" : "<|endoftext|>";
+ int baseTokens = vocabulary.getIndex(firstSpecialToken).orElseThrow(); // assume all tokens after the base ones are special.
+ // int reservedSpecialTokens = allTokens - baseTokens;
+ List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
+
+ assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
+
+ Map specialTokens =
+ IntStream.range(0, specialTokensList.size())
+ .boxed()
+ .collect(Collectors.toMap(
+ i -> specialTokensList.get(i),
+ i -> baseTokens + i)
+ );
+ specialTokens.remove("");
+ specialTokens.remove("");
+
+ this.vocabulary = vocabulary;
+ this.compiledPattern = Pattern.compile(QWEN3_PATTERN);
+ this.specialTokens = new HashMap<>(specialTokens);
+ this.merges = new HashMap<>();
+ this.tokenTypes = tokenTypes;
+ for (Pair pair : merges) {
+ int firstIndex = pair.first();
+ int secondIndex = pair.second();
+ int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow();
+ this.merges.put(pair, mergeIndex);
+ }
+ }
+
+ private int[] encodeImpl(String text) {
+ return encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
+ }
+
+ static List findAll(Pattern pattern, String text) {
+ List allMatches = new ArrayList<>();
+ Matcher matcher = pattern.matcher(text);
+ while (matcher.find()) {
+ allMatches.add(matcher.group());
+ }
+ return allMatches;
+ }
+
+ /**
+ * Encoding that ignores any special tokens.
+ */
+ public List encodeOrdinary(String text) {
+ // split text into chunks of text by categories defined in regex pattern
+ List textChunks = findAll(compiledPattern, text);
+ // all chunks of text are encoded separately, then results are joined
+ List ids = new ArrayList<>();
+ for (String chunk : textChunks) {
+ List chunkIds = encodeChunk(chunk);
+ ids.addAll(chunkIds);
+ }
+ return ids;
+ }
+
+ private Map, Integer> getStats(List ids) {
+ Map, Integer> map = new HashMap<>();
+ for (int i = 0; i + 1 < ids.size(); i++) {
+ Pair key = new Pair<>(ids.get(i), ids.get(i + 1));
+ map.put(key, map.getOrDefault(key, 0) + 1);
+ }
+ return map;
+ }
+
+ private List encodeChunk(String chunk) {
+ // return the token ids
+ // let's begin. first, convert all bytes to integers in range 0..255
+ List ids = new ArrayList<>();
+ for (int b : chunk.toCharArray()) {
+ int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow();
+ ids.add(tokenIndex);
+ }
+
+ while (ids.size() >= 2) {
+ // find the pair with the lowest merge index
+ Map, Integer> stats = getStats(ids);
+ Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow();
+ // subtle: if there are no more merges available, the key will
+ // result in an inf for every single pair, and the min will be
+ // just the first pair in the list, arbitrarily
+ // we can detect this terminating case by a membership check
+ if (!this.merges.containsKey(pair)) {
+ break; // nothing else can be merged anymore
+ }
+ // otherwise let's merge the best pair (lowest merge index)
+ int idx = this.merges.get(pair);
+ ids = merge(ids, pair, idx);
+ }
+ return ids;
+ }
+
+ static List merge(List ids, Pair pair, int idx) {
+ List newids = new ArrayList<>();
+ int i = 0;
+ while (i < ids.size()) {
+ // if not at the very last position AND the pair matches, replace it
+ if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
+ newids.add(idx);
+ i += 2;
+ } else {
+ newids.add(ids.get(i));
+ i += 1;
+ }
+ }
+ return newids;
+ }
+
+ /**
+ * Returns list of utf-8 byte and a corresponding list of unicode strings.
+ * The reversible bpe codes work on unicode strings.
+ * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ * This is a significant percentage of your normal, say, 32K bpe vocab.
+ * To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ * And avoids mapping to whitespace/control characters the bpe code barfs on.
+ */
+ static Map bytesToUnicode() {
+ List bs = new ArrayList<>();
+ IntStream.rangeClosed('!', '~').forEach(bs::add);
+ IntStream.rangeClosed('¡', '¬').forEach(bs::add);
+ IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
+
+ List cs = new ArrayList<>(bs);
+ int n = 0;
+ for (int b = 0; b < 256; ++b) {
+ if (!bs.contains(b)) {
+ bs.add(b);
+ cs.add(256 + n);
+ n += 1;
+ }
+ }
+
+ // return dict(zip(bs, cs))
+ return IntStream.range(0, bs.size())
+ .boxed()
+ .collect(Collectors.toMap(bs::get, cs::get));
+ }
+
+ static final Map BYTE_ENCODER = bytesToUnicode();
+ static final Map BYTE_DECODER = BYTE_ENCODER.entrySet()
+ .stream()
+ .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
+
+ public int[] encode(String text) {
+ StringBuilder sb = new StringBuilder();
+ byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
+ for (byte b : bytes) {
+ sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b)));
+ }
+ return encodeImpl(sb.toString());
+ }
+
+ @Override
+ public List encode(String text, Set allowedSpecial) {
+ // decode the user desire w.r.t. handling of special tokens
+ Set special = allowedSpecial;
+ assert getSpecialTokens().keySet().containsAll(special);
+ if (special.isEmpty()) {
+ // shortcut: if no special tokens, just use the ordinary encoding
+ return encodeOrdinary(text);
+ }
+
+ // otherwise, we have to be careful with potential special tokens in text
+ // we handle special tokens by splitting the text
+ // based on the occurrence of any exact match with any of the special tokens
+ // we can use re.split for this. note that surrounding the pattern with ()
+ // makes it into a capturing group, so the special tokens will be included
+ String specialPattern = special
+ .stream()
+ .map(Pattern::quote)
+ .collect(Collectors.joining("|", "(", ")"));
+
+ String[] specialChunks = text.split(specialPattern);
+ // now all the special characters are separated from the rest of the text
+ // all chunks of text are encoded separately, then results are joined
+ List ids = new ArrayList<>();
+ for (String part : specialChunks) {
+ if (special.contains(part)) {
+ // this is a special token, encode it separately as a special case
+ ids.add(getSpecialTokens().get(part));
+ } else {
+ // this is an ordinary sequence, encode it normally
+ ids.addAll(encodeOrdinary(part));
+ }
+ }
+ return ids;
+ }
+
+ @Override
+ public List encodeAsList(String text) {
+ return Arrays.stream(encode(text)).boxed().toList();
+ }
+
+ public String decodeImpl(List tokens) {
+ StringBuilder sb = new StringBuilder();
+ for (int token : tokens) {
+ String tokenString = vocabulary.get(token);
+ sb.append(tokenString);
+ }
+ return sb.toString();
+ }
+
+ @Override
+ public String decode(List tokens) {
+ String decoded = decodeImpl(tokens);
+ // The '|' in '<|end▁of▁sentence|>' of DeepSeek-R1 has code-point 65372.
+ int[] decodedBytesAsInts = decoded.codePoints().map(cp -> cp <= 512 ? BYTE_DECODER.get(cp) : cp).toArray();
+ byte[] rawBytes = new byte[decodedBytesAsInts.length + 3];
+ int indexRawByte = 0;
+ loopDecoded:
+ for (int i = 0; i < decoded.length(); i++) {
+ byte b = (byte) decodedBytesAsInts[i];
+ if (currUtf8Index == 0) {
+ for (Utf8Mask utf8Mask : Utf8Mask.MASKS) {
+ if ((b & utf8Mask.mask()) == utf8Mask.pattern()) {
+ currUtf8Mask = utf8Mask;
+ bufUtf8[currUtf8Index++] = b;
+ continue loopDecoded;
+ }
+ }
+ }
+ if (currUtf8Index > 0 && currUtf8Mask != null) {
+ bufUtf8[currUtf8Index++] = b;
+ if (currUtf8Index == currUtf8Mask.len()) {
+ System.arraycopy(bufUtf8, 0, rawBytes, indexRawByte, currUtf8Mask.len());
+ indexRawByte += currUtf8Mask.len();
+ currUtf8Index = 0;
+ currUtf8Mask = null;
+ }
+ continue;
+ }
+ rawBytes[indexRawByte++] = b;
+ }
+ return new String(rawBytes, 0, indexRawByte, StandardCharsets.UTF_8);
+ }
+}
diff --git a/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java b/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java
index a45c073..459fb69 100644
--- a/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java
+++ b/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java
@@ -40,6 +40,12 @@ public static Vocabulary loadMistralVocabulary(Map metadata) {
return v;
}
+ public static Vocabulary loadQwen3Vocabulary(Map metadata) {
+ String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens");
+ float[] scores = (float[]) metadata.get("tokenizer.ggml.scores");
+ return new Vocabulary(tokens, scores);
+ }
+
public int size() {
return tokens.length;
}
diff --git a/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java b/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java
index e1864a4..18618b8 100644
--- a/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java
+++ b/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java
@@ -1,10 +1,10 @@
package com.example.tornadovm;
import com.example.auxiliary.Tuple2;
+import com.example.inference.weights.tornado.TornadoWeights;
import com.example.model.Configuration;
import com.example.model.Model;
-import com.example.loader.weights.State;
-import com.example.loader.weights.Weights;
+import com.example.inference.state.State;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
@@ -49,7 +49,7 @@ public class TornadoVMLayerPlanner {
private final State state;
private final Configuration config;
- private final Weights weights;
+ private final TornadoWeights weights;
private final KernelContext context;
/**
@@ -63,7 +63,7 @@ public class TornadoVMLayerPlanner {
public TornadoVMLayerPlanner(State state, Model model) {
this.state = state;
this.config = model.configuration();
- this.weights = model.weights();
+ this.weights = (TornadoWeights) model.weights();
this.context = new KernelContext();
}
@@ -182,7 +182,7 @@ public Tuple2, GridScheduler> setupTornadoForwardPlanLa
*/
// @formatter:on
private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
- switch (weights.weightType) {
+ switch (weights.getWeightType()) {
case F16:
case Q8_0:
case Q4_0:
@@ -191,7 +191,7 @@ private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); //
break;
default:
- throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.weightType + ". Only Q8_0 and Q4_0 are supported.");
+ throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.getWeightType() + ". Only Q8_0 and Q4_0 are supported.");
}
return logits;
}
diff --git a/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java b/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java
index af167dd..594a15b 100644
--- a/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java
+++ b/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java
@@ -1,7 +1,7 @@
package com.example.tornadovm;
import com.example.auxiliary.Tuple2;
-import com.example.loader.weights.State;
+import com.example.inference.state.State;
import com.example.model.Configuration;
import com.example.model.Model;
import com.example.model.ModelType;