diff --git a/src/main/java/com/example/LlamaApp.java b/src/main/java/com/example/LlamaApp.java index 5ea0cb2..4a8b37c 100644 --- a/src/main/java/com/example/LlamaApp.java +++ b/src/main/java/com/example/LlamaApp.java @@ -5,7 +5,7 @@ import com.example.inference.sampler.CategoricalSampler; import com.example.inference.sampler.Sampler; import com.example.inference.sampler.ToppSampler; -import com.example.loader.weights.ModelLoader; +import com.example.model.loader.ModelLoader; import com.example.model.Model; import com.example.tornadovm.FloatArrayUtils; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/com/example/aot/AOT.java b/src/main/java/com/example/aot/AOT.java index 837a56b..42e7ffc 100644 --- a/src/main/java/com/example/aot/AOT.java +++ b/src/main/java/com/example/aot/AOT.java @@ -3,11 +3,13 @@ import com.example.auxiliary.Timer; import com.example.core.model.GGUF; import com.example.core.model.tensor.GGMLTensorEntry; +import com.example.model.loader.LlamaModelLoader; import com.example.model.Model; import com.example.Options; +import com.example.model.format.LlamaChatFormat; import com.example.model.llama.Llama; -import com.example.loader.weights.ModelLoader; -import com.example.loader.weights.Weights; +import com.example.inference.weights.Weights; +import com.example.tokenizer.impl.LlamaTokenizer; import java.io.IOException; import java.nio.channels.FileChannel; @@ -28,6 +30,8 @@ public final class AOT { AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; + static LlamaModelLoader modelLoader; + record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map tensorInfos) {} @@ -44,9 +48,10 @@ private static PartialModel preLoadGGUF(String modelPath) { } GGUF gguf = GGUF.loadModel(path); try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) { + modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false); return new PartialModel( path.getFileName().toString(), - Llama.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), // TODO: needs proper handling for AOT + modelLoader.loadModel(), // TODO: needs proper handling for AOT gguf.getTensorDataOffset(), gguf.getTensorInfos() ); @@ -77,8 +82,8 @@ public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IO var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) { // Load only the tensors (mmap slices). Map tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos()); - Weights weights = ModelLoader.loadWeights(tensorEntries, baseModel.configuration()); - return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights); + Weights weights = modelLoader.loadWeights(tensorEntries, baseModel.configuration()); + return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights, new LlamaChatFormat((LlamaTokenizer) baseModel.tokenizer())); } } } diff --git a/src/main/java/com/example/auxiliary/Utf8Mask.java b/src/main/java/com/example/auxiliary/Utf8Mask.java new file mode 100644 index 0000000..d5ee242 --- /dev/null +++ b/src/main/java/com/example/auxiliary/Utf8Mask.java @@ -0,0 +1,10 @@ +package com.example.auxiliary; + +/** mask of a byte-sequence in UTF-8 encoding */ +public record Utf8Mask(int mask, int pattern, int len) { + public static final Utf8Mask[] MASKS = { + new Utf8Mask(0b11100000, 0b11000000, 2), + new Utf8Mask(0b11110000, 0b11100000, 3), + new Utf8Mask(0b11111000, 0b11110000, 4) + }; +} diff --git a/src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java b/src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java index 6efeea9..b8cfa23 100644 --- a/src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java +++ b/src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java @@ -13,7 +13,7 @@ public final class ArrayFloatTensor extends FloatTensor { final float[] values; - ArrayFloatTensor(float[] values) { + public ArrayFloatTensor(float[] values) { this.values = values; } diff --git a/src/main/java/com/example/core/model/tensor/F32FloatTensor.java b/src/main/java/com/example/core/model/tensor/F32FloatTensor.java new file mode 100644 index 0000000..b650c36 --- /dev/null +++ b/src/main/java/com/example/core/model/tensor/F32FloatTensor.java @@ -0,0 +1,48 @@ +package com.example.core.model.tensor; + +import com.example.core.model.GGMLType; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.VectorSpecies; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +public final class F32FloatTensor extends FloatTensor { + final int size; + final MemorySegment segment; + + public F32FloatTensor(int size, MemorySegment segment) { + this.size = size; + this.segment = segment; + } + + @Override + public int size() { + return size; + } + + @Override + public GGMLType type() { + return GGMLType.F32; + } + + @Override + public MemorySegment asMemorySegment() { + return null; + } + + @Override + public float getFloat(int index) { + return segment.get(ValueLayout.OfFloat.JAVA_FLOAT, index * Float.BYTES); + } + + @Override + public void setFloat(int index, float value) { + segment.set(ValueLayout.OfFloat.JAVA_FLOAT, index * Float.BYTES, value); + } + + @Override + protected FloatVector getFloatVector(VectorSpecies species, int offset) { + throw new UnsupportedOperationException("getFloatVector is not yet implemented."); + } +} diff --git a/src/main/java/com/example/inference/InferenceCore.java b/src/main/java/com/example/inference/InferenceCore.java index 81e432c..b135e1c 100644 --- a/src/main/java/com/example/inference/InferenceCore.java +++ b/src/main/java/com/example/inference/InferenceCore.java @@ -2,22 +2,35 @@ import com.example.auxiliary.Parallel; import com.example.core.model.tensor.FloatTensor; -import com.example.loader.weights.State; -import com.example.loader.weights.Weights; +import com.example.inference.state.State; +import com.example.inference.weights.standard.Qwen3StandardWeights; +import com.example.inference.weights.standard.StandardWeights; +import com.example.inference.weights.tornado.TornadoWeights; import com.example.model.Configuration; import com.example.model.Model; +import com.example.model.qwen3.Qwen3Configuration; import com.example.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import java.lang.foreign.MemorySegment; -import java.nio.FloatBuffer; /** * Low-level operations for model inference. * *

- * Provides core computational operations: RMS normalization and forward passes - * through model layers. Supports both CPU and GPU implementations. + * This class provides core computational operations such as RMS normalization and + * forward passes through model layers. It supports both CPU and GPU implementations. + *

+ * + *

+ * Specifically, it implements: + *

    + *
  • {@code rmsnorm} – applies Root Mean Square Layer Normalization to input vectors
  • + *
  • {@code forwardJava} – executes a Forward pass for LLaMA and Mistral models on CPU
  • + *
  • {@code forwardJavaQwen3} – executes a Forward pass for Qwen3 models on CPU
  • + *
  • {@code forwardTornadoVM} – executes a Forward pass using TornadoVM for GPU acceleration
  • + *
+ *

*/ public final class InferenceCore { @@ -26,21 +39,21 @@ private InferenceCore() { // prevent instantiation } - public static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) { + public static void rmsnorm(FloatTensor out, FloatTensor x, FloatTensor weight, int offset, int size, float rmsNormEps) { // calculate sum of squares - float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi); + float ss = x.reduce(offset, size, 0f, (acc, xi) -> acc + xi * xi); ss /= size; ss += rmsNormEps; ss = (float) (1.0 / Math.sqrt(ss)); // normalize and scale final float finalss = ss; // for the lambda - out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index))); + out.mapWithIndexInPlace(offset, size, (value, index) -> weight.getFloat(index % size) * (finalss * x.getFloat(index))); } public static FloatTensor forwardJava(Model model, State state, int token, int position) { // a few convenience variables final Configuration config = model.configuration(); - final Weights weights = model.weights(); + final StandardWeights weights = (StandardWeights) model.weights(); int dim = config.dim(); int headSize = config.headSize(); int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); @@ -53,7 +66,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p // forward all the layers for (int l = 0; l < config.numberOfLayers(); l++) { // attention rmsnorm - rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps()); + rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps()); // qkv matmuls for this position @@ -64,8 +77,8 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p // RoPE relative positional encoding: complex-valued rotate q and k in each head for (int i = 0; i < dim; i += 2) { int head_dim = i % headSize; - float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2)); - float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2)); + float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2)); int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only for (int v = 0; v < rotn; v++) { FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key) @@ -133,7 +146,146 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p state.x.addInPlace(state.xb2); // ffn rmsnorm - rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps()); + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps()); + + // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) + // first calculate self.w1(x) and self.w3(x) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + + // SwiGLU non-linearity + // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + + // elementwise multiply with w3(x) + state.hb.multiplyInPlace(state.hb2); + + // final matmul to get the output of the ffn + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // residual connection + state.x.addInPlace(state.xb); + } + + rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps()); + + weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); + + return state.logits; + } + + public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) { + // a few convenience variables + final Qwen3Configuration config = (Qwen3Configuration) model.configuration(); // same + final Qwen3StandardWeights weights = (Qwen3StandardWeights) model.weights(); // same + int dim = config.dim(); // same + int nHeadKv = config.numberOfKeyValueHeads(); // n_head_kv = numberOfKeyValueHeads + int nEmbdHeadK = config.numberOfHeadsKey(); // n_embd_head_k = n_embd / n_head; %s.attention.key_length + int nEmbdHeadV = config.numberOfHeadsValue(); // n_embd_head_v = n_embd / n_head; %s.attention.value_length + int nEmbdVGqa = nEmbdHeadV * nHeadKv; // n_embd_v_gqa = n_embd_head_v * n_head_kv + int nEmbdHead = nEmbdHeadV; + int nEmbdGqa = nEmbdVGqa; + int gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery + float sqrtHeadSize = (float) Math.sqrt(nEmbdHead); + + // copy the token embedding into x + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + // forward all the layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // attention rmsnorm + final int curLayer = l; + rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps()); + + // qkv matmuls for this position + weights.wq[curLayer].matmul(state.xb, state.q, nEmbdHeadK * config.numberOfHeads(), dim); + weights.wk[curLayer].matmul(state.xb, state.k, nEmbdGqa, dim); + weights.wv[curLayer].matmul(state.xb, state.v, nEmbdGqa, dim); + + // Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + for (int i = 0; i < config.numberOfHeads(); i++) { + rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps()); + } + // Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + for (int i = 0; i < config.numberOfKeyValueHeads(); i++) { + rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps()); + } + + // RoPE relative positional encoding: complex-valued rotate q and k in each head + //for (int i = 0; i < config.numberOfHeads(); i += 2) { + for (int h = 0; h < config.numberOfHeads(); ++h) { + int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + int poffset = h * nEmbdHead; + int nComplEmbdHead = nEmbdHead / 2; + for (int ic = 0; ic < nComplEmbdHead; ic++) { + float fcr = weights.freq_cis_real.getFloat(position * nComplEmbdHead + ic); + float fci = weights.freq_cis_imag.getFloat(position * nComplEmbdHead + ic); + for (int vi = 0; vi < rotn; vi++) { + FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key) + float v0 = vec.getFloat(poffset + ic); + float v1 = vec.getFloat(poffset + ic + nComplEmbdHead); + vec.setFloat(poffset + ic, v0 * fcr - v1 * fci); + vec.setFloat(poffset + ic + nComplEmbdHead, v0 * fci + v1 * fcr); + } + } + } + + // save key,value at this time step (position) to our kv cache + //int loff = l * config.seq_len * kvDim; + // kv cache layer offset for convenience + state.k.copyTo(0, state.keyCache[curLayer], position * nEmbdGqa, nEmbdGqa); + state.v.copyTo(0, state.valueCache[curLayer], position * nEmbdGqa, nEmbdGqa); + + // multihead attention. iterate over all heads + // Process tokens one by one instead of in parallel + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + // get the query vector for this head + // float* q = s.q + h * headSize; + int qOffset = h * nEmbdHead; + // attention scores for this head + // float* att = s.att + h * config.seq_len; + int attOffset = h * config.contextLength(); + + // iterate over all timesteps, including the current one + for (int t = 0; t <= position; t++) { + // get the key vector for this head and at this timestep + // float* k = s.key_cache + loff + t * dim + h * headSize; + int keyCacheOffset = /* loff + */ (t * nEmbdGqa + (h / gqa) * nEmbdHead); + // calculate the attention score as the dot product of q and k + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, nEmbdHeadK); + //state.kq.setFloat(h + t, score); + score /= sqrtHeadSize; + // save the score to the attention buffer + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset, position + 1); // position + 0 + 1 + + // weighted sum of the values, store back into xb + // float* xb = s.xb + h * headSize; + int xbOffset = h * nEmbdHeadV; + // memset(xb, 0, headSize * sizeof(float)); + state.xb.fillInPlace(xbOffset, nEmbdHeadV, 0f); + + for (int t = 0; t <= position; t++) { + // get the value vector for this head and at this timestep + // float* v = s.value_cache + loff + t * dim + h * headSize;C + int vOffset = /* loff + */ t * nEmbdGqa + (h / gqa) * nEmbdHeadV; + // get the attention weight for this timestep + float a = state.att.getFloat(attOffset + t); + // accumulate the weighted value into xb + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, nEmbdHeadV, a); + } + }); + + // final matmul to get the output of the attention + weights.wo[l].matmul(state.xb, state.xb2, dim, nEmbdHeadK * config.numberOfHeads()); + + // residual connection back into x + state.x.addInPlace(state.xb2); + + // ffn rmsnorm + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps()); // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) // first calculate self.w1(x) and self.w3(x) @@ -154,7 +306,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p state.x.addInPlace(state.xb); } - rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps()); + rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps()); weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); @@ -188,7 +340,7 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p */ public static FloatArray forwardTornadoVM(Model model, State state, int token, int position, TornadoVMMasterPlan tornadoVMMasterPlan) { final Configuration configuration = model.configuration(); - final Weights weights = model.weights(); + final TornadoWeights weights = (TornadoWeights) model.weights(); MemorySegment.copy(weights.tokenEmbeddingTable.getSegment(), token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); diff --git a/src/main/java/com/example/inference/InferenceEngine.java b/src/main/java/com/example/inference/InferenceEngine.java index 814b1ae..bad6b67 100644 --- a/src/main/java/com/example/inference/InferenceEngine.java +++ b/src/main/java/com/example/inference/InferenceEngine.java @@ -2,7 +2,7 @@ import com.example.auxiliary.LastRunMetrics; import com.example.inference.sampler.Sampler; -import com.example.loader.weights.State; +import com.example.inference.state.State; import com.example.model.Configuration; import com.example.model.Model; import com.example.tokenizer.impl.Tokenizer; @@ -20,6 +20,16 @@ *

* Orchestrates the complete inference process: ingests prompt tokens, then generates * new tokens until a stop condition is met. Supports both CPU and GPU execution. + *

+ * + *

+ * It provides unified logic for the following methods: + *

    + *
  • {@code generateTokensLlama} – for LLaMA and Mistral models running on CPU
  • + *
  • {@code generateTokensQwen3} – for Qwen3 models running on CPU
  • + *
  • {@code generateTokensGPU} – for models executed on GPU
  • + *
+ *

*/ public final class InferenceEngine { @@ -46,7 +56,7 @@ private InferenceEngine() { * @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens * @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt */ - public static List generateTokens(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + public static List generateTokensLlama(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { // Start timing the whole process long startNanos = System.nanoTime(); @@ -122,6 +132,90 @@ public static List generateTokens(Model model, State state, int startPo return generatedTokens; } + public static List generateTokensQwen3(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + // Start timing the whole process + long startNanos = System.nanoTime(); + long startGen = 0; + long inferenceStartNanos = 0; + + // Validate and adjust maxTokens if necessary + if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) { + maxTokens = model.configuration().contextLength(); + } + + // Storage for generated tokens + List generatedTokens = new ArrayList<>(); + + // Initialize token variables + int currentToken = state.latestToken; // BOS? + int nextToken = 0; + int promptIndex = 0; + + for (int position = startPosition; position < maxTokens; ++position) { + + // Handle token processing + if (promptIndex < promptTokens.size()) { + // We're still processing the prompt tokens + final int token = promptTokens.get(promptIndex); + + model.forward(state, token, position); + + promptIndex++; + if (promptIndex < promptTokens.size()) { + continue; + } + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + // We have reached the last prompt token and computed the first response-token. + startGen = System.nanoTime(); + position++; // The current logit belongs to the next position + } else { + // Mark the start of actual generation (after prompt processing) + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + model.forward(state, currentToken, position); + + } + + // Sample the next token + nextToken = sampler.sampleToken(state.logits); + + // Output the token if echo is enabled + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + + // Track the generated token + generatedTokens.add(nextToken); + + // Notify via callback if provided + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + // Check for stop condition + if (stopTokens.contains(nextToken)) { + break; + } + + // Update for next iteration + state.latestToken = currentToken = nextToken; + } + + // Calculate and print performance metrics + long endNanos = System.nanoTime(); + double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0; + int totalTokens = promptIndex + generatedTokens.size(); + + LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds); + + return generatedTokens; + } + public static List generateTokensGPU(Model model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { // === Setup and Initialization === diff --git a/src/main/java/com/example/inference/state/LlamaState.java b/src/main/java/com/example/inference/state/LlamaState.java new file mode 100644 index 0000000..350ad72 --- /dev/null +++ b/src/main/java/com/example/inference/state/LlamaState.java @@ -0,0 +1,65 @@ +package com.example.inference.state; + +import com.example.core.model.tensor.ArrayFloatTensor; +import com.example.core.model.tensor.FloatTensor; +import com.example.model.Configuration; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +import java.util.stream.Stream; + +public final class LlamaState extends State { + + public LlamaState(Configuration config, int batchsize) { + super(config, batchsize); + } + + @Override + protected StateFields createStateFields(Configuration config) { + StateFields fields = new StateFields(); + + // Allocation with Llama/Mistral dimensions + fields.x = ArrayFloatTensor.allocate(config.dim()); + fields.xb = ArrayFloatTensor.allocate(config.dim()); + fields.xb2 = ArrayFloatTensor.allocate(config.dim()); + fields.hb = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.q = ArrayFloatTensor.allocate(config.dim()); + fields.k = ArrayFloatTensor.allocate(config.dim()); + fields.v = ArrayFloatTensor.allocate(config.dim()); + fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength()); + fields.logits = ArrayFloatTensor.allocate(config.vocabularySize()); + + // Key-value cache with Llama/Mistral dimensions + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + + // TornadoVM wrappers with Llama/Mistral dimensions + fields.wrapX = new FloatArray(config.dim()); + fields.wrapXb = new FloatArray(config.dim()); + fields.wrapXb2 = new FloatArray(config.dim()); + fields.wrapHb = new FloatArray(config.hiddenDim()); + fields.wrapHb2 = new FloatArray(config.hiddenDim()); + + fields.wrapLogits = new FloatArray(config.vocabularySize()); + fields.wrapQ = new FloatArray(config.dim()); + fields.wrapK = new FloatArray(config.dim()); + fields.wrapV = new FloatArray(config.dim()); + + // dim vs kvdim + fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + fields.wrapValueCache.init(0.f); + fields.wrapKeyCache.init(0.f); + fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength()); + fields.positionHolder = new IntArray(1); + + // Temporary arrays + fields.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); + fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); + fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); + + return fields; + } +} diff --git a/src/main/java/com/example/inference/state/Qwen3State.java b/src/main/java/com/example/inference/state/Qwen3State.java new file mode 100644 index 0000000..b82a056 --- /dev/null +++ b/src/main/java/com/example/inference/state/Qwen3State.java @@ -0,0 +1,78 @@ +package com.example.inference.state; + +import com.example.core.model.tensor.ArrayFloatTensor; +import com.example.core.model.tensor.FloatTensor; +import com.example.model.Configuration; +import com.example.model.qwen3.Qwen3Configuration; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +import java.util.stream.Stream; + +public final class Qwen3State extends State { + + // Qwen3-specific field + public final FloatTensor kq; + + public Qwen3State(Configuration config, int batchsize) { + super(config, batchsize); + // Initialize Qwen3-specific field + this.kq = ArrayFloatTensor.allocate(config.numberOfHeads(), 32, 15); + } + + @Override + protected StateFields createStateFields(Configuration configuration) { + StateFields fields = new StateFields(); + + Qwen3Configuration config = (Qwen3Configuration) configuration; + + // Qwen3-specific calculations + int nHeadKv = config.numberOfKeyValueHeads(); + int nEmbdHeadK = config.numberOfHeadsKey(); + int nEmbdKGqa = nEmbdHeadK * nHeadKv; + + // Qwen3-specific allocation logic + fields.x = ArrayFloatTensor.allocate(config.dim()); + fields.xb = ArrayFloatTensor.allocate(nEmbdHeadK * config.numberOfHeads()); + fields.xb2 = ArrayFloatTensor.allocate(config.dim()); + fields.hb = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.q = ArrayFloatTensor.allocate(nEmbdHeadK * config.numberOfHeads()); + fields.k = ArrayFloatTensor.allocate(nEmbdKGqa); // Different from Llama! + fields.v = ArrayFloatTensor.allocate(nEmbdKGqa); // Different from Llama! + fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength()); + fields.logits = ArrayFloatTensor.allocate(config.vocabularySize()); + + // Key-value cache with Qwen3 dimensions + int kvDim = nEmbdKGqa; + fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)) + .limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)) + .limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + + // TornadoVM wrappers with Qwen3-specific sizes + fields.wrapX = new FloatArray(config.dim()); + fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads()); // Different from Llama! + fields.wrapXb2 = new FloatArray(config.dim()); + fields.wrapHb = new FloatArray(config.hiddenDim()); + fields.wrapHb2 = new FloatArray(config.hiddenDim()); + fields.wrapLogits = new FloatArray(config.vocabularySize()); + fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads()); // Different from Llama! + fields.wrapK = new FloatArray(nEmbdKGqa); // Different from Llama! + fields.wrapV = new FloatArray(nEmbdKGqa); // Different from Llama! + + fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + fields.wrapValueCache.init(0.f); + fields.wrapKeyCache.init(0.f); + fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength()); + fields.positionHolder = new IntArray(1); + + // Temporary arrays + fields.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); + fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); + fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); + + return fields; + } +} diff --git a/src/main/java/com/example/inference/state/State.java b/src/main/java/com/example/inference/state/State.java new file mode 100644 index 0000000..a4436f2 --- /dev/null +++ b/src/main/java/com/example/inference/state/State.java @@ -0,0 +1,116 @@ +package com.example.inference.state; + +import com.example.core.model.tensor.FloatTensor; +import com.example.model.Configuration; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +/** + * Base class for State + */ +public abstract class State{ + + // current wave of activations + public final FloatTensor x; // activation at current time stamp (dim,) + public final FloatTensor xb; // same, but inside a residual branch (dim,) + public final FloatTensor xb2; // an additional buffer just for convenience (dim,) + public final FloatTensor hb; // buffer for hidden dimension in the ffn (hidden_dim,) + public final FloatTensor hb2; // buffer for hidden dimension in the ffn (hidden_dim,) + public final FloatTensor q; // query (dim,) + public final FloatTensor k; // key (dim,) + public final FloatTensor v; // value (dim,) + public final FloatTensor att; // buffer for scores/attention values (n_heads, seq_len) + public final FloatTensor logits; // output logits + public final int batchsize; + + // kv cache + public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim) + public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim) + + // Wrappers for TornadoVM compatibility (FloatArray data structure for TornadoVM acceleration) + // TornadoVM uses FloatArray for more efficient handling of data, particularly when running on GPU or other accelerators. + public final FloatArray wrapLogits; // FloatArray wrapper for the logits tensor, compatible with TornadoVM for GPU execution. + public final FloatArray wrapXb; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage. + public final FloatArray wrapXb2; // FloatArray wrapper for xb2, another residual buffer to aid in computations with TornadoVM. + public final FloatArray wrapHb; // FloatArray wrapper for hb (hidden dimension buffer for FFN), optimized for TornadoVM. + public final FloatArray wrapHb2; // FloatArray wrapper for hb2, additional hidden buffer for FFN, for compatibility with TornadoVM. + public final FloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM. + public final FloatArray wrapQ; // FloatArray wrapper for the query tensor, optimized for TornadoVM. + public final FloatArray wrapK; // FloatArray wrapper for the key tensor, optimized for TornadoVM. + public final FloatArray wrapV; // FloatArray wrapper for the value tensor, optimized for TornadoVM. + public final FloatArray wrapAtt; // FloatArray wrapper for the attention scores, optimized for TornadoVM. + public final FloatArray wrapKeyCache; // FloatArray wrapper for the key cache, optimized for TornadoVM. + public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM. + public final IntArray positionHolder; + + // store inter + public int localSize; + public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. + public FloatArray tempFFN; // Temporary buffer for feed-forward network calculations, size adjusted for local workgroup size. + public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size. + public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models. + + /** last index in previous block */ + + protected State(Configuration config, int batchsize) { + this.batchsize = -1; + this.latestToken = -1; + this.localSize = 256; + + // Initialize all fields through the creation method + StateFields fields = createStateFields(config); + + this.x = fields.x; + this.xb = fields.xb; + this.xb2 = fields.xb2; + this.hb = fields.hb; + this.hb2 = fields.hb2; + this.q = fields.q; + this.k = fields.k; + this.v = fields.v; + this.att = fields.att; + this.logits = fields.logits; + //int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + this.keyCache = fields.keyCache; + this.valueCache = fields.valueCache; + + this.wrapX = fields.wrapX; + this.wrapXb = fields.wrapXb; + this.wrapXb2 = fields.wrapXb2; + this.wrapHb = fields.wrapHb; + this.wrapHb2 = fields.wrapHb2; + this.wrapLogits = fields.wrapLogits; + this.wrapQ = fields.wrapQ; + this.wrapK = fields.wrapK; + this.wrapV = fields.wrapV; + + // dim vs kvdim + this.wrapKeyCache = fields.wrapKeyCache; + this.wrapValueCache = fields.wrapValueCache; + this.wrapAtt = fields.wrapAtt; + this.positionHolder = fields.positionHolder; + + // You need at least 9 elements: 1 for the final result + 8 for the workgroup partial sums + this.temp = fields.temp; + this.tempFFN = fields.tempFFN; + this.tempLogits = fields.tempLogits; + } + + // Abstract method - subclasses implement their specific allocation logic and sizes + protected abstract StateFields createStateFields(Configuration config); + + // Helper class to hold all the state fields during construction + protected static class StateFields { + public FloatTensor x, xb, xb2, hb, hb2, q, k, v, att, logits; + public FloatTensor[] keyCache, valueCache; + public FloatArray wrapX, wrapXb, wrapXb2, wrapHb, wrapHb2, wrapLogits; + public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache; + public IntArray positionHolder; + public FloatArray temp, tempFFN, tempLogits; + } + + @Override + public State clone() throws CloneNotSupportedException { + return (State) super.clone(); + } +} \ No newline at end of file diff --git a/src/main/java/com/example/inference/weights/Weights.java b/src/main/java/com/example/inference/weights/Weights.java new file mode 100644 index 0000000..f2c495a --- /dev/null +++ b/src/main/java/com/example/inference/weights/Weights.java @@ -0,0 +1,9 @@ +package com.example.inference.weights; + +import com.example.core.model.GGMLType; + +public interface Weights { + + GGMLType getWeightType(); + +} \ No newline at end of file diff --git a/src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java b/src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java new file mode 100644 index 0000000..621d69a --- /dev/null +++ b/src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java @@ -0,0 +1,17 @@ +package com.example.inference.weights.standard; + +import com.example.core.model.GGMLType; +import com.example.core.model.tensor.FloatTensor; + +public class LlamaStandardWeights extends StandardWeights { + + public LlamaStandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatTensor[] rms_ffn_weight, + FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatTensor rms_final_weight, FloatTensor freq_cis_real, FloatTensor freq_cis_imag, FloatTensor wcls, GGMLType weightType) { + super(token_embedding_table, rms_att_weight, wq, wk, wv, wo, rms_ffn_weight, w1, w2, w3, rms_final_weight, freq_cis_real, freq_cis_imag, wcls, weightType); + } + + @Override + public GGMLType getWeightType() { + return weightType; + } +} diff --git a/src/main/java/com/example/inference/weights/standard/Qwen3StandardWeights.java b/src/main/java/com/example/inference/weights/standard/Qwen3StandardWeights.java new file mode 100644 index 0000000..e40c679 --- /dev/null +++ b/src/main/java/com/example/inference/weights/standard/Qwen3StandardWeights.java @@ -0,0 +1,25 @@ +package com.example.inference.weights.standard; + +import com.example.core.model.GGMLType; +import com.example.core.model.tensor.FloatTensor; + +public class Qwen3StandardWeights extends StandardWeights { + public final FloatTensor[] attnKNorm, attnQNorm; + + public Qwen3StandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight, + FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, + FloatTensor[] attnKNorm, FloatTensor[] attnQNorm, + FloatTensor[] rms_ffn_weight, + FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, + FloatTensor rms_final_weight, FloatTensor freq_cis_real, FloatTensor freq_cis_imag, FloatTensor wcls, GGMLType weightType) { + // call to StandardWeights constructor + super(token_embedding_table, rms_att_weight, wq, wk, wv, wo, rms_ffn_weight, w1, w2, w3, rms_final_weight, freq_cis_real, freq_cis_imag, wcls, weightType); + this.attnKNorm = attnKNorm; + this.attnQNorm = attnQNorm; + } + + @Override + public GGMLType getWeightType() { + return weightType; + } +} diff --git a/src/main/java/com/example/inference/weights/standard/StandardWeights.java b/src/main/java/com/example/inference/weights/standard/StandardWeights.java new file mode 100644 index 0000000..afebf60 --- /dev/null +++ b/src/main/java/com/example/inference/weights/standard/StandardWeights.java @@ -0,0 +1,95 @@ +package com.example.inference.weights.standard; + +import com.example.core.model.GGMLType; +import com.example.core.model.tensor.FloatTensor; +import com.example.inference.weights.Weights; + +public abstract class StandardWeights implements Weights { + // token embedding table + public final FloatTensor token_embedding_table; // (vocab_size, dim) + // weights for rmsnorms + public final FloatTensor[] rms_att_weight; // (layer, dim) rmsnorm weights + // weights for matmuls + public final FloatTensor[] wq; // (layer, n_heads * head_size) + public final FloatTensor[] wk; // (layer, n_kv_heads, head_size) + public final FloatTensor[] wv; // (layer, n_kv_heads * head_size) + public final FloatTensor[] wo; // (layer, n_heads * head_size, dim) + //public final FloatTensor[] attnKNorm; // qwen3 + //public final FloatTensor[] attnQNorm; // qwen3 + public final FloatTensor[] rms_ffn_weight; // (layer, dim) + + // weights for ffn + public final FloatTensor[] w1; // (layer, hidden_dim, dim) + public final FloatTensor[] w2; // (layer, dim, hidden_dim) + public final FloatTensor[] w3; // (layer, hidden_dim, dim) + // + public final FloatTensor wcls; // (vocab_size, dim) + // public final rmsnorm + public final FloatTensor rms_final_weight; // (dim,) + // freq_cis for RoPE relatively positional embeddings + public final FloatTensor freq_cis_real; // (seq_len, head_size/2) + public final FloatTensor freq_cis_imag; // (seq_len, head_size/2) + + // (optional) classifier weights for the logits, on the last layer + protected final GGMLType weightType; + + /** + * Constructor for standard (non-TornadoVM) mode + * + * @param token_embedding_table + * Token embeddings matrix + * @param rms_att_weight + * RMSNorm weights for attention layers + * @param wq + * Query weight matrices + * @param wk + * Key weight matrices + * @param wv + * Value weight matrices + * @param wo + * Output projection matrices + * @param rms_ffn_weight + * RMSNorm weights for FFN layers + * @param w1 + * First FFN weight matrices + * @param w2 + * Second FFN weight matrices + * @param w3 + * Third FFN weight matrices (gate) + * @param rms_final_weight + * Final layer normalization weights + * @param freq_cis_real + * RoPE cosine components + * @param freq_cis_imag + * RoPE sine components + * @param wcls + * Classifier weights for output logits + */ + protected StandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight, + FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, + //FloatTensor[] attnKNorm, FloatTensor[] attnQNorm, + FloatTensor[] rms_ffn_weight, + FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, + FloatTensor rms_final_weight, + FloatTensor freq_cis_real, FloatTensor freq_cis_imag, + FloatTensor wcls, GGMLType weightType) { + + // Standard format + this.token_embedding_table = token_embedding_table; + this.rms_att_weight = rms_att_weight; + this.wq = wq; + this.wk = wk; + this.wv = wv; + this.wo = wo; + + this.rms_ffn_weight = rms_ffn_weight; + this.w1 = w1; + this.w2 = w2; + this.w3 = w3; + this.wcls = wcls; + this.rms_final_weight = rms_final_weight; + this.freq_cis_real = freq_cis_real; + this.freq_cis_imag = freq_cis_imag; + this.weightType = weightType; + } +} diff --git a/src/main/java/com/example/inference/weights/tornado/LlamaTornadoWeights.java b/src/main/java/com/example/inference/weights/tornado/LlamaTornadoWeights.java new file mode 100644 index 0000000..ec786e3 --- /dev/null +++ b/src/main/java/com/example/inference/weights/tornado/LlamaTornadoWeights.java @@ -0,0 +1,15 @@ +package com.example.inference.weights.tornado; + +import com.example.core.model.GGMLType; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +public class LlamaTornadoWeights extends TornadoWeights{ + + public LlamaTornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, + HalfFloatArray[] woLayered, FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, + FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray, GGMLType weightType) { + super(tokenEmbeddingTable, rms_att_weightLayered, wqLayered, wkLayered, wvLayered, woLayered, rms_ffn_weightLayered, w1Layered, w2Layered, w3Layered, rms_final_weight_as_floatArray, + freq_cis_realFlat, freq_cis_imagFlat, wclsByteArray, weightType); + } +} diff --git a/src/main/java/com/example/inference/weights/tornado/TornadoWeights.java b/src/main/java/com/example/inference/weights/tornado/TornadoWeights.java new file mode 100644 index 0000000..39ab75b --- /dev/null +++ b/src/main/java/com/example/inference/weights/tornado/TornadoWeights.java @@ -0,0 +1,68 @@ +package com.example.inference.weights.tornado; + +import com.example.core.model.GGMLType; +import com.example.inference.weights.Weights; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +public abstract class TornadoWeights implements Weights { + + public FloatArray[] rms_att_weightLayered; // (layer, dim) rmsnorm weights + public HalfFloatArray[] wqLayered; // (layer, n_heads * head_size) + public HalfFloatArray[] wkLayered; // (layer, n_kv_heads, head_size) + public HalfFloatArray[] wvLayered; // (layer, n_kv_heads * head_size) + public HalfFloatArray[] woLayered; // (layer, n_heads * head_size, dim) + public FloatArray[] rms_ffn_weightLayered; // (layer, dim) + public HalfFloatArray[] w1Layered; // (layer, hidden_dim, dim) + public HalfFloatArray[] w2Layered; // (layer, dim, hidden_dim) + public HalfFloatArray[] w3Layered; // (layer, hidden_dim, dim) + public FloatArray rms_final_weight_as_floatArray; + public FloatArray tokenEmbeddingTable; // (vocab_size, dim) + public FloatArray freq_cis_realFlat; // (seq_len, head_size/2) + public FloatArray freq_cis_imagFlat; // (seq_len, head_size/2) + public HalfFloatArray wclsHalfFloat; + + // (optional) classifier weights for the logits, on the last layer + protected final GGMLType weightType; + + protected TornadoWeights( + FloatArray tokenEmbeddingTable, + FloatArray[] rms_att_weightLayered, + HalfFloatArray[] wqLayered, + HalfFloatArray[] wkLayered, + HalfFloatArray[] wvLayered, + HalfFloatArray[] woLayered, + FloatArray[] rms_ffn_weightLayered, + HalfFloatArray[] w1Layered, + HalfFloatArray[] w2Layered, + HalfFloatArray[] w3Layered, + FloatArray rms_final_weight_as_floatArray, + FloatArray freq_cis_realFlat, + FloatArray freq_cis_imagFlat, + HalfFloatArray wclsByteArray, + GGMLType weightType) { + // TornadoVM format + this.tokenEmbeddingTable = tokenEmbeddingTable; + this.rms_att_weightLayered = rms_att_weightLayered; + this.wqLayered = wqLayered; + this.wkLayered = wkLayered; + this.wvLayered = wvLayered; + this.woLayered = woLayered; + this.rms_ffn_weightLayered = rms_ffn_weightLayered; + this.w1Layered = w1Layered; + this.w2Layered = w2Layered; + this.w3Layered = w3Layered; + this.rms_final_weight_as_floatArray = rms_final_weight_as_floatArray; + this.freq_cis_realFlat = freq_cis_realFlat; + this.freq_cis_imagFlat = freq_cis_imagFlat; + this.wclsHalfFloat = wclsByteArray; + this.weightType = weightType; + } + + @Override + public GGMLType getWeightType() { + return weightType; + } + + +} diff --git a/src/main/java/com/example/loader/weights/State.java b/src/main/java/com/example/loader/weights/State.java deleted file mode 100644 index 12f968a..0000000 --- a/src/main/java/com/example/loader/weights/State.java +++ /dev/null @@ -1,107 +0,0 @@ -package com.example.loader.weights; - -import com.example.core.model.tensor.ArrayFloatTensor; -import com.example.core.model.tensor.FloatTensor; -import com.example.model.Configuration; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.IntArray; - -import java.util.stream.Stream; - -public final class State { - - // current wave of activations - public final FloatTensor x; // activation at current time stamp (dim,) - public final FloatTensor xb; // same, but inside a residual branch (dim,) - public final FloatTensor xb2; // an additional buffer just for convenience (dim,) - public final FloatTensor hb; // buffer for hidden dimension in the ffn (hidden_dim,) - public final FloatTensor hb2; // buffer for hidden dimension in the ffn (hidden_dim,) - public final FloatTensor q; // query (dim,) - public final FloatTensor k; // key (dim,) - public final FloatTensor v; // value (dim,) - public final FloatTensor att; // buffer for scores/attention values (n_heads, seq_len) - public final FloatTensor logits; // output logits - public final int batchsize; - - // kv cache - public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim) - public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim) - - // Wrappers for TornadoVM compatibility (FloatArray data structure for TornadoVM acceleration) - // TornadoVM uses FloatArray for more efficient handling of data, particularly when running on GPU or other accelerators. - public final FloatArray wrapLogits; // FloatArray wrapper for the logits tensor, compatible with TornadoVM for GPU execution. - public final FloatArray wrapXb; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage. - public final FloatArray wrapXb2; // FloatArray wrapper for xb2, another residual buffer to aid in computations with TornadoVM. - public final FloatArray wrapHb; // FloatArray wrapper for hb (hidden dimension buffer for FFN), optimized for TornadoVM. - public final FloatArray wrapHb2; // FloatArray wrapper for hb2, additional hidden buffer for FFN, for compatibility with TornadoVM. - public final FloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM. - - public final FloatArray wrapQ; // FloatArray wrapper for the query tensor, optimized for TornadoVM. - public final FloatArray wrapK; // FloatArray wrapper for the key tensor, optimized for TornadoVM. - public final FloatArray wrapV; // FloatArray wrapper for the value tensor, optimized for TornadoVM. - public final FloatArray wrapAtt; // FloatArray wrapper for the attention scores, optimized for TornadoVM. - public final FloatArray wrapKeyCache;// FloatArray wrapper for the key cache, optimized for TornadoVM. - public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM. - public final IntArray positionHolder; - - // store inter - // - public int localSize; - public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. - public FloatArray tempFFN; // Temporary buffer for feed-forward network calculations, size adjusted for local workgroup size. - public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size. - - public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models. - - /** last index in previous block */ - - public State(Configuration config, int batchsize) { - this.batchsize = -1; - - this.x = ArrayFloatTensor.allocate(config.dim()); - this.xb = ArrayFloatTensor.allocate(config.dim()); - this.xb2 = ArrayFloatTensor.allocate(config.dim()); - this.hb = ArrayFloatTensor.allocate(config.hiddenDim()); - this.hb2 = ArrayFloatTensor.allocate(config.hiddenDim()); - this.q = ArrayFloatTensor.allocate(config.dim()); - this.k = ArrayFloatTensor.allocate(config.dim()); - this.v = ArrayFloatTensor.allocate(config.dim()); - this.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength()); - this.logits = ArrayFloatTensor.allocate(config.vocabularySize()); - int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); - this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); - this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); - - this.wrapX = new FloatArray(config.dim()); - this.wrapXb = new FloatArray(config.dim()); - this.wrapXb2 = new FloatArray(config.dim()); - this.wrapHb = new FloatArray(config.hiddenDim()); - this.wrapHb2 = new FloatArray(config.hiddenDim()); - - this.wrapLogits = new FloatArray(config.vocabularySize()); - this.wrapQ = new FloatArray(config.dim()); - this.wrapK = new FloatArray(config.dim()); - this.wrapV = new FloatArray(config.dim()); - - // dim vs kvdim - this.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); - this.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); - this.wrapValueCache.init(0.f); - this.wrapKeyCache.init(0.f); - this.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength()); - this.positionHolder = new IntArray(1); - this.latestToken = -1; - - // - this.localSize = 256; - // You need at least 9 elements: 1 for the final result + 8 for the workgroup partial sums - this.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); - this.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); - this.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize)); - } - - @Override - public State clone() throws CloneNotSupportedException { - return (State) super.clone(); - } -} \ No newline at end of file diff --git a/src/main/java/com/example/loader/weights/Weights.java b/src/main/java/com/example/loader/weights/Weights.java deleted file mode 100644 index d5b004c..0000000 --- a/src/main/java/com/example/loader/weights/Weights.java +++ /dev/null @@ -1,173 +0,0 @@ -package com.example.loader.weights; - -import com.example.core.model.GGMLType; -import com.example.core.model.tensor.FloatTensor; -import com.example.core.model.tensor.GGMLTensorEntry; -import com.example.core.types.Float16; -import uk.ac.manchester.tornado.api.types.HalfFloat; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; - -import java.lang.foreign.MemorySegment; -import java.nio.ByteOrder; -import java.nio.FloatBuffer; -import java.util.function.IntFunction; - -import static com.example.core.model.tensor.FloatTensor.readByte; -import static com.example.core.model.tensor.FloatTensor.readShort; - -public class Weights { - // token embedding table - public final FloatTensor token_embedding_table; // (vocab_size, dim) - // weights for rmsnorms - public final FloatBuffer[] rms_att_weight; // (layer, dim) rmsnorm weights - // weights for matmuls - public final FloatTensor[] wq; // (layer, n_heads * head_size) - public final FloatTensor[] wk; // (layer, n_kv_heads, head_size) - public final FloatTensor[] wv; // (layer, n_kv_heads * head_size) - public final FloatTensor[] wo; // (layer, n_heads * head_size, dim) - public final FloatBuffer[] rms_ffn_weight; // (layer, dim) - - // weights for ffn - public final FloatTensor[] w1; // (layer, hidden_dim, dim) - public final FloatTensor[] w2; // (layer, dim, hidden_dim) - public final FloatTensor[] w3; // (layer, hidden_dim, dim) - // - public final FloatTensor wcls; // (vocab_size, dim) - public final HalfFloatArray wclsHalfFloat; - // public final rmsnorm - public final FloatBuffer rms_final_weight; // (dim,) - // freq_cis for RoPE relatively positional embeddings - public final FloatBuffer freq_cis_real; // (seq_len, head_size/2) - public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2) - // // Layered Data structures - public FloatArray[] rms_att_weightLayered; // (layer, dim) rmsnorm weights - public HalfFloatArray[] wqLayered; // (layer, n_heads * head_size) - public HalfFloatArray[] wkLayered; // (layer, n_kv_heads, head_size) - public HalfFloatArray[] wvLayered; // (layer, n_kv_heads * head_size) - public HalfFloatArray[] woLayered; // (layer, n_heads * head_size, dim) - public FloatArray[] rms_ffn_weightLayered; // (layer, dim) - public HalfFloatArray[] w1Layered; // (layer, hidden_dim, dim) - public HalfFloatArray[] w2Layered; // (layer, dim, hidden_dim) - // - public HalfFloatArray[] w3Layered; // (layer, hidden_dim, dim) - public FloatArray rms_final_weight_as_floatArray; - public FloatArray tokenEmbeddingTable; // (vocab_size, dim) - public FloatArray freq_cis_realFlat; // (seq_len, head_size/2) - public FloatArray freq_cis_imagFlat; // (seq_len, head_size/2) - // (optional) classifier weights for the logits, on the last layer - public GGMLType weightType; - - /** - * Constructor to initialize all weight tensors for the model. Automatically creates TornadoVM-compatible versions when needed. - * - * @param token_embedding_table - * Token embeddings matrix - * @param rms_att_weight - * RMSNorm weights for attention layers - * @param wq - * Query weight matrices - * @param wk - * Key weight matrices - * @param wv - * Value weight matrices - * @param wo - * Output projection matrices - * @param rms_ffn_weight - * RMSNorm weights for FFN layers - * @param w1 - * First FFN weight matrices - * @param w2 - * Second FFN weight matrices - * @param w3 - * Third FFN weight matrices (gate) - * @param rms_final_weight - * Final layer normalization weights - * @param freq_cis_real - * RoPE cosine components - * @param freq_cis_imag - * RoPE sine components - * @param wcls - * Classifier weights for output logits - * - /** - * Constructor for standard (non-TornadoVM) mode - */ - public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight, - FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls, GGMLType weightType) { - // Standard format - this.token_embedding_table = token_embedding_table; - this.rms_att_weight = rms_att_weight; - this.wq = wq; - this.wk = wk; - this.wv = wv; - this.wo = wo; - this.rms_ffn_weight = rms_ffn_weight; - this.w1 = w1; - this.w2 = w2; - this.w3 = w3; - this.wcls = wcls; - this.rms_final_weight = rms_final_weight; - this.freq_cis_real = freq_cis_real; - this.freq_cis_imag = freq_cis_imag; - this.weightType = weightType; - - // TornadoVM format (null when not using TornadoVM) - this.tokenEmbeddingTable = null; - this.rms_att_weightLayered = null; - this.wqLayered = null; - this.wkLayered = null; - this.wvLayered = null; - this.woLayered = null; - this.rms_ffn_weightLayered = null; - this.w1Layered = null; - this.w2Layered = null; - this.w3Layered = null; - this.rms_final_weight_as_floatArray = null; - this.freq_cis_realFlat = null; - this.freq_cis_imagFlat = null; - this.wclsHalfFloat = null; - } - - /** - * Constructor for TornadoVM mode - */ - public Weights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, HalfFloatArray[] woLayered, - FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray, - FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray, GGMLType weightType) { - // Standard format (null when using TornadoVM) - this.token_embedding_table = null; - this.rms_att_weight = null; - this.wq = null; - this.wk = null; - this.wv = null; - this.wo = null; - this.rms_ffn_weight = null; - this.w1 = null; - this.w2 = null; - this.w3 = null; - this.wcls = null; - this.rms_final_weight = null; - this.freq_cis_real = null; - this.freq_cis_imag = null; - - // TornadoVM format - this.tokenEmbeddingTable = tokenEmbeddingTable; - this.rms_att_weightLayered = rms_att_weightLayered; - this.wqLayered = wqLayered; - this.wkLayered = wkLayered; - this.wvLayered = wvLayered; - this.woLayered = woLayered; - this.rms_ffn_weightLayered = rms_ffn_weightLayered; - this.w1Layered = w1Layered; - this.w2Layered = w2Layered; - this.w3Layered = w3Layered; - this.rms_final_weight_as_floatArray = rms_final_weight_as_floatArray; - this.freq_cis_realFlat = freq_cis_realFlat; - this.freq_cis_imagFlat = freq_cis_imagFlat; - this.wclsHalfFloat = wclsByteArray; - this.weightType = weightType; - } - -} \ No newline at end of file diff --git a/src/main/java/com/example/model/Configuration.java b/src/main/java/com/example/model/Configuration.java index 6ff558c..56e14b8 100644 --- a/src/main/java/com/example/model/Configuration.java +++ b/src/main/java/com/example/model/Configuration.java @@ -17,12 +17,17 @@ public interface Configuration { /** Number of key/value heads (can be fewer than query heads in multi-query attention) */ int numberOfKeyValueHeads(); + int numberOfHeadsKey(); + /** Size of the vocabulary (token set) */ int vocabularySize(); /** Maximum sequence length the model can process */ int contextLength(); + /** Max sequence length in model */ + int contextLengthModel(); + /** Epsilon value for RMSNorm layers (stabilizes normalization) */ float rmsNormEps(); diff --git a/src/main/java/com/example/model/Model.java b/src/main/java/com/example/model/Model.java index e42349b..3147ee2 100644 --- a/src/main/java/com/example/model/Model.java +++ b/src/main/java/com/example/model/Model.java @@ -4,8 +4,8 @@ import com.example.auxiliary.LastRunMetrics; import com.example.inference.InferenceEngine; import com.example.inference.sampler.Sampler; -import com.example.loader.weights.State; -import com.example.loader.weights.Weights; +import com.example.inference.state.State; +import com.example.inference.weights.Weights; import com.example.model.format.ChatFormat; import com.example.tokenizer.impl.Tokenizer; import com.example.tornadovm.TornadoVMMasterPlan; @@ -20,18 +20,36 @@ import static com.example.LlamaApp.USE_TORNADOVM; public interface Model { + Configuration configuration(); Tokenizer tokenizer(); Weights weights(); + ChatFormat chatFormat(); + ModelType getModelType(); State createNewState(); State createNewState(int batchsize); + /** + * Wrapper for invoking the model-specific forward pass via InferenceCore. + * + *

+ * Delegates to the appropriate InferenceCore method based on the model type + * (e.g., {@code forwardJava}, {@code forwardJavaQwen3}). + *

+ */ + void forward(State state, int token, int position); + + /** + * Wrapper for invoking the model-specific {@code InferenceEngine.generateTokens} call. + */ + List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated); + /** * Model agnostic default implementation for interactive mode. * @param sampler @@ -41,8 +59,11 @@ default void runInteractive(Sampler sampler, Options options) { State state = null; List conversationTokens = new ArrayList<>(); - ChatFormat chatFormat = ChatFormat.create(tokenizer()); - conversationTokens.add(chatFormat.getBeginOfText()); + ChatFormat chatFormat = chatFormat(); + + if (!getModelType().equals(ModelType.QWEN_3)) { + conversationTokens.add(chatFormat.getBeginOfText()); + } if (options.systemPrompt() != null) { conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); @@ -92,7 +113,7 @@ default void runInteractive(Sampler sampler, Options options) { options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); } else { // CPU path - responseTokens = InferenceEngine.generateTokens(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), + responseTokens = generateTokens(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); } @@ -138,11 +159,14 @@ default void runInteractive(Sampler sampler, Options options) { */ default void runInstructOnce(Sampler sampler, Options options) { State state = createNewState(); - ChatFormat chatFormat = ChatFormat.create(tokenizer()); + ChatFormat chatFormat = chatFormat(); TornadoVMMasterPlan tornadoVMPlan = null; List promptTokens = new ArrayList<>(); - promptTokens.add(chatFormat.getBeginOfText()); + + if (!getModelType().equals(ModelType.QWEN_3)) { + promptTokens.add(chatFormat.getBeginOfText()); + } if (options.systemPrompt() != null) { promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); @@ -168,7 +192,7 @@ default void runInstructOnce(Sampler sampler, Options options) { responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); } else { - responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); + responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); } if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { diff --git a/src/main/java/com/example/model/ModelType.java b/src/main/java/com/example/model/ModelType.java index 0426824..f36a4e1 100644 --- a/src/main/java/com/example/model/ModelType.java +++ b/src/main/java/com/example/model/ModelType.java @@ -1,8 +1,9 @@ package com.example.model; import com.example.core.model.GGUF; -import com.example.model.llama.Llama; -import com.example.model.mistral.Mistral; +import com.example.model.loader.LlamaModelLoader; +import com.example.model.loader.MistralModelLoader; +import com.example.model.loader.Qwen3ModelLoader; import java.nio.channels.FileChannel; @@ -10,14 +11,21 @@ public enum ModelType { LLAMA_3 { @Override public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { - return Llama.loadModel(fileChannel, gguf, contextLength, loadWeights); + return new LlamaModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel(); } }, MISTRAL { @Override public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { - return Mistral.loadModel(fileChannel, gguf, contextLength, loadWeights); + return new MistralModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel(); + } + }, + + QWEN_3 { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + return new Qwen3ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel(); } }, diff --git a/src/main/java/com/example/model/format/ChatFormat.java b/src/main/java/com/example/model/format/ChatFormat.java index b114c7d..d6e7e74 100644 --- a/src/main/java/com/example/model/format/ChatFormat.java +++ b/src/main/java/com/example/model/format/ChatFormat.java @@ -2,17 +2,26 @@ import com.example.tokenizer.impl.LlamaTokenizer; import com.example.tokenizer.impl.MistralTokenizer; +import com.example.tokenizer.impl.Qwen3Tokenizer; import java.util.List; import java.util.Set; public interface ChatFormat { - static ChatFormat create(Object tokenizer) { + default ChatTokens chatTokens() { + throw new UnsupportedOperationException("ChatFormat for Llama and Mistral does not support chatTokens"); + } + + public record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) { } + + static ChatFormat create(Object tokenizer, ChatTokens chatTokens) { if (tokenizer instanceof LlamaTokenizer llamaTokenizer) { return new LlamaChatFormat(llamaTokenizer); } else if (tokenizer instanceof MistralTokenizer mistralTokenizer) { return new MistralChatFormat(mistralTokenizer); + } else if (tokenizer instanceof Qwen3Tokenizer qwen3Tokenizer) { + return new Qwen3ChatFormat(qwen3Tokenizer, chatTokens); } else { throw new IllegalArgumentException("Unsupported tokenizer type: " + tokenizer.getClass().getName()); } @@ -54,6 +63,9 @@ record Role(String name) { public static Role SYSTEM = new Role("system"); public static Role USER = new Role("user"); public static Role ASSISTANT = new Role("assistant"); + public static Role FIM_PREFIX = new ChatFormat.Role("fim_prefix"); + public static Role FIM_SUFFIX = new ChatFormat.Role("fim_suffix"); + public static Role FIM_MIDDLE = new ChatFormat.Role("fim_middle"); @Override public String toString() { diff --git a/src/main/java/com/example/model/format/Qwen3ChatFormat.java b/src/main/java/com/example/model/format/Qwen3ChatFormat.java new file mode 100644 index 0000000..2a42a03 --- /dev/null +++ b/src/main/java/com/example/model/format/Qwen3ChatFormat.java @@ -0,0 +1,123 @@ +package com.example.model.format; + +import com.example.tokenizer.impl.Qwen3Tokenizer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Utility tailored for the Chat Markup Language (ChatML) prompt format. + */ +public class Qwen3ChatFormat implements ChatFormat { + + protected Qwen3Tokenizer tokenizer; + protected ChatTokens chatTokens; + + protected final int beginOfText; + protected final int startHeader; + protected final int endHeader; + protected final int endOfTurn; + protected final int endOfText; + protected final int endOfMessage; + protected final int endOfTextFim; + + protected final int imStart; // beginOfText + protected final int imEnd; // endOfText + + protected final int fimPrefix; + protected final int fimSuffix; + protected final int fimMiddle; + + public Qwen3ChatFormat(Qwen3Tokenizer tokenizer, ChatTokens chatTokens) { + //super(tokenizer, "", chatTokens.tStartHeader(), chatTokens.tEndHeader(), chatTokens.tEndOfTurn(), chatTokens.tEndOfText(), "", chatTokens.tEndOfTextFim()); + this.tokenizer = tokenizer; + this.chatTokens = chatTokens; + Map specialTokens = tokenizer.getSpecialTokens(); + this.beginOfText = specialTokens.getOrDefault("", -1); + this.startHeader = specialTokens.getOrDefault(chatTokens.tStartHeader(), -1); + this.endHeader = specialTokens.getOrDefault(chatTokens.tEndHeader(), -1); + this.endOfTurn = specialTokens.getOrDefault(chatTokens.tEndOfTurn(), -1); + this.endOfText = specialTokens.getOrDefault(chatTokens.tEndOfText(), -1); + this.endOfTextFim = specialTokens.getOrDefault(chatTokens.tEndOfTextFim(), -1); + this.endOfMessage = specialTokens.getOrDefault("", -1); // Use default value if key not found + + this.imStart = startHeader; + this.imEnd = endHeader; + + fimPrefix = specialTokens.getOrDefault("<|fim_prefix|>", -1); + fimSuffix = specialTokens.getOrDefault("<|fim_suffix|>", -1); + fimMiddle = specialTokens.getOrDefault("<|fim_middle|>", -1); + } + + public ChatTokens chatTokens() { + return chatTokens; + } + + @Override + public List encodeHeader(Message message) { + List tokens = new ArrayList<>(); + if (endHeader == -1) { + // DeepSeek-R1 + String sToken = switch (message.role().name()) { + case "system" -> null; + case "user" -> "<|User|>"; + case "assistant" -> "<|Assistant|>"; + case "fim_prefix" -> "<|fim_prefix|>"; + case "fim_middle" -> "<|fim_middle|>"; + case "fim_suffix" -> "<|fim_suffix|>"; + default -> null; + }; + if (sToken != null) { + Integer token = tokenizer.getSpecialTokens().get("<|User|>"); + if (token == null) { + throw new IllegalStateException(String.format("Unknown token '%s'", sToken)); + } + tokens.add(token); + } + } else if (Role.FIM_PREFIX.equals(message.role())) { + // fill-in-the-middle, token fim_prefix. + tokens.add(fimPrefix); + } else if (Role.FIM_SUFFIX.equals(message.role())) { + tokens.add(fimSuffix); + } else if (Role.FIM_MIDDLE.equals(message.role())) { + tokens.add(fimMiddle); + } else { + tokens.add(imStart); + tokens.addAll(this.tokenizer.encodeAsList(message.role().name())); + tokens.addAll(this.tokenizer.encodeAsList("\n")); + } + return tokens; + } + + @Override + public List encodeMessage(Message message) { + List tokens = this.encodeHeader(message); + tokens.addAll(this.tokenizer.encodeAsList(message.content().strip())); + boolean isFim = Role.FIM_PREFIX.equals(message.role()) + || Role.FIM_SUFFIX.equals(message.role()) + || Role.FIM_MIDDLE.equals(message.role()); + if (imEnd != -1 && !isFim) { + tokens.add(imEnd); + } + return tokens; + } + + @Override + public int getBeginOfText() { + return beginOfText; + } + + @Override + public Set getStopTokens() { + if (imEnd == -1 && endOfText == -1) { + throw new IllegalStateException("No stop token is defined."); + } + if (imEnd == -1) { + + return Set.of(endOfText); + } + return Set.of(imEnd, endOfText, endOfTextFim); + } +} diff --git a/src/main/java/com/example/model/llama/Llama.java b/src/main/java/com/example/model/llama/Llama.java index c8326a1..0ad3427 100644 --- a/src/main/java/com/example/model/llama/Llama.java +++ b/src/main/java/com/example/model/llama/Llama.java @@ -1,24 +1,22 @@ package com.example.model.llama; -import com.example.auxiliary.Timer; -import com.example.core.model.GGUF; -import com.example.core.model.tensor.GGMLTensorEntry; +import com.example.inference.InferenceCore; +import com.example.inference.InferenceEngine; +import com.example.inference.sampler.Sampler; +import com.example.inference.state.LlamaState; +import com.example.inference.state.State; +import com.example.inference.weights.Weights; import com.example.model.Model; -import com.example.loader.weights.State; -import com.example.loader.weights.Weights; import com.example.model.ModelType; +import com.example.model.format.ChatFormat; import com.example.tokenizer.impl.LlamaTokenizer; import com.example.tokenizer.impl.Tokenizer; -import com.example.tokenizer.vocabulary.Vocabulary; -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.util.Map; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; -import static com.example.loader.weights.ModelLoader.loadWeights; - -public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model { - private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16); +public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model { /* For explicit use */ private LlamaTokenizer getAsLlamaTokenizer() { @@ -32,53 +30,27 @@ public ModelType getModelType() { @Override public State createNewState() { - State state = new State(configuration(), -1); + State state = new LlamaState(configuration(), -1); state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); return state; } @Override public State createNewState(int batchsize) { - State state = new State(configuration(), batchsize); + State state = new LlamaState(configuration(), batchsize); state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); return state; } - // @formatter:off - public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { - try (var ignored = Timer.log("Load LlaMa model")) { - Map metadata = gguf.getMetadata(); - - Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata); - Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary); - - LlamaConfiguration config = new LlamaConfiguration( - (int) metadata.get("llama.embedding_length"), - (int) metadata.get("llama.feed_forward_length"), - (int) metadata.get("llama.block_count"), - (int) metadata.get("llama.attention.head_count"), - - metadata.containsKey("llama.attention.head_count_kv") ? - (int) metadata.get("llama.attention.head_count_kv") : - (int) metadata.get("llama.attention.head_count"), + @Override + public void forward(State state, int token, int position) { + InferenceCore.forwardJava(this, state, token, position); + } - vocabulary.size(), - (int) metadata.get("llama.context_length"), - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) - ).withContextLength(contextLength); - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - return new Llama(config, tokenizer, weights); - } catch (IOException e) { - throw new RuntimeException(e); - } + @Override + public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } - // @formatter:on - } diff --git a/src/main/java/com/example/model/llama/LlamaConfiguration.java b/src/main/java/com/example/model/llama/LlamaConfiguration.java index 53195d1..a363bc6 100644 --- a/src/main/java/com/example/model/llama/LlamaConfiguration.java +++ b/src/main/java/com/example/model/llama/LlamaConfiguration.java @@ -5,6 +5,16 @@ public record LlamaConfiguration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) implements Configuration { + @Override + public int numberOfHeadsKey() { + throw new UnsupportedOperationException("Not supported for Llama."); + } + + @Override + public int contextLengthModel() { + throw new UnsupportedOperationException("Not supported for Llama."); + } + public int headSize() { return dim / numberOfHeads; } diff --git a/src/main/java/com/example/model/loader/LlamaModelLoader.java b/src/main/java/com/example/model/loader/LlamaModelLoader.java new file mode 100644 index 0000000..f6cb245 --- /dev/null +++ b/src/main/java/com/example/model/loader/LlamaModelLoader.java @@ -0,0 +1,60 @@ +package com.example.model.loader; + +import com.example.auxiliary.Timer; +import com.example.core.model.GGUF; +import com.example.core.model.tensor.GGMLTensorEntry; +import com.example.inference.weights.Weights; +import com.example.model.format.ChatFormat; +import com.example.model.llama.Llama; +import com.example.model.llama.LlamaConfiguration; +import com.example.tokenizer.impl.LlamaTokenizer; +import com.example.tokenizer.impl.Tokenizer; +import com.example.tokenizer.vocabulary.Vocabulary; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.util.Map; + +public class LlamaModelLoader extends ModelLoader { + + public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + super(fileChannel, gguf, contextLength, loadWeights); + } + + // @formatter:off + @Override + public Llama loadModel() { + try (var ignored = Timer.log("Load LlaMa model")) { + Map metadata = gguf.getMetadata(); + + Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata); + Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary); + + LlamaConfiguration config = new LlamaConfiguration( + (int) metadata.get("llama.embedding_length"), + (int) metadata.get("llama.feed_forward_length"), + (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + + metadata.containsKey("llama.attention.head_count_kv") ? + (int) metadata.get("llama.attention.head_count_kv") : + (int) metadata.get("llama.attention.head_count"), + + vocabulary.size(), + (int) metadata.get("llama.context_length"), + (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) + ).withContextLength(contextLength); + + Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + // @formatter:on +} diff --git a/src/main/java/com/example/model/loader/MistralModelLoader.java b/src/main/java/com/example/model/loader/MistralModelLoader.java new file mode 100644 index 0000000..159ad9a --- /dev/null +++ b/src/main/java/com/example/model/loader/MistralModelLoader.java @@ -0,0 +1,66 @@ +package com.example.model.loader; + +import com.example.auxiliary.Timer; +import com.example.core.model.GGUF; +import com.example.core.model.tensor.GGMLTensorEntry; +import com.example.inference.weights.Weights; +import com.example.model.format.ChatFormat; +import com.example.model.mistral.Mistral; +import com.example.model.mistral.MistralConfiguration; +import com.example.tokenizer.impl.MistralTokenizer; +import com.example.tokenizer.impl.Tokenizer; +import com.example.tokenizer.vocabulary.Vocabulary; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.util.Map; + +public class MistralModelLoader extends ModelLoader { + + public MistralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + super(fileChannel, gguf, contextLength, loadWeights); + } + + // @formatter:off + @Override + public Mistral loadModel() { + try (var ignored = Timer.log("Load Mistral model")) { + Map metadata = gguf.getMetadata(); + + Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata); + Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary); + + int modelContextLength = (int) metadata.get("llama.context_length"); + if (contextLength < 0 || modelContextLength < contextLength) { + contextLength = modelContextLength; + } + + MistralConfiguration config = new MistralConfiguration( + (int) metadata.get("llama.embedding_length"), + (int) metadata.get("llama.feed_forward_length"), + (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + + metadata.containsKey("llama.attention.head_count_kv") + ? (int) metadata.get("llama.attention.head_count_kv") + : (int) metadata.get("llama.attention.head_count"), + + vocabulary.size(), + contextLength, + false, + (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) + ); + + Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + // @formatter:on +} diff --git a/src/main/java/com/example/loader/weights/ModelLoader.java b/src/main/java/com/example/model/loader/ModelLoader.java similarity index 79% rename from src/main/java/com/example/loader/weights/ModelLoader.java rename to src/main/java/com/example/model/loader/ModelLoader.java index 353600e..6482414 100644 --- a/src/main/java/com/example/loader/weights/ModelLoader.java +++ b/src/main/java/com/example/model/loader/ModelLoader.java @@ -1,14 +1,20 @@ -package com.example.loader.weights; +package com.example.model.loader; import com.example.LlamaApp; import com.example.core.model.GGMLType; import com.example.core.model.GGUF; +import com.example.core.model.tensor.ArrayFloatTensor; import com.example.core.model.tensor.F16FloatTensor; +import com.example.core.model.tensor.F32FloatTensor; import com.example.core.model.tensor.FloatTensor; import com.example.core.model.tensor.GGMLTensorEntry; import com.example.core.model.tensor.Q4_0FloatTensor; import com.example.core.model.tensor.Q8_0FloatTensor; import com.example.core.types.Pair; +import com.example.inference.weights.standard.LlamaStandardWeights; +import com.example.inference.weights.tornado.LlamaTornadoWeights; +import com.example.inference.weights.tornado.TornadoWeights; +import com.example.inference.weights.Weights; import com.example.model.Configuration; import com.example.model.Model; import com.example.model.ModelType; @@ -27,10 +33,22 @@ import java.util.Map; import java.util.function.IntFunction; -public final class ModelLoader { +public abstract class ModelLoader { private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2"; private static final String TOKENIZER_MISTRAL_MODEL = "llama"; + protected FileChannel fileChannel; + GGUF gguf; + int contextLength; + boolean loadWeights; + + public ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + this.fileChannel = fileChannel; + this.gguf = gguf; + this.contextLength = contextLength; + this.loadWeights = loadWeights; + } + private static ModelType detectModelType(Map metadata) { String name = (String) metadata.get("general.name"); String tokenizerModel = (String) metadata.get("tokenizer.ggml.model"); @@ -43,6 +61,8 @@ private static ModelType detectModelType(Map metadata) { return ModelType.MISTRAL; } else if (lowerName.contains("llama")) { return ModelType.LLAMA_3; + } else if (lowerName.contains("qwen3")) { + return ModelType.QWEN_3; } } @@ -65,6 +85,8 @@ private static ModelType detectModelType(Map metadata) { return ModelType.UNKNOWN; } + public abstract Model loadModel(); + public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException { // initial load of metadata from gguf file GGUF gguf = GGUF.loadModel(ggufPath); @@ -75,7 +97,7 @@ public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeig return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights); } - public static Weights loadWeights(Map tensorEntries, Configuration config) { + public Weights loadWeights(Map tensorEntries, Configuration config) { boolean ropeScaling = tensorEntries.containsKey("rope_freqs"); RopeConfig ropeConfig = new RopeConfig(8.0f, // scaleFactor 1.0f, // loFreqFactor @@ -83,14 +105,15 @@ public static Weights loadWeights(Map tensorEntries, Co 8192 // oldContextLength ); - Pair ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength(), // Maximum sequence length the model can process - config.headSize(), // Dimension of each attention head - config.ropeTheta(), // Base frequency parameter (typically 10000.0) - ropeScaling, // Whether to apply frequency scaling (determined by model type) - ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling) - ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies - ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision - ropeConfig.oldContextLength // Original context length the model was trained with + Pair ropeFreqs = RoPE.precomputeFreqsCis( + config.contextLength(), // Maximum sequence length the model can process + config.headSize(), // Dimension of each attention head + config.ropeTheta(), // Base frequency parameter (typically 10000.0) + ropeScaling, // Whether to apply frequency scaling (determined by model type) + ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling) + ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies + ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision + ropeConfig.oldContextLength // Original context length the model was trained with ); GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); @@ -104,9 +127,9 @@ public static Weights loadWeights(Map tensorEntries, Co } } - private static Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - return new Weights( + return new LlamaTornadoWeights( // Load directly to TornadoVM format loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), @@ -117,30 +140,37 @@ private static Weights createTornadoVMWeights(Map tenso loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), - FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()); + FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()) { + }; } /** * Creates weights in standard format only */ - private static Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - return new Weights(loadQuantized(tokenEmbeddings), loadArrayOfFloatBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + return new LlamaStandardWeights( + loadQuantized(tokenEmbeddings), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayOfFloatBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), toFloatBuffer(tensorEntries.get("output_norm.weight")), - FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType()); + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadQuantized(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadQuantized(outputWeight), + outputWeight.ggmlType()); } public static FloatTensor loadQuantized(GGMLTensorEntry entry) { GGMLType ggmlType = entry.ggmlType(); return switch (ggmlType) { - // case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); diff --git a/src/main/java/com/example/model/loader/Qwen3ModelLoader.java b/src/main/java/com/example/model/loader/Qwen3ModelLoader.java new file mode 100644 index 0000000..51e5451 --- /dev/null +++ b/src/main/java/com/example/model/loader/Qwen3ModelLoader.java @@ -0,0 +1,146 @@ +package com.example.model.loader; + +import com.example.LlamaApp; +import com.example.auxiliary.Timer; +import com.example.core.model.GGMLType; +import com.example.core.model.GGUF; +import com.example.core.model.tensor.ArrayFloatTensor; +import com.example.core.model.tensor.GGMLTensorEntry; +import com.example.core.types.Pair; +import com.example.inference.operation.RoPE; +import com.example.inference.weights.standard.Qwen3StandardWeights; +import com.example.inference.weights.Weights; +import com.example.model.Configuration; +import com.example.model.format.ChatFormat; +import com.example.model.format.ChatFormat.ChatTokens; +import com.example.model.qwen3.Qwen3; +import com.example.model.qwen3.Qwen3Configuration; +import com.example.tokenizer.impl.Qwen3Tokenizer; +import com.example.tokenizer.impl.Tokenizer; +import com.example.tokenizer.vocabulary.Vocabulary; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.util.Map; + +import static com.example.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary; + +public class Qwen3ModelLoader extends ModelLoader { + + public Qwen3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { + super(fileChannel, gguf, contextLength, loadWeights); + } + + @Override + public Qwen3 loadModel() { + try (var ignored = Timer.log("Load Qwen3 model")) { + Map metadata = gguf.getMetadata(); + + Vocabulary vocabulary = loadQwen3Vocabulary(metadata); + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); + + int modelContextLength = (int) metadata.get("qwen3.context_length"); + if (contextLength < 0 || modelContextLength < contextLength) { + contextLength = modelContextLength; + } + + //String modelName = ggufPath.getFileName().toString(); + Qwen3Configuration config = new Qwen3Configuration( + //modelName, + (int) metadata.get("qwen3.embedding_length"), + (int) metadata.get("qwen3.feed_forward_length"), + (int) metadata.get("qwen3.block_count"), + (int) metadata.get("qwen3.attention.head_count"), + + metadata.containsKey("qwen3.attention.head_count_kv") + ? (int) metadata.get("qwen3.attention.head_count_kv") + : (int) metadata.get("qwen3.attention.head_count"), + (int) metadata.get("qwen3.attention.key_length"), + (int) metadata.get("qwen3.attention.value_length"), + + vocabulary.size(), + modelContextLength, contextLength, + false, + (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"), + (float) metadata.get("qwen3.rope.freq_base") + ); + + Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + // Qwen2.5-coder uses <|endoftext|> as stop-token. + ChatTokens chatTokens = isDeepSeekR1DistillQwen ? + new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") : + new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); + return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Weights loadWeights(Map tensorEntries, Configuration config) { + Pair ropeFreqs = RoPE.precomputeFreqsCis( + config.contextLengthModel(), + config.numberOfHeadsKey(), + config.ropeTheta(), + false, + 0, + 0, + 0, + 0 + ); + + GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); + GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); + + if (LlamaApp.USE_TORNADOVM) { + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); + return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + } else { + return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + } + } + + @Override + public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + throw new UnsupportedOperationException("Not supported yet."); + } + + @Override + public Weights createStandardWeights(Map tensorEntries, + Configuration config, + Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + float[] ropeFreqsReal = ropeFreqs.first(); + float[] ropeFreqsImag = ropeFreqs.second(); + return new Qwen3StandardWeights( + loadQuantized(tokenEmbeddings), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo + + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm + + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 + loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight + new ArrayFloatTensor(ropeFreqsReal), + new ArrayFloatTensor(ropeFreqsImag), + tensorEntries.containsKey("output.weight") + ? ModelLoader.loadQuantized(tensorEntries.get("output.weight")) + : loadQuantized(tokenEmbeddings), // weights are shared + null + ); + } +} diff --git a/src/main/java/com/example/model/mistral/Mistral.java b/src/main/java/com/example/model/mistral/Mistral.java index c1f9d09..9dc74d7 100644 --- a/src/main/java/com/example/model/mistral/Mistral.java +++ b/src/main/java/com/example/model/mistral/Mistral.java @@ -1,23 +1,22 @@ package com.example.model.mistral; -import com.example.auxiliary.Timer; -import com.example.core.model.GGUF; -import com.example.core.model.tensor.GGMLTensorEntry; +import com.example.inference.InferenceCore; +import com.example.inference.InferenceEngine; +import com.example.inference.sampler.Sampler; +import com.example.inference.state.LlamaState; +import com.example.inference.state.State; +import com.example.inference.weights.Weights; import com.example.model.Model; -import com.example.loader.weights.State; -import com.example.loader.weights.Weights; import com.example.model.ModelType; +import com.example.model.format.ChatFormat; import com.example.tokenizer.impl.MistralTokenizer; import com.example.tokenizer.impl.Tokenizer; -import com.example.tokenizer.vocabulary.Vocabulary; -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.util.Map; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; -import static com.example.loader.weights.ModelLoader.loadWeights; - -public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model { +public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model { /* For explicit use */ private MistralTokenizer getAsMistralTokenizer() { @@ -30,57 +29,25 @@ public ModelType getModelType() { } public State createNewState() { - State state = new State(configuration(), -1); + State state = new LlamaState(configuration(), -1); state.latestToken = tokenizer.getSpecialTokens().get(""); return state; } public State createNewState(int batchsize) { - State state = new State(configuration(), batchsize); + State state = new LlamaState(configuration(), batchsize); state.latestToken = tokenizer.getSpecialTokens().get(""); return state; } - // @formatter:off - public static Mistral loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { - try (var ignored = Timer.log("Load Mistral model")) { - Map metadata = gguf.getMetadata(); - - Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata); - Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary); - - int modelContextLength = (int) metadata.get("llama.context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - MistralConfiguration config = new MistralConfiguration( - (int) metadata.get("llama.embedding_length"), - (int) metadata.get("llama.feed_forward_length"), - (int) metadata.get("llama.block_count"), - (int) metadata.get("llama.attention.head_count"), - - metadata.containsKey("llama.attention.head_count_kv") - ? (int) metadata.get("llama.attention.head_count_kv") - : (int) metadata.get("llama.attention.head_count"), - - vocabulary.size(), - contextLength, - false, - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) - ); + @Override + public void forward(State state, int token, int position) { + InferenceCore.forwardJava(this, state, token, position); + } - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - return new Mistral(config, tokenizer, weights); - } catch (IOException e) { - throw new RuntimeException(e); - } + @Override + public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } - // @formatter:on } diff --git a/src/main/java/com/example/model/mistral/MistralConfiguration.java b/src/main/java/com/example/model/mistral/MistralConfiguration.java index d2c14d4..dad7e0b 100644 --- a/src/main/java/com/example/model/mistral/MistralConfiguration.java +++ b/src/main/java/com/example/model/mistral/MistralConfiguration.java @@ -13,6 +13,16 @@ public int kvMul() { return numberOfHeads / numberOfKeyValueHeads; } + @Override + public int numberOfHeadsKey() { + throw new UnsupportedOperationException("Not supported for Mistral."); + } + + @Override + public int contextLengthModel() { + throw new UnsupportedOperationException("Not supported for Mistral."); + } + public int headSize() { return dim / numberOfHeads; } diff --git a/src/main/java/com/example/model/qwen3/Qwen3.java b/src/main/java/com/example/model/qwen3/Qwen3.java new file mode 100644 index 0000000..ad1fc35 --- /dev/null +++ b/src/main/java/com/example/model/qwen3/Qwen3.java @@ -0,0 +1,49 @@ +package com.example.model.qwen3; + +import com.example.inference.InferenceCore; +import com.example.inference.InferenceEngine; +import com.example.inference.sampler.Sampler; +import com.example.inference.state.Qwen3State; +import com.example.inference.state.State; +import com.example.inference.weights.Weights; +import com.example.model.Model; +import com.example.model.ModelType; +import com.example.model.format.ChatFormat; +import com.example.tokenizer.impl.Tokenizer; + +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +public record Qwen3(Qwen3Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) implements Model { + + @Override + public ModelType getModelType() { + return ModelType.QWEN_3; + } + + @Override + public State createNewState() { + State state = new Qwen3State(configuration(), -1); + state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader()); + return state; + } + + @Override + public State createNewState(int batchsize) { + State state = new Qwen3State(configuration(), batchsize); + state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader()); + return state; + } + + @Override + public void forward(State state, int token, int position) { + InferenceCore.forwardJavaQwen3(this, state, token, position); + } + + @Override + public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } + +} diff --git a/src/main/java/com/example/model/qwen3/Qwen3Configuration.java b/src/main/java/com/example/model/qwen3/Qwen3Configuration.java new file mode 100644 index 0000000..f1cdf87 --- /dev/null +++ b/src/main/java/com/example/model/qwen3/Qwen3Configuration.java @@ -0,0 +1,37 @@ +package com.example.model.qwen3; + +import com.example.model.Configuration; + +public record Qwen3Configuration(int dim, + int hiddenDim, + int numberOfLayers, + int numberOfHeads, + int numberOfKeyValueHeads, + int numberOfHeadsKey, + int numberOfHeadsValue, + int vocabularySize, + int contextLengthModel, + int contextLength, + boolean sharedWeights, + float rmsNormEps, + float ropeTheta) implements Configuration { + @Override + public int headSize() { + throw new UnsupportedOperationException("Not supported for Qwen3."); + } + + @Override + public int kvDim() { + throw new UnsupportedOperationException("Not supported for Qwen3."); + } + + @Override + public int kvMul() { + throw new UnsupportedOperationException("Not supported for Qwen3."); + } + + @Override + public int contextLengthModel() { + return contextLengthModel; + } +} diff --git a/src/main/java/com/example/tokenizer/impl/Qwen3Tokenizer.java b/src/main/java/com/example/tokenizer/impl/Qwen3Tokenizer.java new file mode 100644 index 0000000..24f7cf2 --- /dev/null +++ b/src/main/java/com/example/tokenizer/impl/Qwen3Tokenizer.java @@ -0,0 +1,317 @@ +package com.example.tokenizer.impl; + +import com.example.auxiliary.Utf8Mask; +import com.example.core.types.Pair; +import com.example.tokenizer.vocabulary.Vocabulary; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class Qwen3Tokenizer implements Tokenizer { + private final static String QWEN3_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + private final Pattern compiledPattern; + private final Vocabulary vocabulary; + private final Map, Integer> merges; + private final Map specialTokens; + private final int[] tokenTypes; + + /** buffer to store incomplete UTF-8 sequence */ + private final byte[] bufUtf8 = new byte[4]; + /** index in UTF-8 buffer */ + private int currUtf8Index = 0; + /** current UTF-8 mask */ + private Utf8Mask currUtf8Mask; + + @Override + public String regexPattern() { + if (compiledPattern == null) { + return null; + } + return compiledPattern.pattern(); + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + @Override + public boolean isSpecialToken(int tokenIndex) { + return specialTokens.containsValue(tokenIndex); + } + + @Override + public boolean shouldDisplayToken(int token) { + int tokenType = getTokenType(token); + + return tokenType == 1 || tokenType == 6; + } + + public int getTokenType(int tokenIndex) { + if (tokenTypes == null) { + throw new IllegalStateException("Qwen3Tokenizer hasn't been constructed using tokenTypes"); + } + return tokenTypes[tokenIndex]; + } + + public Qwen3Tokenizer(Map metadata, Vocabulary vocabulary, boolean isDeepSeekR1DistillQwen) { + int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); + String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); + List> merges = Arrays.stream(mergeLines) + .map(line -> line.split(" ")) + .map(parts -> + new Pair<>( + vocabulary.getIndex(parts[0]).orElseThrow(), + vocabulary.getIndex(parts[1]).orElseThrow()) + ).toList(); + + int allTokens = vocabulary.size(); + String firstSpecialToken = isDeepSeekR1DistillQwen ? "<|end▁of▁sentence|>" : "<|endoftext|>"; + int baseTokens = vocabulary.getIndex(firstSpecialToken).orElseThrow(); // assume all tokens after the base ones are special. + // int reservedSpecialTokens = allTokens - baseTokens; + List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); + + assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); + + Map specialTokens = + IntStream.range(0, specialTokensList.size()) + .boxed() + .collect(Collectors.toMap( + i -> specialTokensList.get(i), + i -> baseTokens + i) + ); + specialTokens.remove(""); + specialTokens.remove(""); + + this.vocabulary = vocabulary; + this.compiledPattern = Pattern.compile(QWEN3_PATTERN); + this.specialTokens = new HashMap<>(specialTokens); + this.merges = new HashMap<>(); + this.tokenTypes = tokenTypes; + for (Pair pair : merges) { + int firstIndex = pair.first(); + int secondIndex = pair.second(); + int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow(); + this.merges.put(pair, mergeIndex); + } + } + + private int[] encodeImpl(String text) { + return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); + } + + static List findAll(Pattern pattern, String text) { + List allMatches = new ArrayList<>(); + Matcher matcher = pattern.matcher(text); + while (matcher.find()) { + allMatches.add(matcher.group()); + } + return allMatches; + } + + /** + * Encoding that ignores any special tokens. + */ + public List encodeOrdinary(String text) { + // split text into chunks of text by categories defined in regex pattern + List textChunks = findAll(compiledPattern, text); + // all chunks of text are encoded separately, then results are joined + List ids = new ArrayList<>(); + for (String chunk : textChunks) { + List chunkIds = encodeChunk(chunk); + ids.addAll(chunkIds); + } + return ids; + } + + private Map, Integer> getStats(List ids) { + Map, Integer> map = new HashMap<>(); + for (int i = 0; i + 1 < ids.size(); i++) { + Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); + map.put(key, map.getOrDefault(key, 0) + 1); + } + return map; + } + + private List encodeChunk(String chunk) { + // return the token ids + // let's begin. first, convert all bytes to integers in range 0..255 + List ids = new ArrayList<>(); + for (int b : chunk.toCharArray()) { + int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow(); + ids.add(tokenIndex); + } + + while (ids.size() >= 2) { + // find the pair with the lowest merge index + Map, Integer> stats = getStats(ids); + Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); + // subtle: if there are no more merges available, the key will + // result in an inf for every single pair, and the min will be + // just the first pair in the list, arbitrarily + // we can detect this terminating case by a membership check + if (!this.merges.containsKey(pair)) { + break; // nothing else can be merged anymore + } + // otherwise let's merge the best pair (lowest merge index) + int idx = this.merges.get(pair); + ids = merge(ids, pair, idx); + } + return ids; + } + + static List merge(List ids, Pair pair, int idx) { + List newids = new ArrayList<>(); + int i = 0; + while (i < ids.size()) { + // if not at the very last position AND the pair matches, replace it + if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { + newids.add(idx); + i += 2; + } else { + newids.add(ids.get(i)); + i += 1; + } + } + return newids; + } + + /** + * Returns list of utf-8 byte and a corresponding list of unicode strings. + * The reversible bpe codes work on unicode strings. + * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + * This is a significant percentage of your normal, say, 32K bpe vocab. + * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + * And avoids mapping to whitespace/control characters the bpe code barfs on. + */ + static Map bytesToUnicode() { + List bs = new ArrayList<>(); + IntStream.rangeClosed('!', '~').forEach(bs::add); + IntStream.rangeClosed('¡', '¬').forEach(bs::add); + IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); + + List cs = new ArrayList<>(bs); + int n = 0; + for (int b = 0; b < 256; ++b) { + if (!bs.contains(b)) { + bs.add(b); + cs.add(256 + n); + n += 1; + } + } + + // return dict(zip(bs, cs)) + return IntStream.range(0, bs.size()) + .boxed() + .collect(Collectors.toMap(bs::get, cs::get)); + } + + static final Map BYTE_ENCODER = bytesToUnicode(); + static final Map BYTE_DECODER = BYTE_ENCODER.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + + public int[] encode(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return encodeImpl(sb.toString()); + } + + @Override + public List encode(String text, Set allowedSpecial) { + // decode the user desire w.r.t. handling of special tokens + Set special = allowedSpecial; + assert getSpecialTokens().keySet().containsAll(special); + if (special.isEmpty()) { + // shortcut: if no special tokens, just use the ordinary encoding + return encodeOrdinary(text); + } + + // otherwise, we have to be careful with potential special tokens in text + // we handle special tokens by splitting the text + // based on the occurrence of any exact match with any of the special tokens + // we can use re.split for this. note that surrounding the pattern with () + // makes it into a capturing group, so the special tokens will be included + String specialPattern = special + .stream() + .map(Pattern::quote) + .collect(Collectors.joining("|", "(", ")")); + + String[] specialChunks = text.split(specialPattern); + // now all the special characters are separated from the rest of the text + // all chunks of text are encoded separately, then results are joined + List ids = new ArrayList<>(); + for (String part : specialChunks) { + if (special.contains(part)) { + // this is a special token, encode it separately as a special case + ids.add(getSpecialTokens().get(part)); + } else { + // this is an ordinary sequence, encode it normally + ids.addAll(encodeOrdinary(part)); + } + } + return ids; + } + + @Override + public List encodeAsList(String text) { + return Arrays.stream(encode(text)).boxed().toList(); + } + + public String decodeImpl(List tokens) { + StringBuilder sb = new StringBuilder(); + for (int token : tokens) { + String tokenString = vocabulary.get(token); + sb.append(tokenString); + } + return sb.toString(); + } + + @Override + public String decode(List tokens) { + String decoded = decodeImpl(tokens); + // The '|' in '<|end▁of▁sentence|>' of DeepSeek-R1 has code-point 65372. + int[] decodedBytesAsInts = decoded.codePoints().map(cp -> cp <= 512 ? BYTE_DECODER.get(cp) : cp).toArray(); + byte[] rawBytes = new byte[decodedBytesAsInts.length + 3]; + int indexRawByte = 0; + loopDecoded: + for (int i = 0; i < decoded.length(); i++) { + byte b = (byte) decodedBytesAsInts[i]; + if (currUtf8Index == 0) { + for (Utf8Mask utf8Mask : Utf8Mask.MASKS) { + if ((b & utf8Mask.mask()) == utf8Mask.pattern()) { + currUtf8Mask = utf8Mask; + bufUtf8[currUtf8Index++] = b; + continue loopDecoded; + } + } + } + if (currUtf8Index > 0 && currUtf8Mask != null) { + bufUtf8[currUtf8Index++] = b; + if (currUtf8Index == currUtf8Mask.len()) { + System.arraycopy(bufUtf8, 0, rawBytes, indexRawByte, currUtf8Mask.len()); + indexRawByte += currUtf8Mask.len(); + currUtf8Index = 0; + currUtf8Mask = null; + } + continue; + } + rawBytes[indexRawByte++] = b; + } + return new String(rawBytes, 0, indexRawByte, StandardCharsets.UTF_8); + } +} diff --git a/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java b/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java index a45c073..459fb69 100644 --- a/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java +++ b/src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java @@ -40,6 +40,12 @@ public static Vocabulary loadMistralVocabulary(Map metadata) { return v; } + public static Vocabulary loadQwen3Vocabulary(Map metadata) { + String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); + float[] scores = (float[]) metadata.get("tokenizer.ggml.scores"); + return new Vocabulary(tokens, scores); + } + public int size() { return tokens.length; } diff --git a/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java b/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java index e1864a4..18618b8 100644 --- a/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java +++ b/src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java @@ -1,10 +1,10 @@ package com.example.tornadovm; import com.example.auxiliary.Tuple2; +import com.example.inference.weights.tornado.TornadoWeights; import com.example.model.Configuration; import com.example.model.Model; -import com.example.loader.weights.State; -import com.example.loader.weights.Weights; +import com.example.inference.state.State; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -49,7 +49,7 @@ public class TornadoVMLayerPlanner { private final State state; private final Configuration config; - private final Weights weights; + private final TornadoWeights weights; private final KernelContext context; /** @@ -63,7 +63,7 @@ public class TornadoVMLayerPlanner { public TornadoVMLayerPlanner(State state, Model model) { this.state = state; this.config = model.configuration(); - this.weights = model.weights(); + this.weights = (TornadoWeights) model.weights(); this.context = new KernelContext(); } @@ -182,7 +182,7 @@ public Tuple2, GridScheduler> setupTornadoForwardPlanLa */ // @formatter:on private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) { - switch (weights.weightType) { + switch (weights.getWeightType()) { case F16: case Q8_0: case Q4_0: @@ -191,7 +191,7 @@ private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) { config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // break; default: - throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.weightType + ". Only Q8_0 and Q4_0 are supported."); + throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.getWeightType() + ". Only Q8_0 and Q4_0 are supported."); } return logits; } diff --git a/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java b/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java index af167dd..594a15b 100644 --- a/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java @@ -1,7 +1,7 @@ package com.example.tornadovm; import com.example.auxiliary.Tuple2; -import com.example.loader.weights.State; +import com.example.inference.state.State; import com.example.model.Configuration; import com.example.model.Model; import com.example.model.ModelType;