diff --git a/llama-tornado b/llama-tornado
index 112a7ca..5df1851 100755
--- a/llama-tornado
+++ b/llama-tornado
@@ -251,6 +251,9 @@ class LlamaRunner:
print(f" {arg}")
print()
+ if args.gui:
+ cmd.append("--gui")
+
# Execute the command
try:
result = subprocess.run(cmd, check=True)
@@ -316,7 +319,8 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--model",
dest="model_path",
- required=True,
+ required="--gui" not in sys.argv,
+ default="",
help="Path to the LLM gguf file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)",
)
@@ -466,6 +470,12 @@ def create_parser() -> argparse.ArgumentParser:
"--verbose", "-v", action="store_true", help="Verbose output"
)
+ # GUI options
+ gui_group = parser.add_argument_group("GUI Options")
+ gui_group.add_argument(
+ "--gui", action="store_true", help="Launch the GUI chatbox"
+ )
+
return parser
diff --git a/pom.xml b/pom.xml
index 216dda9..32d260c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -32,6 +32,18 @@
tornado-runtime
1.1.1-dev
+
+
+ org.openjfx
+ javafx-controls
+ 21
+
+
+
+ io.github.mkpaz
+ atlantafx-base
+ 2.0.1
+
@@ -68,6 +80,15 @@
+
+
+ org.openjfx
+ javafx-maven-plugin
+ 0.0.8
+
+ com.example.gui.LlamaChatbox
+
+
diff --git a/src/main/java/com/example/LlamaApp.java b/src/main/java/com/example/LlamaApp.java
index 5ea0cb2..e739ca9 100644
--- a/src/main/java/com/example/LlamaApp.java
+++ b/src/main/java/com/example/LlamaApp.java
@@ -2,12 +2,14 @@
import com.example.aot.AOT;
import com.example.core.model.tensor.FloatTensor;
+import com.example.gui.LlamaChatbox;
import com.example.inference.sampler.CategoricalSampler;
import com.example.inference.sampler.Sampler;
import com.example.inference.sampler.ToppSampler;
import com.example.loader.weights.ModelLoader;
import com.example.model.Model;
import com.example.tornadovm.FloatArrayUtils;
+import javafx.application.Application;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import java.io.IOException;
@@ -18,9 +20,29 @@ public class LlamaApp {
// Configuration flags for hardware acceleration and optimizations
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
- public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
+ private boolean useTornadoVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode
+ private static LlamaApp instance;
+
+ private LlamaApp() {
+ }
+
+ public static LlamaApp getInstance() {
+ if (instance == null) {
+ instance = new LlamaApp();
+ }
+ return instance;
+ }
+
+ public void setUseTornadoVM(boolean value) {
+ useTornadoVM = value;
+ }
+
+ public boolean getUseTornadoVM() {
+ return useTornadoVM;
+ }
+
/**
* Creates and configures a sampler for token generation based on specified parameters.
*
@@ -118,7 +140,7 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
* @throws IOException if the model fails to load
* @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable
*/
- private static Model loadModel(Options options) throws IOException {
+ public static Model loadModel(Options options) throws IOException {
if (USE_AOT) {
Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
if (model == null) {
@@ -129,7 +151,7 @@ private static Model loadModel(Options options) throws IOException {
return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
}
- private static Sampler createSampler(Model model, Options options) {
+ public static Sampler createSampler(Model model, Options options) {
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
}
@@ -145,13 +167,20 @@ private static Sampler createSampler(Model model, Options options) {
*/
public static void main(String[] args) throws IOException {
Options options = Options.parseOptions(args);
- Model model = loadModel(options);
- Sampler sampler = createSampler(model, options);
- if (options.interactive()) {
- model.runInteractive(sampler, options);
+ if (options.guiMode()) {
+ // Launch the JavaFX application
+ Application.launch(LlamaChatbox.class, args);
} else {
- model.runInstructOnce(sampler, options);
+ // Run the CLI logic
+ Model model = loadModel(options);
+ Sampler sampler = createSampler(model, options);
+
+ if (options.interactive()) {
+ model.runInteractive(sampler, options);
+ } else {
+ model.runInstructOnce(sampler, options);
+ }
}
}
}
diff --git a/src/main/java/com/example/Options.java b/src/main/java/com/example/Options.java
index 284e754..ebf9546 100644
--- a/src/main/java/com/example/Options.java
+++ b/src/main/java/com/example/Options.java
@@ -5,7 +5,7 @@
import java.nio.file.Paths;
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive,
- float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {
+ float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, boolean guiMode) {
public static final int DEFAULT_MAX_TOKENS = 1024;
@@ -41,6 +41,7 @@ static void printUsage(PrintStream out) {
out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
out.println(" --echo print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
+ out.println(" --gui run the GUI chatbox");
out.println();
}
@@ -57,6 +58,7 @@ public static Options parseOptions(String[] args) {
boolean interactive = false;
boolean stream = true;
boolean echo = false;
+ boolean guiMode = false;
for (int i = 0; i < args.length; i++) {
String optionName = args[i];
@@ -64,6 +66,7 @@ public static Options parseOptions(String[] args) {
switch (optionName) {
case "--interactive", "--chat", "-i" -> interactive = true;
case "--instruct" -> interactive = false;
+ case "--gui" -> guiMode = true;
case "--help", "-h" -> {
printUsage(System.out);
System.exit(0);
@@ -95,6 +98,6 @@ public static Options parseOptions(String[] args) {
}
}
}
- return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo);
+ return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, guiMode);
}
}
diff --git a/src/main/java/com/example/gui/ChatboxController.java b/src/main/java/com/example/gui/ChatboxController.java
new file mode 100644
index 0000000..5129e93
--- /dev/null
+++ b/src/main/java/com/example/gui/ChatboxController.java
@@ -0,0 +1,36 @@
+package com.example.gui;
+
+import javafx.concurrent.Task;
+import javafx.scene.layout.Region;
+
+public class ChatboxController {
+
+ private final ChatboxViewBuilder viewBuilder;
+ private final ChatboxInteractor interactor;
+
+ public ChatboxController() {
+ ChatboxModel model = new ChatboxModel();
+ interactor = new ChatboxInteractor(model);
+ viewBuilder = new ChatboxViewBuilder(model, this::runInference);
+ }
+
+ private void runInference(Runnable postRunAction) {
+ Task inferenceTask = new Task<>() {
+ @Override
+ protected Void call() {
+ interactor.runLlamaTornado();
+ return null;
+ }
+ };
+ inferenceTask.setOnSucceeded(evt -> {
+ postRunAction.run();
+ });
+ Thread inferenceThread = new Thread(inferenceTask);
+ inferenceThread.start();
+ }
+
+ public Region getView() {
+ return viewBuilder.build();
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/gui/ChatboxInteractor.java b/src/main/java/com/example/gui/ChatboxInteractor.java
new file mode 100644
index 0000000..b740fd8
--- /dev/null
+++ b/src/main/java/com/example/gui/ChatboxInteractor.java
@@ -0,0 +1,164 @@
+package com.example.gui;
+
+import com.example.LlamaApp;
+import com.example.Options;
+import com.example.inference.sampler.Sampler;
+import com.example.model.Model;
+
+import java.io.ByteArrayOutputStream;
+import java.io.PrintStream;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class ChatboxInteractor {
+
+ private final ChatboxModel model;
+ private Model llm;
+ private Sampler sampler;
+ private Model.Response currentResponse;
+
+ public ChatboxInteractor(ChatboxModel viewModel) {
+ this.model = viewModel;
+ }
+
+ private String[] buildCommands() {
+ List commands = new ArrayList<>();
+
+ ChatboxModel.Engine engine = model.getSelectedEngine();
+ LlamaApp llamaApp = LlamaApp.getInstance();
+ if (engine == ChatboxModel.Engine.TORNADO_VM) {
+ llamaApp.setUseTornadoVM(true);
+ } else {
+ llamaApp.setUseTornadoVM(false);
+ }
+
+ ChatboxModel.Mode mode = model.getSelectedMode();
+ if (mode == ChatboxModel.Mode.INTERACTIVE) {
+ commands.add("--interactive");
+ }
+
+ // Assume that models are found in a /models directory.
+ String selectedModel = model.getSelectedModel();
+ if (selectedModel == null || selectedModel.isEmpty()) {
+ model.setOutputText("Please select a model.");
+ return null;
+ }
+ String modelPath = String.format("./models/%s", selectedModel);
+ String prompt = String.format("\"%s\"", model.getPromptText());
+
+ commands.addAll(Arrays.asList("--model", modelPath));
+ if (!model.getInteractiveSessionActive()) {
+ commands.addAll(Arrays.asList("--prompt", prompt));
+ }
+
+ return commands.toArray(new String[commands.size()]);
+ }
+
+ private void cleanTornadoVMResources() {
+ if (currentResponse != null && currentResponse.tornadoVMPlan() != null) {
+ try {
+ currentResponse.tornadoVMPlan().freeTornadoExecutionPlan();
+ } catch (Exception e) {
+ System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage());
+ }
+ }
+ }
+
+ private void endInteractiveSession() {
+ cleanTornadoVMResources();
+ llm = null;
+ currentResponse = null;
+ model.setInteractiveSessionActive(false);
+ System.out.println("Interactive session ended");
+ }
+
+ // Load and run a model while capturing its output text to a custom stream.
+ public void runLlamaTornado() {
+ // Save the original System.out stream
+ PrintStream originalOut = System.out;
+ try {
+ String[] commands = buildCommands();
+ if (commands == null) {
+ // commands is null if no model was found, so exit this process
+ return;
+ }
+
+ StringBuilder builder = new StringBuilder();
+
+ // Create a custom PrintStream to capture output from loading the model and running it.
+ PrintStream customStream = new PrintStream(new ByteArrayOutputStream()) {
+ @Override
+ public void println(String str) {
+ process(str + "\n");
+ }
+
+ @Override
+ public void print(String str) {
+ process(str);
+ }
+
+ private void process(String str) {
+ // Capture the output stream into the GUI output area.
+ builder.append(str);
+ final String currentOutput = builder.toString();
+ javafx.application.Platform.runLater(() -> model.setOutputText(currentOutput));
+ }
+ };
+
+ // Redirect System.out to the custom print stream.
+ System.setOut(customStream);
+ System.setErr(customStream);
+
+ Options options = Options.parseOptions(commands);
+
+ if (model.getInteractiveSessionActive()) {
+ builder.append(model.getOutputText()); // Include the current output to avoid clearing the entire text.
+ String userText = model.getPromptText();
+ // Display the user message with a '>' prefix
+ builder.append("> ");
+ builder.append(userText);
+ builder.append(System.getProperty("line.separator"));
+ if (List.of("quit", "exit").contains(userText)) {
+ endInteractiveSession();
+ } else {
+ currentResponse = llm.runInteractiveStep(sampler, options, userText, currentResponse);
+ }
+ } else {
+ builder.append("Processing... please wait");
+ builder.append(System.getProperty("line.separator"));
+
+ // Load the model and run.
+ llm = LlamaApp.loadModel(options);
+ sampler = LlamaApp.createSampler(llm, options);
+ if (options.interactive()) {
+ // Start a new interactive session.
+ builder.append("Interactive session started (write 'exit' or 'quit' to stop)");
+ builder.append(System.getProperty("line.separator"));
+ // Display the user message with a '>' prefix
+ builder.append("> ");
+ builder.append(model.getPromptText());
+ builder.append(System.getProperty("line.separator"));
+ currentResponse = llm.runInteractiveStep(sampler, options, model.getPromptText(), new Model.Response());
+ model.setInteractiveSessionActive(true);
+ } else {
+ llm.runInstructOnce(sampler, options);
+ llm = null;
+ sampler = null;
+ }
+ }
+
+ } catch (Exception e) {
+ // Catch all exceptions so that they're logged in the output area.
+ e.printStackTrace();
+ e.printStackTrace(originalOut);
+ // Make sure to end the interactive session if an exception occurs.
+ if (model.getInteractiveSessionActive()) {
+ endInteractiveSession();
+ }
+ } finally {
+ System.setOut(originalOut);
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/gui/ChatboxModel.java b/src/main/java/com/example/gui/ChatboxModel.java
new file mode 100644
index 0000000..16a4cd4
--- /dev/null
+++ b/src/main/java/com/example/gui/ChatboxModel.java
@@ -0,0 +1,94 @@
+package com.example.gui;
+
+import javafx.beans.property.BooleanProperty;
+import javafx.beans.property.ObjectProperty;
+import javafx.beans.property.SimpleBooleanProperty;
+import javafx.beans.property.SimpleObjectProperty;
+import javafx.beans.property.SimpleStringProperty;
+import javafx.beans.property.StringProperty;
+
+public class ChatboxModel {
+
+ public enum Engine { TORNADO_VM, JVM }
+ public enum Mode { INSTRUCT, INTERACTIVE }
+
+ private final ObjectProperty selectedEngine = new SimpleObjectProperty<>(Engine.TORNADO_VM);
+ private final ObjectProperty selectedMode = new SimpleObjectProperty<>(Mode.INSTRUCT);
+ private final StringProperty selectedModel = new SimpleStringProperty("");
+ private final StringProperty promptText = new SimpleStringProperty("");
+ private final StringProperty outputText = new SimpleStringProperty("");
+ private final BooleanProperty interactiveSessionActive = new SimpleBooleanProperty(false);
+
+ public Engine getSelectedEngine() {
+ return selectedEngine.get();
+ }
+
+ public ObjectProperty selectedEngineProperty() {
+ return selectedEngine;
+ }
+
+ public void setSelectedEngine(Engine engine) {
+ this.selectedEngine.set(engine);
+ }
+
+ public Mode getSelectedMode() {
+ return selectedMode.get();
+ }
+
+ public ObjectProperty selectedModeProperty() {
+ return selectedMode;
+ }
+
+ public void setSelectedMode(Mode mode) {
+ this.selectedMode.set(mode);
+ }
+
+ public String getSelectedModel() {
+ return selectedModel.get();
+ }
+
+ public StringProperty selectedModelProperty() {
+ return selectedModel;
+ }
+
+ public void setSelectedModel(String selectedModel) {
+ this.selectedModel.set(selectedModel);
+ }
+
+ public String getPromptText() {
+ return promptText.get();
+ }
+
+ public StringProperty promptTextProperty() {
+ return promptText;
+ }
+
+ public void setPromptText(String text) {
+ this.promptText.set(text);
+ }
+
+ public String getOutputText() {
+ return outputText.get();
+ }
+
+ public StringProperty outputTextProperty() {
+ return outputText;
+ }
+
+ public void setOutputText(String text) {
+ this.outputText.set(text);
+ }
+
+ public boolean getInteractiveSessionActive() {
+ return interactiveSessionActive.get();
+ }
+
+ public BooleanProperty interactiveSessionActiveProperty() {
+ return interactiveSessionActive;
+ }
+
+ public void setInteractiveSessionActive(boolean value) {
+ this.interactiveSessionActive.set(value);
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/gui/ChatboxViewBuilder.java b/src/main/java/com/example/gui/ChatboxViewBuilder.java
new file mode 100644
index 0000000..ee6894a
--- /dev/null
+++ b/src/main/java/com/example/gui/ChatboxViewBuilder.java
@@ -0,0 +1,247 @@
+package com.example.gui;
+
+import atlantafx.base.theme.Styles;
+import javafx.beans.binding.Bindings;
+import javafx.beans.property.BooleanProperty;
+import javafx.beans.property.SimpleBooleanProperty;
+import javafx.beans.property.StringProperty;
+import javafx.geometry.Insets;
+import javafx.geometry.Pos;
+import javafx.scene.Node;
+import javafx.scene.control.Button;
+import javafx.scene.control.CheckBox;
+import javafx.scene.control.ComboBox;
+import javafx.scene.control.Control;
+import javafx.scene.control.Label;
+import javafx.scene.control.TextArea;
+import javafx.scene.control.TextField;
+import javafx.scene.layout.HBox;
+import javafx.scene.layout.Priority;
+import javafx.scene.layout.Region;
+import javafx.scene.layout.VBox;
+import javafx.util.Builder;
+
+import java.io.File;
+import java.util.function.Consumer;
+
+public class ChatboxViewBuilder implements Builder {
+
+ private static final int PANEL_WIDTH = 640;
+
+ private final ChatboxModel model;
+ private final BooleanProperty inferenceRunning = new SimpleBooleanProperty(false);
+ private final Consumer actionHandler;
+
+ public ChatboxViewBuilder(ChatboxModel model, Consumer actionHandler) {
+ this.model = model;
+ this.actionHandler = actionHandler;
+ }
+
+ @Override
+ public Region build() {
+ VBox results = new VBox();
+
+ HBox panels = new HBox(createLeftPanel(), createRightPanel());
+ VBox.setVgrow(panels, Priority.ALWAYS);
+ results.getChildren().add(panels);
+
+ return results;
+ }
+
+ private Node createLeftPanel() {
+ VBox panel = new VBox(12);
+ panel.setPrefWidth(PANEL_WIDTH);
+ panel.setPadding(new Insets(24, 12, 24, 24));
+ HBox.setHgrow(panel, Priority.ALWAYS);
+ panel.getChildren().addAll(
+ createHeaderLabel("TornadoVM Chat"),
+ createEngineBox(),
+ createChatModeBox(),
+ createModelSelectBox(),
+ createLabel("Prompt:"),
+ createPromptBox(),
+ createRunButton(),
+ createLabel("Output:"),
+ createOutputArea()
+ );
+ return panel;
+ }
+
+ private Node createEngineBox() {
+ ComboBox engineDropdown = new ComboBox<>();
+ engineDropdown.disableProperty().bind(Bindings.createBooleanBinding(() -> (inferenceRunning.get() || model.getInteractiveSessionActive()),
+ inferenceRunning,
+ model.interactiveSessionActiveProperty()));
+ engineDropdown.valueProperty().bindBidirectional(model.selectedEngineProperty());
+ engineDropdown.getItems().addAll(ChatboxModel.Engine.values());
+ engineDropdown.setMaxWidth(Double.MAX_VALUE);
+ engineDropdown.setPrefWidth(152);
+ HBox box = new HBox(8, createLabel("Engine:"), engineDropdown);
+ box.setAlignment(Pos.CENTER_LEFT);
+ return box;
+ }
+
+ private Node createChatModeBox() {
+ ComboBox modeDropdown = new ComboBox<>();
+ modeDropdown.disableProperty().bind(Bindings.createBooleanBinding(() -> (inferenceRunning.get() || model.getInteractiveSessionActive()),
+ inferenceRunning,
+ model.interactiveSessionActiveProperty()));
+ modeDropdown.valueProperty().bindBidirectional(model.selectedModeProperty());
+ modeDropdown.getItems().addAll(ChatboxModel.Mode.values());
+ modeDropdown.setMaxWidth(Double.MAX_VALUE);
+ modeDropdown.setPrefWidth(152);
+ HBox box = new HBox(8, createLabel("Chat:"), modeDropdown);
+ box.setAlignment(Pos.CENTER_LEFT);
+ return box;
+ }
+
+ private Node createModelSelectBox() {
+ ComboBox modelDropdown = new ComboBox<>();
+ modelDropdown.disableProperty().bind(Bindings.createBooleanBinding(() -> (inferenceRunning.get() || model.getInteractiveSessionActive()),
+ inferenceRunning,
+ model.interactiveSessionActiveProperty()));
+ modelDropdown.valueProperty().bindBidirectional(model.selectedModelProperty());
+ HBox.setHgrow(modelDropdown, Priority.ALWAYS);
+ modelDropdown.setMaxWidth(Double.MAX_VALUE);
+
+ Button reloadButton = new Button("Reload");
+ reloadButton.getStyleClass().add(Styles.ACCENT);
+ reloadButton.setMinWidth(80);
+ reloadButton.disableProperty().bind(Bindings.createBooleanBinding(() -> (inferenceRunning.get() || model.getInteractiveSessionActive()),
+ inferenceRunning,
+ model.interactiveSessionActiveProperty()));
+ reloadButton.setOnAction(e -> {
+ // Search for a /models folder containing model files.
+ modelDropdown.getItems().clear();
+ File llama3ModelsDir = new File("./models");
+ if (llama3ModelsDir.exists() && llama3ModelsDir.isDirectory()) {
+ File[] files = llama3ModelsDir.listFiles((dir, name) -> name.endsWith(".gguf"));
+ if (files != null) {
+ for (File file : files) {
+ modelDropdown.getItems().add(file.getName());
+ }
+
+ int numModels = modelDropdown.getItems().size();
+ String message = String.format("Found %d %s in %s", numModels, (numModels == 1 ? "model" : "models"), llama3ModelsDir.toPath().normalize().toAbsolutePath());
+ String currentOutput = model.getOutputText();
+ if (currentOutput.isEmpty()) {
+ model.setOutputText(message);
+ } else {
+ model.setOutputText(String.format("%s\n%s", model.getOutputText(), message));
+ }
+
+ if (numModels == 0) {
+ modelDropdown.getSelectionModel().clearSelection();
+ } else {
+ modelDropdown.getSelectionModel().select(0);
+ }
+ }
+ }
+ });
+
+ reloadButton.fire(); // Trigger the reload once at the start.
+
+ HBox box = new HBox(8, createLabel("Model:"), modelDropdown, reloadButton);
+ box.setAlignment(Pos.CENTER_LEFT);
+ return box;
+ }
+
+ private Node createPromptBox() {
+ TextField promptField = boundTextField(model.promptTextProperty());
+ HBox.setHgrow(promptField, Priority.ALWAYS);
+ promptField.setMaxWidth(Double.MAX_VALUE);
+ return new HBox(8, promptField);
+ }
+
+ private Node createRunButton() {
+ Button runButton = new Button("Run");
+ runButton.getStyleClass().add(Styles.ACCENT);
+ runButton.setMaxWidth(Double.MAX_VALUE);
+ runButton.disableProperty().bind(inferenceRunning);
+ runButton.setOnAction(e -> {
+ inferenceRunning.set(true);
+ actionHandler.accept(() -> inferenceRunning.set(false));
+ });
+ return runButton;
+ }
+
+ private Node createOutputArea() {
+ TextArea outputArea = new TextArea();
+ outputArea.setEditable(false);
+ outputArea.setWrapText(true);
+ VBox.setVgrow(outputArea, Priority.ALWAYS);
+ model.outputTextProperty().subscribe((newValue) -> {
+ outputArea.setText(newValue);
+ // Autoscroll the text area to the bottom.
+ outputArea.positionCaret(newValue.length());
+ });
+ return outputArea;
+ }
+
+ private Node createRightPanel() {
+ VBox panel = new VBox(8);
+ panel.setPrefWidth(PANEL_WIDTH);
+ panel.setPadding(new Insets(24, 24, 24, 12));
+ HBox.setHgrow(panel, Priority.ALWAYS);
+ panel.getChildren().addAll(
+ createMonitorOutputArea(),
+ createMonitorOptionsPanel()
+ );
+ return panel;
+ }
+
+ private TextArea createMonitorOutputArea() {
+ TextArea textArea = new TextArea();
+ textArea.setEditable(false);
+ textArea.setWrapText(true);
+ VBox.setVgrow(textArea, Priority.ALWAYS);
+ return textArea;
+ }
+
+ private Node createMonitorOptionsPanel() {
+ VBox box = new VBox();
+ box.setPadding(new Insets(8));
+ box.getChildren().addAll(
+ createSubHeaderLabel("System Monitor"),
+ createSystemMonitoringCheckBoxes()
+ );
+ return box;
+ }
+
+ private Node createSystemMonitoringCheckBoxes() {
+ HBox checkBoxes = new HBox(8);
+ checkBoxes.setAlignment(Pos.CENTER_LEFT);
+ checkBoxes.getChildren().addAll(
+ new CheckBox("htop"),
+ new CheckBox("nvtop"),
+ new CheckBox("GPU-Monitor")
+ );
+ return checkBoxes;
+ }
+
+ // Helper method for creating TextField objects with bound text property
+ private TextField boundTextField(StringProperty boundProperty) {
+ TextField textField = new TextField();
+ textField.textProperty().bindBidirectional(boundProperty);
+ return textField;
+ }
+
+ // Helper method to create Label objects with a minimum width
+ private Label createLabel(String text) {
+ Label label = new Label(text);
+ label.setMinWidth(Control.USE_PREF_SIZE);
+ return label;
+ }
+
+ private Label createHeaderLabel(String text) {
+ Label label = createLabel(text);
+ label.setStyle("-fx-font-size: 16pt; -fx-font-weight: bold;");
+ return label;
+ }
+
+ private Label createSubHeaderLabel(String text) {
+ Label label = createLabel(text);
+ label.setStyle("-fx-font-size: 12pt; -fx-font-weight: bold;");
+ return label;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/gui/LlamaChatbox.java b/src/main/java/com/example/gui/LlamaChatbox.java
new file mode 100644
index 0000000..7ac2cb2
--- /dev/null
+++ b/src/main/java/com/example/gui/LlamaChatbox.java
@@ -0,0 +1,23 @@
+package com.example.gui;
+
+import atlantafx.base.theme.CupertinoDark;
+import javafx.application.Application;
+import javafx.scene.Scene;
+import javafx.stage.Stage;
+
+public class LlamaChatbox extends Application {
+
+ @Override
+ public void start(Stage stage) {
+ Application.setUserAgentStylesheet(new CupertinoDark().getUserAgentStylesheet());
+ ChatboxController controller = new ChatboxController();
+ Scene scene = new Scene(controller.getView(), 800, 600);
+ stage.setTitle("TornadoVM Chat");
+ stage.setScene(scene);
+ stage.show();
+ }
+
+ public static void main(String[] args) {
+ launch(args);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/example/loader/weights/ModelLoader.java b/src/main/java/com/example/loader/weights/ModelLoader.java
index c4f7751..e96b8da 100644
--- a/src/main/java/com/example/loader/weights/ModelLoader.java
+++ b/src/main/java/com/example/loader/weights/ModelLoader.java
@@ -96,7 +96,8 @@ public static Weights loadWeights(Map tensorEntries, Co
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
- if (LlamaApp.USE_TORNADOVM) {
+ LlamaApp llamaApp = LlamaApp.getInstance();
+ if (llamaApp.getUseTornadoVM()) {
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
} else {
diff --git a/src/main/java/com/example/model/Model.java b/src/main/java/com/example/model/Model.java
index e42349b..bbf2a46 100644
--- a/src/main/java/com/example/model/Model.java
+++ b/src/main/java/com/example/model/Model.java
@@ -1,5 +1,6 @@
package com.example.model;
+import com.example.LlamaApp;
import com.example.Options;
import com.example.auxiliary.LastRunMetrics;
import com.example.inference.InferenceEngine;
@@ -17,7 +18,6 @@
import java.util.function.IntConsumer;
import static com.example.LlamaApp.SHOW_PERF_INTERACTIVE;
-import static com.example.LlamaApp.USE_TORNADOVM;
public interface Model {
Configuration configuration();
@@ -54,6 +54,9 @@ default void runInteractive(Sampler sampler, Options options) {
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
TornadoVMMasterPlan tornadoVMPlan = null;
+ // Get the LlamaApp singleton to read configuration values
+ LlamaApp llamaApp = LlamaApp.getInstance();
+
try {
while (true) {
System.out.print("> ");
@@ -68,7 +71,7 @@ default void runInteractive(Sampler sampler, Options options) {
state = createNewState();
}
- if (USE_TORNADOVM && tornadoVMPlan == null) {
+ if (llamaApp.getUseTornadoVM() && tornadoVMPlan == null) {
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
}
@@ -86,7 +89,7 @@ default void runInteractive(Sampler sampler, Options options) {
};
// Choose between GPU and CPU path based on configuration
- if (USE_TORNADOVM) {
+ if (llamaApp.getUseTornadoVM()) {
// GPU path using TornadoVM
responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
@@ -121,7 +124,7 @@ default void runInteractive(Sampler sampler, Options options) {
}
} finally {
// Clean up TornadoVM resources when exiting the chat loop
- if (USE_TORNADOVM && tornadoVMPlan != null) {
+ if (llamaApp.getUseTornadoVM() && tornadoVMPlan != null) {
try {
tornadoVMPlan.freeTornadoExecutionPlan();
} catch (Exception e) {
@@ -131,6 +134,104 @@ default void runInteractive(Sampler sampler, Options options) {
}
}
+ /**
+ * Model agnostic implementation for interactive GUI mode.
+ * Takes a single user input and returns the model's response, allowing the GUI to manage the chat loop.
+ *
+ * @param sampler The sampler for token generation
+ * @param options The inference options
+ * @param userText The user's input text
+ * @param previousResponse A Response object referencing an ongoing chat
+ * @return A Response object containing the model's output and updated state
+ */
+ default Response runInteractiveStep(Sampler sampler, Options options, String userText, Response previousResponse) {
+ ChatFormat chatFormat = ChatFormat.create(tokenizer());
+ List conversationTokens = previousResponse.conversationTokens();
+ State state = previousResponse.state();
+ TornadoVMMasterPlan tornadoVMPlan = previousResponse.tornadoVMPlan();
+ String responseText = "";
+
+ int startPosition = conversationTokens.size();
+
+ // For the first message, set up the conversation tokens if empty
+ if (conversationTokens.isEmpty()) {
+ conversationTokens.add(chatFormat.getBeginOfText());
+ if (options.systemPrompt() != null) {
+ conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
+ }
+ }
+
+ // Get the LlamaApp singleton to read configuration values
+ LlamaApp llamaApp = LlamaApp.getInstance();
+
+ if (state == null) {
+ // State allocation can take some time for large context sizes
+ state = createNewState();
+ }
+
+ // Initialize TornadoVM plan once at the beginning if GPU path is enabled
+ if (llamaApp.getUseTornadoVM() && tornadoVMPlan == null) {
+ tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
+ }
+
+ conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
+ conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
+ Set stopTokens = chatFormat.getStopTokens();
+
+ List responseTokens;
+ IntConsumer tokenConsumer = token -> {
+ if (options.stream()) {
+ if (tokenizer().shouldDisplayToken(token)) {
+ System.out.print(tokenizer().decode(List.of(token)));
+ }
+ }
+ };
+
+ // Choose between GPU and CPU path based on configuration
+ if (llamaApp.getUseTornadoVM()) {
+ // GPU path using TornadoVM
+ responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
+ options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
+ } else {
+ // CPU path
+ responseTokens = InferenceEngine.generateTokens(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
+ sampler, options.echo(), tokenConsumer);
+ }
+
+ // Include stop token in the prompt history, but not in the response displayed to the user.
+ conversationTokens.addAll(responseTokens);
+ Integer stopToken = null;
+ if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
+ stopToken = responseTokens.getLast();
+ responseTokens.removeLast();
+ }
+ if (!options.stream()) {
+ responseText = tokenizer().decode(responseTokens);
+ System.out.println(responseText);
+ }
+ if (stopToken == null) {
+ System.err.println("\n Ran out of context length...\n Increase context length by passing to llama-tornado --max-tokens XXX");
+ return new Response(responseText, state, conversationTokens, tornadoVMPlan);
+ }
+ System.out.print("\n");
+
+ // Optionally print performance metrics after each response
+ if (SHOW_PERF_INTERACTIVE) {
+ LastRunMetrics.printMetrics();
+ }
+
+ return new Response(responseText, state, conversationTokens, tornadoVMPlan);
+ }
+
+ /**
+ * Simple data model for Model responses, used to keep track of conversation history and state.
+ */
+ record Response(String responseText, State state, List conversationTokens, TornadoVMMasterPlan tornadoVMPlan) {
+ public Response() {
+ this("", null, new ArrayList<>(), null);
+ }
+ }
+
/**
* Model agnostic default implementation for instruct mode.
* @param sampler
@@ -140,6 +241,7 @@ default void runInstructOnce(Sampler sampler, Options options) {
State state = createNewState();
ChatFormat chatFormat = ChatFormat.create(tokenizer());
TornadoVMMasterPlan tornadoVMPlan = null;
+ LlamaApp llamaApp = LlamaApp.getInstance();
List promptTokens = new ArrayList<>();
promptTokens.add(chatFormat.getBeginOfText());
@@ -162,7 +264,7 @@ default void runInstructOnce(Sampler sampler, Options options) {
Set stopTokens = chatFormat.getStopTokens();
- if (USE_TORNADOVM) {
+ if (llamaApp.getUseTornadoVM()) {
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
// Call generateTokensGPU without the token consumer parameter
responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null,