Skip to content

[WIP] Support for Qwen3 models #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/com/example/LlamaApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import com.example.inference.sampler.CategoricalSampler;
import com.example.inference.sampler.Sampler;
import com.example.inference.sampler.ToppSampler;
import com.example.loader.weights.ModelLoader;
import com.example.model.loader.ModelLoader;
import com.example.model.Model;
import com.example.tornadovm.FloatArrayUtils;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
Expand Down
15 changes: 10 additions & 5 deletions src/main/java/com/example/aot/AOT.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import com.example.auxiliary.Timer;
import com.example.core.model.GGUF;
import com.example.core.model.tensor.GGMLTensorEntry;
import com.example.model.loader.LlamaModelLoader;
import com.example.model.Model;
import com.example.Options;
import com.example.model.format.LlamaChatFormat;
import com.example.model.llama.Llama;
import com.example.loader.weights.ModelLoader;
import com.example.loader.weights.Weights;
import com.example.inference.weights.Weights;
import com.example.tokenizer.impl.LlamaTokenizer;

import java.io.IOException;
import java.nio.channels.FileChannel;
Expand All @@ -28,6 +30,8 @@
public final class AOT {
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;

static LlamaModelLoader modelLoader;


record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map<String, GGUF.GGUFTensorInfo> tensorInfos) {}

Expand All @@ -44,9 +48,10 @@ private static PartialModel preLoadGGUF(String modelPath) {
}
GGUF gguf = GGUF.loadModel(path);
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false);
return new PartialModel(
path.getFileName().toString(),
Llama.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), // TODO: needs proper handling for AOT
modelLoader.loadModel(), // TODO: needs proper handling for AOT
gguf.getTensorDataOffset(),
gguf.getTensorInfos()
);
Expand Down Expand Up @@ -77,8 +82,8 @@ public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IO
var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
// Load only the tensors (mmap slices).
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos());
Weights weights = ModelLoader.loadWeights(tensorEntries, baseModel.configuration());
return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights);
Weights weights = modelLoader.loadWeights(tensorEntries, baseModel.configuration());
return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights, new LlamaChatFormat((LlamaTokenizer) baseModel.tokenizer()));
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/com/example/auxiliary/Utf8Mask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.example.auxiliary;

/** mask of a byte-sequence in UTF-8 encoding */
public record Utf8Mask(int mask, int pattern, int len) {
public static final Utf8Mask[] MASKS = {
new Utf8Mask(0b11100000, 0b11000000, 2),
new Utf8Mask(0b11110000, 0b11100000, 3),
new Utf8Mask(0b11111000, 0b11110000, 4)
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public final class ArrayFloatTensor extends FloatTensor {

final float[] values;

ArrayFloatTensor(float[] values) {
public ArrayFloatTensor(float[] values) {
this.values = values;
}

Expand Down
48 changes: 48 additions & 0 deletions src/main/java/com/example/core/model/tensor/F32FloatTensor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.example.core.model.tensor;

import com.example.core.model.GGMLType;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorSpecies;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;

public final class F32FloatTensor extends FloatTensor {
final int size;
final MemorySegment segment;

public F32FloatTensor(int size, MemorySegment segment) {
this.size = size;
this.segment = segment;
}

@Override
public int size() {
return size;
}

@Override
public GGMLType type() {
return GGMLType.F32;
}

@Override
public MemorySegment asMemorySegment() {
return null;
}

@Override
public float getFloat(int index) {
return segment.get(ValueLayout.OfFloat.JAVA_FLOAT, index * Float.BYTES);
}

@Override
public void setFloat(int index, float value) {
segment.set(ValueLayout.OfFloat.JAVA_FLOAT, index * Float.BYTES, value);
}

@Override
protected FloatVector getFloatVector(VectorSpecies<Float> species, int offset) {
throw new UnsupportedOperationException("getFloatVector is not yet implemented.");
}
}
182 changes: 167 additions & 15 deletions src/main/java/com/example/inference/InferenceCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,35 @@

import com.example.auxiliary.Parallel;
import com.example.core.model.tensor.FloatTensor;
import com.example.loader.weights.State;
import com.example.loader.weights.Weights;
import com.example.inference.state.State;
import com.example.inference.weights.standard.Qwen3StandardWeights;
import com.example.inference.weights.standard.StandardWeights;
import com.example.inference.weights.tornado.TornadoWeights;
import com.example.model.Configuration;
import com.example.model.Model;
import com.example.model.qwen3.Qwen3Configuration;
import com.example.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

import java.lang.foreign.MemorySegment;
import java.nio.FloatBuffer;

/**
* Low-level operations for model inference.
*
* <p>
* Provides core computational operations: RMS normalization and forward passes
* through model layers. Supports both CPU and GPU implementations.
* This class provides core computational operations such as RMS normalization and
* forward passes through model layers. It supports both CPU and GPU implementations.
* </p>
*
* <p>
* Specifically, it implements:
* <ul>
* <li>{@code rmsnorm} – applies Root Mean Square Layer Normalization to input vectors</li>
* <li>{@code forwardJava} – executes a Forward pass for LLaMA and Mistral models on CPU</li>
* <li>{@code forwardJavaQwen3} – executes a Forward pass for Qwen3 models on CPU</li>
* <li>{@code forwardTornadoVM} – executes a Forward pass using TornadoVM for GPU acceleration</li>
* </ul>
* </p>
*/

public final class InferenceCore {
Expand All @@ -26,21 +39,21 @@ private InferenceCore() {
// prevent instantiation
}

public static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
public static void rmsnorm(FloatTensor out, FloatTensor x, FloatTensor weight, int offset, int size, float rmsNormEps) {
// calculate sum of squares
float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
float ss = x.reduce(offset, size, 0f, (acc, xi) -> acc + xi * xi);
ss /= size;
ss += rmsNormEps;
ss = (float) (1.0 / Math.sqrt(ss));
// normalize and scale
final float finalss = ss; // for the lambda
out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index)));
out.mapWithIndexInPlace(offset, size, (value, index) -> weight.getFloat(index % size) * (finalss * x.getFloat(index)));
}

public static FloatTensor forwardJava(Model model, State state, int token, int position) {
// a few convenience variables
final Configuration config = model.configuration();
final Weights weights = model.weights();
final StandardWeights weights = (StandardWeights) model.weights();
int dim = config.dim();
int headSize = config.headSize();
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
Expand All @@ -53,7 +66,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
// forward all the layers
for (int l = 0; l < config.numberOfLayers(); l++) {
// attention rmsnorm
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps());
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());

// qkv matmuls for this position

Expand All @@ -64,8 +77,8 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i += 2) {
int head_dim = i % headSize;
float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2));
float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2));
float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2));
float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2));
int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int v = 0; v < rotn; v++) {
FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key)
Expand Down Expand Up @@ -133,7 +146,146 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
state.x.addInPlace(state.xb2);

// ffn rmsnorm
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps());
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());

// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);

// SwiGLU non-linearity
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));

// elementwise multiply with w3(x)
state.hb.multiplyInPlace(state.hb2);

// final matmul to get the output of the ffn
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());

// residual connection
state.x.addInPlace(state.xb);
}

rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());

weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);

return state.logits;
}

public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) {
// a few convenience variables
final Qwen3Configuration config = (Qwen3Configuration) model.configuration(); // same
final Qwen3StandardWeights weights = (Qwen3StandardWeights) model.weights(); // same
int dim = config.dim(); // same
int nHeadKv = config.numberOfKeyValueHeads(); // n_head_kv = numberOfKeyValueHeads
int nEmbdHeadK = config.numberOfHeadsKey(); // n_embd_head_k = n_embd / n_head; %s.attention.key_length
int nEmbdHeadV = config.numberOfHeadsValue(); // n_embd_head_v = n_embd / n_head; %s.attention.value_length
int nEmbdVGqa = nEmbdHeadV * nHeadKv; // n_embd_v_gqa = n_embd_head_v * n_head_kv
int nEmbdHead = nEmbdHeadV;
int nEmbdGqa = nEmbdVGqa;
int gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
float sqrtHeadSize = (float) Math.sqrt(nEmbdHead);

// copy the token embedding into x
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);

// forward all the layers
for (int l = 0; l < config.numberOfLayers(); l++) {
// attention rmsnorm
final int curLayer = l;
rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps());

// qkv matmuls for this position
weights.wq[curLayer].matmul(state.xb, state.q, nEmbdHeadK * config.numberOfHeads(), dim);
weights.wk[curLayer].matmul(state.xb, state.k, nEmbdGqa, dim);
weights.wv[curLayer].matmul(state.xb, state.v, nEmbdGqa, dim);

// Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
for (int i = 0; i < config.numberOfHeads(); i++) {
rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
}
// Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
for (int i = 0; i < config.numberOfKeyValueHeads(); i++) {
rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
}

// RoPE relative positional encoding: complex-valued rotate q and k in each head
//for (int i = 0; i < config.numberOfHeads(); i += 2) {
for (int h = 0; h < config.numberOfHeads(); ++h) {
int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
int poffset = h * nEmbdHead;
int nComplEmbdHead = nEmbdHead / 2;
for (int ic = 0; ic < nComplEmbdHead; ic++) {
float fcr = weights.freq_cis_real.getFloat(position * nComplEmbdHead + ic);
float fci = weights.freq_cis_imag.getFloat(position * nComplEmbdHead + ic);
for (int vi = 0; vi < rotn; vi++) {
FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key)
float v0 = vec.getFloat(poffset + ic);
float v1 = vec.getFloat(poffset + ic + nComplEmbdHead);
vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
vec.setFloat(poffset + ic + nComplEmbdHead, v0 * fci + v1 * fcr);
}
}
}

// save key,value at this time step (position) to our kv cache
//int loff = l * config.seq_len * kvDim;
// kv cache layer offset for convenience
state.k.copyTo(0, state.keyCache[curLayer], position * nEmbdGqa, nEmbdGqa);
state.v.copyTo(0, state.valueCache[curLayer], position * nEmbdGqa, nEmbdGqa);

// multihead attention. iterate over all heads
// Process tokens one by one instead of in parallel
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
// get the query vector for this head
// float* q = s.q + h * headSize;
int qOffset = h * nEmbdHead;
// attention scores for this head
// float* att = s.att + h * config.seq_len;
int attOffset = h * config.contextLength();

// iterate over all timesteps, including the current one
for (int t = 0; t <= position; t++) {
// get the key vector for this head and at this timestep
// float* k = s.key_cache + loff + t * dim + h * headSize;
int keyCacheOffset = /* loff + */ (t * nEmbdGqa + (h / gqa) * nEmbdHead);
// calculate the attention score as the dot product of q and k
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, nEmbdHeadK);
//state.kq.setFloat(h + t, score);
score /= sqrtHeadSize;
// save the score to the attention buffer
state.att.setFloat(attOffset + t, score);
}

state.att.softmaxInPlace(attOffset, position + 1); // position + 0 + 1

// weighted sum of the values, store back into xb
// float* xb = s.xb + h * headSize;
int xbOffset = h * nEmbdHeadV;
// memset(xb, 0, headSize * sizeof(float));
state.xb.fillInPlace(xbOffset, nEmbdHeadV, 0f);

for (int t = 0; t <= position; t++) {
// get the value vector for this head and at this timestep
// float* v = s.value_cache + loff + t * dim + h * headSize;C
int vOffset = /* loff + */ t * nEmbdGqa + (h / gqa) * nEmbdHeadV;
// get the attention weight for this timestep
float a = state.att.getFloat(attOffset + t);
// accumulate the weighted value into xb
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, nEmbdHeadV, a);
}
});

// final matmul to get the output of the attention
weights.wo[l].matmul(state.xb, state.xb2, dim, nEmbdHeadK * config.numberOfHeads());

// residual connection back into x
state.x.addInPlace(state.xb2);

// ffn rmsnorm
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps());

// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
Expand All @@ -154,7 +306,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
state.x.addInPlace(state.xb);
}

rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps());
rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());

weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);

Expand Down Expand Up @@ -188,7 +340,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
*/
public static FloatArray forwardTornadoVM(Model model, State state, int token, int position, TornadoVMMasterPlan tornadoVMMasterPlan) {
final Configuration configuration = model.configuration();
final Weights weights = model.weights();
final TornadoWeights weights = (TornadoWeights) model.weights();

MemorySegment.copy(weights.tokenEmbeddingTable.getSegment(), token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);

Expand Down
Loading