![]() |
Integration of Llama3 models with TornadoVM to enable accelerated inference on Java using GPUs and CPUs. This project allows you to run Llama3 inference efficiently, leveraging TornadoVM's parallel computing features for enhanced performance.
This project builds on Llama3.java, based on the original Llama 3, 3.1, and 3.2 models, with TornadoVM support for parallelism and hardware acceleration. Thanks to @mukel for the original implementation of LLama3.java. Previous intergration of TornadoVM and Llama2 it can be found in llama2.tornadovm. |
This table shows inference performance across different hardware and quantization options.
Hardware | Llama-3.2-1B-Instruct | Llama-3.2-1B-Instruct | Llama-3.2-3B-Instruct | Optimizations |
---|---|---|---|---|
Q8_0 | Q4_0 | Q4_0 | Support | |
NVIDIA GPUs | ||||
RTX 3070 | 42.3 tokens/s | 78.6 tokens/s | 22.1 tokens/s | β |
RTX 4090 | 96.7 tokens/s | 158.2 tokens/s | 52.9 tokens/s | β |
RTX 5090 | 156.8 tokens/s | 243.5 tokens/s | 84.7 tokens/s | β |
H100 | 178.3 tokens/s | 289.7 tokens/s | 102.5 tokens/s | β |
Apple Silicon | ||||
M3 Pro | 18.4 tokens/s | 35.7 tokens/s | 11.6 tokens/s | β |
M4 Pro | 28.9 tokens/s | 52.3 tokens/s | 17.2 tokens/s | β |
AMD GPUs | ||||
Radeon RX | (WIP) | (WIP) | (WIP) | β |
Note: β indicates hardware with optimized kernels for maximum performance. Benchmark details: Settings used include context length of 4096, batch size 1, and default parameters.
- TornadoVM-accelerated Llama 3 inference with pure Java
- Support for GGUF format models with Q8_0 and Q4_0 quantization
- Instruction-following and chat modes for various use cases
- Cross-platform compatibility:
- β NVIDIA GPUs (OpenCL & PTX (Soon))
- Interactive CLI with
--interactive
and--instruct
modes - Flexible backend switching - choose OpenCL or PTX at runtime (need to build TornadoVM with both enabled)
Ensure you have the following installed and configured:
- Java 21+: Required for Vector API support.
- TornadoVM: To install TornadoVM, you'll need to set up the environment variables
TORNADO_ROOT
andTORNADO_SDK
as part of the configuration process. For detailed installation instructions, visit the TornadoVM GitHub repository. - Maven: For building the Java project.
When cloning this repository, use the --recursive
flag to ensure that TornadoVM is properly included as submodule:
# Clone the repository with all submodules
git clone --recursive git@github.com:mikepapadim/GPULlama3.java.git
# Navigate to the project directory
cd GPULlama3.java
# Enter the TornadoVM submodule directory
cd external/tornadovm
# Optional: Create and activate a Python virtual environment if needed
python3 -m venv venv
source ./venv/bin/activate
# Install TornadoVM with OpenCL backend and OpenJDK 21 [Optional] -> --backend opencl,ptx
# Be sure to have the correct JDK version installed and the TornadoVM installer script is executed correctly.
# you can run at this step: tornado --devices to check if the installation was successful
./bin/tornadovm-installer --jdk jdk21 --backend opencl
# Source the TornadoVM environment variables
source setvars.sh
# Navigate back to the project root directory
cd ../../
# Make the llama-tornado script executable
chmod +x llama-tornado
# Source the project-specific environment paths -> this will ensure the correct paths are set for the project and the TornadoVM SDK
# Expect to see: [INFO] Environment configured for LLaMA3 with TornadoVM at: /home/YOUR_PATH_TO_TORNADOVM
source set_paths
# Build the project using Maven (skip tests for faster build)
# mvn clean package -DskipTests or just make
make
# Run the model (make sure you have downloaded the model file first - see below)
./llama-tornado --gpu --opencl --model Llama-3.2-1B-Instruct-Q4_0.gguf --prompt "tell me a joke"
The above model can we swapped with one of the other models, such as Llama-3.2-3B-Instruct-Q4_0.gguf
or Meta-Llama-3-8B-Instruct-Q4_0.gguf
, depending on your needs.
Check models below.
Download pure Q4_0
and (optionally) Q8_0
quantized .gguf files from:
The pure Q4_0
quantized models are recommended, except for the very small models (1B), please be gentle with huggingface.co servers:
# Llama 3.2 (1B) - Q4_0
curl -L -O https://huggingface.co/mukel/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf
# Llama 3.2 (3B) - Q4_0
curl -L -O https://huggingface.co/mukel/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf
# Llama 3 (8B) - Q4_0
curl -L -O https://huggingface.co/mukel/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf
# Llama 3.2 (1B) - Q8_0
curl -L -O https://huggingface.co/mukel/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q8_0.gguf
# Llama 3.1 (8B) - Q8_0
curl -L -O https://huggingface.co/mukel/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_0.gguf
# Llama 3 (8B) - Q8_0
# Optionally download the Q8_0 quantized models
# curl -L -O https://huggingface.co/mukel/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q8_0.gguf
# curl -L -O https://huggingface.co/mukel/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf
In the wild, Q8_0
quantizations are fine, but Q4_0
quantizations are rarely pure e.g. the token_embd.weights
/output.weights
tensor are quantized with Q6_K
, instead of Q4_0
.
A pure Q4_0
quantization can be generated from a high precision (F32, F16, BFLOAT16) .gguf source
with the llama-quantize
utility from llama.cpp as follows:
./llama-quantize --pure ./Meta-Llama-3-8B-Instruct-F32.gguf ./Meta-Llama-3-8B-Instruct-Q4_0.gguf Q4_0
Set up environment variables by editing and sourcing the set_paths.sh
script in the project root directory:
# Point to your TornadoVM installation directory
export TORNADO_ROOT=/path/to/TornadoVM
# Locate the TornadoVM SDK binaries and libraries
export TORNADO_SDK=${TORNADO_ROOT}/bin/sdk
# Set the path to this GPULlama.java project
export LLAMA_ROOT=/path/to/this-project
# Add the project's binary directory to your PATH for easy access
export PATH="${PATH}:${LLAMA_ROOT}/bin"
# Clean previous builds and package the project (skip tests for faster builds)
mvn clean package -DskipTests
The llama-tornado
script executes Llama3 models on TornadoVM. By default, models run on CPU; add --gpu
for GPU acceleration.
Run a model with a text prompt:
./llama-tornado --gpu --opencl --model Llama-3.2-1B-Instruct-Q8_0.gguf --prompt "Explain the benefits of GPU acceleration."
Enable GPU acceleration with Q8_0 quantization:
llama-tornado --gpu --model Llama-3.2-1B-Instruct-Q8_0.gguf --prompt "tell me a joke"
Run with Q4_0 quantization for lower memory usage:
llama-tornado --gpu --model Llama-3.2-1B-Instruct-Q4_0.gguf --prompt "tell me a joke"
Specify the backend (OpenCL or PTX):
# Use OpenCL backend (default)
./llama-tornado --gpu --opencl --model model.gguf --prompt "..."
# Use PTX backend for NVIDIA GPUs
./llama-tornado --gpu --ptx --model model.gguf --prompt "..."
If you encounter an out of memory error like:
Exception in thread "main" uk.ac.manchester.tornado.api.exceptions.TornadoOutOfMemoryException: Unable to allocate 100663320 bytes of memory.
To increase the maximum device memory, use -Dtornado.device.memory=<X>GB
This indicates that the default GPU memory allocation (7GB) is insufficient for your model.
Increase the GPU memory allocation using the --gpu-memory
flag:
# For 3B models, try increasing to 15GB
./llama-tornado --gpu --model Llama-3.2-3B-Instruct-Q4_0.gguf --prompt "Tell me a joke" --gpu-memory 15GB
# For 8B models, you may need even more (20GB or higher)
./llama-tornado --gpu --model Meta-Llama-3-8B-Instruct-Q4_0.gguf --prompt "Tell me a joke" --gpu-memory 20GB
Model Size | Recommended GPU Memory |
---|---|
1B models | 7GB (default) |
3B models | 15GB |
8B models | 20GB+ |
Note: The actual memory requirement depends on your GPU's available memory. Check your GPU specifications and adjust accordingly. If you still encounter memory issues, try:
- Using Q4_0 instead of Q8_0 quantization (requires less memory)
- Closing other GPU-intensive applications
View TornadoVM's internal behavior:
# Print thread information during execution
./llama-tornado --gpu --model model.gguf --prompt "..." --print-threads
# Show bytecode compilation details
./llama-tornado --gpu --model model.gguf --prompt "..." --print-bytecodes
# Display generated GPU kernel code
./llama-tornado --gpu --model model.gguf --prompt "..." --print-kernel
# Enable full debug output with all details
./llama-tornado --gpu --model model.gguf --prompt "..." --debug --full-dump
# Combine debug options
./llama-tornado --gpu --model model.gguf --prompt "..." --print-threads --print-bytecodes --print-kernel
Supported command-line options include:
cmd β llama-tornado --help
usage: llama-tornado [-h] --model MODEL_PATH [--prompt PROMPT] [-sp SYSTEM_PROMPT] [--temperature TEMPERATURE] [--top-p TOP_P] [--seed SEED] [-n MAX_TOKENS]
[--stream STREAM] [--echo ECHO] [-i] [--instruct] [--gpu] [--opencl] [--ptx] [--gpu-memory GPU_MEMORY] [--heap-min HEAP_MIN] [--heap-max HEAP_MAX]
[--debug] [--profiler] [--profiler-dump-dir PROFILER_DUMP_DIR] [--print-bytecodes] [--print-threads] [--print-kernel] [--full-dump]
[--show-command] [--execute-after-show] [--opencl-flags OPENCL_FLAGS] [--max-wait-events MAX_WAIT_EVENTS] [--verbose]
GPU-accelerated LLaMA.java model runner using TornadoVM
options:
-h, --help show this help message and exit
--model MODEL_PATH Path to the LLaMA model file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf) (default: None)
LLaMA Configuration:
--prompt PROMPT Input prompt for the model (default: None)
-sp SYSTEM_PROMPT, --system-prompt SYSTEM_PROMPT
System prompt for the model (default: None)
--temperature TEMPERATURE
Sampling temperature (0.0 to 2.0) (default: 0.1)
--top-p TOP_P Top-p sampling parameter (default: 0.95)
--seed SEED Random seed (default: current timestamp) (default: None)
-n MAX_TOKENS, --max-tokens MAX_TOKENS
Maximum number of tokens to generate (default: 512)
--stream STREAM Enable streaming output (default: True)
--echo ECHO Echo the input prompt (default: False)
Mode Selection:
-i, --interactive Run in interactive/chat mode (default: False)
--instruct Run in instruction mode (default) (default: True)
Hardware Configuration:
--gpu Enable GPU acceleration (default: False)
--opencl Use OpenCL backend (default) (default: None)
--ptx Use PTX/CUDA backend (default: None)
--gpu-memory GPU_MEMORY
GPU memory allocation (default: 7GB)
--heap-min HEAP_MIN Minimum JVM heap size (default: 20g)
--heap-max HEAP_MAX Maximum JVM heap size (default: 20g)
Debug and Profiling:
--debug Enable debug output (default: False)
--profiler Enable TornadoVM profiler (default: False)
--profiler-dump-dir PROFILER_DUMP_DIR
Directory for profiler output (default: /home/mikepapadim/repos/gpu-llama3.java/prof.json)
TornadoVM Execution Verbose:
--print-bytecodes Print bytecodes (tornado.print.bytecodes=true) (default: False)
--print-threads Print thread information (tornado.threadInfo=true) (default: False)
--print-kernel Print kernel information (tornado.printKernel=true) (default: False)
--full-dump Enable full debug dump (tornado.fullDebug=true) (default: False)
Command Display Options:
--show-command Display the full Java command that will be executed (default: False)
--execute-after-show Execute the command after showing it (use with --show-command) (default: False)
Advanced Options:
--opencl-flags OPENCL_FLAGS
OpenCL compiler flags (default: -cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only)
--max-wait-events MAX_WAIT_EVENTS
Maximum wait events for TornadoVM event pool (default: 32000)
--verbose, -v Verbose output (default: False)
Want to see exactly what's happening under the hood? Our llama-tornado
wrapper script makes it crystal clear. Just add the --show-command
flag and witness the beauty of the underlying Java invocation:
llama-tornado --gpu --model Llama-3.2-1B-Instruct-Q8_0.gguf --prompt "tell me a joke" --show-command
π Click to see the full Java command
/home/mikepapadim/.sdkman/candidates/java/current/bin/java \
-server \
-XX:+UnlockExperimentalVMOptions \
-XX:+EnableJVMCI \
-Xms20g -Xmx20g \
--enable-preview \
-Djava.library.path=/home/mikepapadim/manchester/TornadoVM/bin/sdk/lib \
-Djdk.module.showModuleResolution=false \
--module-path .:/home/mikepapadim/manchester/TornadoVM/bin/sdk/share/java/tornado \
-Dtornado.load.api.implementation=uk.ac.manchester.tornado.runtime.tasks.TornadoTaskGraph \
-Dtornado.load.runtime.implementation=uk.ac.manchester.tornado.runtime.TornadoCoreRuntime \
-Dtornado.load.tornado.implementation=uk.ac.manchester.tornado.runtime.common.Tornado \
-Dtornado.load.annotation.implementation=uk.ac.manchester.tornado.annotation.ASMClassVisitor \
-Dtornado.load.annotation.parallel=uk.ac.manchester.tornado.api.annotations.Parallel \
-Duse.tornadovm=true \
-Dtornado.threadInfo=false \
-Dtornado.debug=false \
-Dtornado.fullDebug=false \
-Dtornado.printKernel=false \
-Dtornado.print.bytecodes=false \
-Dtornado.device.memory=7GB \
-Dtornado.profiler=false \
-Dtornado.log.profiler=false \
-Dtornado.profiler.dump.dir=/home/mikepapadim/repos/gpu-llama3.java/prof.json \
-Dtornado.enable.fastMathOptimizations=true \
-Dtornado.enable.mathOptimizations=false \
-Dtornado.enable.nativeFunctions=fast \
-Dtornado.loop.interchange=true \
-Dtornado.eventpool.maxwaitevents=32000 \
"-Dtornado.opencl.compiler.flags=-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only" \
--upgrade-module-path /home/mikepapadim/manchester/TornadoVM/bin/sdk/share/java/graalJars \
@/home/mikepapadim/manchester/TornadoVM/bin/sdk/etc/exportLists/common-exports \
@/home/mikepapadim/manchester/TornadoVM/bin/sdk/etc/exportLists/opencl-exports \
--add-modules ALL-SYSTEM,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.opencl \
-cp /home/mikepapadim/repos/gpu-llama3.java/target/gpu-llama3-1.0-SNAPSHOT.jar \
com.example.LlamaApp \
-m Llama-3.2-1B-Instruct-Q8_0.gguf \
--temperature 0.1 \
--top-p 0.95 \
--seed 1746903566 \
--max-tokens 512 \
--stream true \
--echo false \
-p "tell me a joke" \
--instruct
That's right! Behind all the GPU acceleration and performance optimizations, you're looking at a standard Java application:
- Entry Point:
com.example.LlamaApp
- JAR File:
/path/to/gpu-llama3-1.0-SNAPSHOT.jar
- JVM Flags: Standard OpenJDK flags with TornadoVM extensions
- Arguments: Plain old command-line arguments
- JAR File:
The secret sauce that transforms regular Java code into GPU-accelerated compute kernels. All those -Dtornado.*
flags? They're just configuring TornadoVM to:
- π Automatically compile Java methods to GPU kernels
- π Manage GPU memory and data transfers
- β‘ Optimize loop execution for parallel hardware
- π Provide debugging and profiling capabilities
- Quantized Weight Support
- Optimized implementations for Q8_0 and Q4_0 formats
- Block-based quantization with FP16 scale per 32-element block
- Vectorized Matrix Operations
- Uses vector parallelism with configurable unroll factors
- Processes 4 elements at once with vectorization
- Loop Unrolling
- Strategic unrolling for performance (16x factor in matrix operations)
- Reduces branch penalties and improves instruction-level parallelism
- Fused Multiply-Add (FMA)
- Uses fused operations for better numerical precision and performance
- Optimizes dot product calculations
- Key-Value Cache
- Efficiently stores past key-values for autoregressive generation
- Organized by layer, position, and dimension for fast access
- Scale Caching
- Avoids redundant decompression of quantized weights
- Caches scale factors for efficient block processing
- Optimized GPU Memory Transfers
- Minimizes host-device data movement
- One-time transfer of static data (weights, caches)
- Per-execution transfer of dynamic data (position, activations)
- Device-to-Device Data Consumption
- Efficient data transfer between operations
- Reduces PCI-E bandwidth bottlenecks
- Parallel Reduction RMS Normalization
- Implements two-phase reduction for efficient normalization
- Work group optimization for parallel sums
- Rotary Position Embeddings (RoPE)
- Optimized implementation for positional encoding
- Efficient rotation of query and key vectors
- Optimized Float16 Decoding
- Fast decoder for half-precision floating point format
- Special case handling for better performance
- Parallelized Attention
- Computes attention heads in parallel
- Optimized softmax with max subtraction for numerical stability
- Fused Feed-Forward Networks
- Combines operations for SwiGLU variant used in LLaMA models
- Optimized SiLU and GELU activation functions
- Layered Execution Planning
- Organizes computation as separate layer-based task graphs
- Strategic scheduling of operations
- Work Group Optimization
- Tailored worker grid configurations for different operations
- Matches GPU hardware characteristics
- Local Memory Optimization
- Strategic use of local/shared memory for reductions
- Optimizes bandwidth-intensive operations
MIT