Skip to content

Add GUI Chatbox for GPULlama3.java Inference #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4f2598f
Add JavaFX dependencies and set up main class.
svntax Jun 21, 2025
e00486e
Set up MVCI framework for the GUI chatbox.
svntax Jun 23, 2025
f65a1f6
Implement run inference button.
svntax Jun 24, 2025
44c9750
Implement Browse button for Llama3 path.
svntax Jun 24, 2025
a0ce181
Implement Reload button for scanning for model files.
svntax Jun 24, 2025
f21cec1
Change output text area to autoscroll to the bottom.
svntax Jun 24, 2025
a281461
Change buttons to be disabled while inference is running.
svntax Jun 24, 2025
ad22925
Adjust buttons, labels, and container nodes for GUI chatbox.
svntax Jun 25, 2025
7c29d68
Add AtlantaFX for new theme styles.
svntax Jun 25, 2025
c43f7a9
Decrease padding between left and right panels.
svntax Jun 25, 2025
bed9b66
Remove unused import
svntax Jun 25, 2025
eb454a7
Remove the need for setting Llama3 path.
svntax Jun 26, 2025
f1b565f
Modify LlamaApp and llama-tornado to support GUI mode.
svntax Jun 26, 2025
03ee28c
Refactor chatbox interactor to run models directly.
svntax Jun 27, 2025
bc43b28
Fix width for engine dropdown menu and remove unused import.
svntax Jun 27, 2025
d381457
Add dropdown menu for chat mode selection.
svntax Jun 28, 2025
6937602
Change USE_TORNADOVM flag into a regular member variable.
svntax Jul 4, 2025
2e2408f
Implement interactive mode for GUI chatbox.
svntax Jul 7, 2025
480336e
Disable GUI controls while an interactive chat is running.
svntax Jul 7, 2025
b898006
Merge branch 'main' of https://github.com/svntax/GPULlama3.java into …
svntax Jul 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion llama-tornado
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ class LlamaRunner:
print(f" {arg}")
print()

if args.gui:
cmd.append("--gui")

# Execute the command
try:
result = subprocess.run(cmd, check=True)
Expand Down Expand Up @@ -316,7 +319,8 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--model",
dest="model_path",
required=True,
required="--gui" not in sys.argv,
default="",
help="Path to the LLM gguf file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)",
)

Expand Down Expand Up @@ -466,6 +470,12 @@ def create_parser() -> argparse.ArgumentParser:
"--verbose", "-v", action="store_true", help="Verbose output"
)

# GUI options
gui_group = parser.add_argument_group("GUI Options")
gui_group.add_argument(
"--gui", action="store_true", help="Launch the GUI chatbox"
)

return parser


Expand Down
21 changes: 21 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
<artifactId>tornado-runtime</artifactId>
<version>1.1.1-dev</version>
</dependency>

<dependency>
<groupId>org.openjfx</groupId>
<artifactId>javafx-controls</artifactId>
<version>21</version>
</dependency>

<dependency>
<groupId>io.github.mkpaz</groupId>
<artifactId>atlantafx-base</artifactId>
<version>2.0.1</version>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -68,6 +80,15 @@
</execution>
</executions>
</plugin>

<plugin>
<groupId>org.openjfx</groupId>
<artifactId>javafx-maven-plugin</artifactId>
<version>0.0.8</version>
<configuration>
<mainClass>com.example.gui.LlamaChatbox</mainClass>
</configuration>
</plugin>
</plugins>
</build>
</project>
45 changes: 37 additions & 8 deletions src/main/java/com/example/LlamaApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import com.example.aot.AOT;
import com.example.core.model.tensor.FloatTensor;
import com.example.gui.LlamaChatbox;
import com.example.inference.sampler.CategoricalSampler;
import com.example.inference.sampler.Sampler;
import com.example.inference.sampler.ToppSampler;
import com.example.loader.weights.ModelLoader;
import com.example.model.Model;
import com.example.tornadovm.FloatArrayUtils;
import javafx.application.Application;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

import java.io.IOException;
Expand All @@ -18,9 +20,29 @@ public class LlamaApp {
// Configuration flags for hardware acceleration and optimizations
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
private boolean useTornadoVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode

private static LlamaApp instance;

private LlamaApp() {
}

public static LlamaApp getInstance() {
if (instance == null) {
instance = new LlamaApp();
}
return instance;
}

public void setUseTornadoVM(boolean value) {
useTornadoVM = value;
}

public boolean getUseTornadoVM() {
return useTornadoVM;
}

/**
* Creates and configures a sampler for token generation based on specified parameters.
*
Expand Down Expand Up @@ -118,7 +140,7 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
* @throws IOException if the model fails to load
* @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable
*/
private static Model loadModel(Options options) throws IOException {
public static Model loadModel(Options options) throws IOException {
if (USE_AOT) {
Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
if (model == null) {
Expand All @@ -129,7 +151,7 @@ private static Model loadModel(Options options) throws IOException {
return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
}

private static Sampler createSampler(Model model, Options options) {
public static Sampler createSampler(Model model, Options options) {
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
}

Expand All @@ -145,13 +167,20 @@ private static Sampler createSampler(Model model, Options options) {
*/
public static void main(String[] args) throws IOException {
Options options = Options.parseOptions(args);
Model model = loadModel(options);
Sampler sampler = createSampler(model, options);

if (options.interactive()) {
model.runInteractive(sampler, options);
if (options.guiMode()) {
// Launch the JavaFX application
Application.launch(LlamaChatbox.class, args);
} else {
model.runInstructOnce(sampler, options);
// Run the CLI logic
Model model = loadModel(options);
Sampler sampler = createSampler(model, options);

if (options.interactive()) {
model.runInteractive(sampler, options);
} else {
model.runInstructOnce(sampler, options);
}
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/com/example/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import java.nio.file.Paths;

public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive,
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, boolean guiMode) {

public static final int DEFAULT_MAX_TOKENS = 1024;

Expand Down Expand Up @@ -41,6 +41,7 @@ static void printUsage(PrintStream out) {
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
out.println(" --echo <boolean> print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
out.println(" --gui <boolean> run the GUI chatbox");
out.println();
}

Expand All @@ -57,13 +58,15 @@ public static Options parseOptions(String[] args) {
boolean interactive = false;
boolean stream = true;
boolean echo = false;
boolean guiMode = false;

for (int i = 0; i < args.length; i++) {
String optionName = args[i];
require(optionName.startsWith("-"), "Invalid option %s", optionName);
switch (optionName) {
case "--interactive", "--chat", "-i" -> interactive = true;
case "--instruct" -> interactive = false;
case "--gui" -> guiMode = true;
case "--help", "-h" -> {
printUsage(System.out);
System.exit(0);
Expand Down Expand Up @@ -95,6 +98,6 @@ public static Options parseOptions(String[] args) {
}
}
}
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo);
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, guiMode);
}
}
36 changes: 36 additions & 0 deletions src/main/java/com/example/gui/ChatboxController.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.example.gui;

import javafx.concurrent.Task;
import javafx.scene.layout.Region;

public class ChatboxController {

private final ChatboxViewBuilder viewBuilder;
private final ChatboxInteractor interactor;

public ChatboxController() {
ChatboxModel model = new ChatboxModel();
interactor = new ChatboxInteractor(model);
viewBuilder = new ChatboxViewBuilder(model, this::runInference);
}

private void runInference(Runnable postRunAction) {
Task<Void> inferenceTask = new Task<>() {
@Override
protected Void call() {
interactor.runLlamaTornado();
return null;
}
};
inferenceTask.setOnSucceeded(evt -> {
postRunAction.run();
});
Thread inferenceThread = new Thread(inferenceTask);
inferenceThread.start();
}

public Region getView() {
return viewBuilder.build();
}

}
Loading