diff --git a/llama-tornado b/llama-tornado index 112a7ca..5df1851 100755 --- a/llama-tornado +++ b/llama-tornado @@ -251,6 +251,9 @@ class LlamaRunner: print(f" {arg}") print() + if args.gui: + cmd.append("--gui") + # Execute the command try: result = subprocess.run(cmd, check=True) @@ -316,7 +319,8 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument( "--model", dest="model_path", - required=True, + required="--gui" not in sys.argv, + default="", help="Path to the LLM gguf file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)", ) @@ -466,6 +470,12 @@ def create_parser() -> argparse.ArgumentParser: "--verbose", "-v", action="store_true", help="Verbose output" ) + # GUI options + gui_group = parser.add_argument_group("GUI Options") + gui_group.add_argument( + "--gui", action="store_true", help="Launch the GUI chatbox" + ) + return parser diff --git a/pom.xml b/pom.xml index 216dda9..32d260c 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,18 @@ tornado-runtime 1.1.1-dev + + + org.openjfx + javafx-controls + 21 + + + + io.github.mkpaz + atlantafx-base + 2.0.1 + @@ -68,6 +80,15 @@ + + + org.openjfx + javafx-maven-plugin + 0.0.8 + + com.example.gui.LlamaChatbox + + diff --git a/src/main/java/com/example/LlamaApp.java b/src/main/java/com/example/LlamaApp.java index 5ea0cb2..e739ca9 100644 --- a/src/main/java/com/example/LlamaApp.java +++ b/src/main/java/com/example/LlamaApp.java @@ -2,12 +2,14 @@ import com.example.aot.AOT; import com.example.core.model.tensor.FloatTensor; +import com.example.gui.LlamaChatbox; import com.example.inference.sampler.CategoricalSampler; import com.example.inference.sampler.Sampler; import com.example.inference.sampler.ToppSampler; import com.example.loader.weights.ModelLoader; import com.example.model.Model; import com.example.tornadovm.FloatArrayUtils; +import javafx.application.Application; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import java.io.IOException; @@ -18,9 +20,29 @@ public class LlamaApp { // Configuration flags for hardware acceleration and optimizations public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation - public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration + private boolean useTornadoVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode + private static LlamaApp instance; + + private LlamaApp() { + } + + public static LlamaApp getInstance() { + if (instance == null) { + instance = new LlamaApp(); + } + return instance; + } + + public void setUseTornadoVM(boolean value) { + useTornadoVM = value; + } + + public boolean getUseTornadoVM() { + return useTornadoVM; + } + /** * Creates and configures a sampler for token generation based on specified parameters. * @@ -118,7 +140,7 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, * @throws IOException if the model fails to load * @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable */ - private static Model loadModel(Options options) throws IOException { + public static Model loadModel(Options options) throws IOException { if (USE_AOT) { Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); if (model == null) { @@ -129,7 +151,7 @@ private static Model loadModel(Options options) throws IOException { return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); } - private static Sampler createSampler(Model model, Options options) { + public static Sampler createSampler(Model model, Options options) { return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); } @@ -145,13 +167,20 @@ private static Sampler createSampler(Model model, Options options) { */ public static void main(String[] args) throws IOException { Options options = Options.parseOptions(args); - Model model = loadModel(options); - Sampler sampler = createSampler(model, options); - if (options.interactive()) { - model.runInteractive(sampler, options); + if (options.guiMode()) { + // Launch the JavaFX application + Application.launch(LlamaChatbox.class, args); } else { - model.runInstructOnce(sampler, options); + // Run the CLI logic + Model model = loadModel(options); + Sampler sampler = createSampler(model, options); + + if (options.interactive()) { + model.runInteractive(sampler, options); + } else { + model.runInstructOnce(sampler, options); + } } } } diff --git a/src/main/java/com/example/Options.java b/src/main/java/com/example/Options.java index 284e754..ebf9546 100644 --- a/src/main/java/com/example/Options.java +++ b/src/main/java/com/example/Options.java @@ -5,7 +5,7 @@ import java.nio.file.Paths; public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, - float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) { + float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, boolean guiMode) { public static final int DEFAULT_MAX_TOKENS = 1024; @@ -41,6 +41,7 @@ static void printUsage(PrintStream out) { out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS); out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); out.println(" --echo print ALL tokens to stderr, if true, recommended to set --stream=false, default false"); + out.println(" --gui run the GUI chatbox"); out.println(); } @@ -57,6 +58,7 @@ public static Options parseOptions(String[] args) { boolean interactive = false; boolean stream = true; boolean echo = false; + boolean guiMode = false; for (int i = 0; i < args.length; i++) { String optionName = args[i]; @@ -64,6 +66,7 @@ public static Options parseOptions(String[] args) { switch (optionName) { case "--interactive", "--chat", "-i" -> interactive = true; case "--instruct" -> interactive = false; + case "--gui" -> guiMode = true; case "--help", "-h" -> { printUsage(System.out); System.exit(0); @@ -95,6 +98,6 @@ public static Options parseOptions(String[] args) { } } } - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, guiMode); } } diff --git a/src/main/java/com/example/gui/ChatboxController.java b/src/main/java/com/example/gui/ChatboxController.java new file mode 100644 index 0000000..5129e93 --- /dev/null +++ b/src/main/java/com/example/gui/ChatboxController.java @@ -0,0 +1,36 @@ +package com.example.gui; + +import javafx.concurrent.Task; +import javafx.scene.layout.Region; + +public class ChatboxController { + + private final ChatboxViewBuilder viewBuilder; + private final ChatboxInteractor interactor; + + public ChatboxController() { + ChatboxModel model = new ChatboxModel(); + interactor = new ChatboxInteractor(model); + viewBuilder = new ChatboxViewBuilder(model, this::runInference); + } + + private void runInference(Runnable postRunAction) { + Task inferenceTask = new Task<>() { + @Override + protected Void call() { + interactor.runLlamaTornado(); + return null; + } + }; + inferenceTask.setOnSucceeded(evt -> { + postRunAction.run(); + }); + Thread inferenceThread = new Thread(inferenceTask); + inferenceThread.start(); + } + + public Region getView() { + return viewBuilder.build(); + } + +} \ No newline at end of file diff --git a/src/main/java/com/example/gui/ChatboxInteractor.java b/src/main/java/com/example/gui/ChatboxInteractor.java new file mode 100644 index 0000000..b740fd8 --- /dev/null +++ b/src/main/java/com/example/gui/ChatboxInteractor.java @@ -0,0 +1,164 @@ +package com.example.gui; + +import com.example.LlamaApp; +import com.example.Options; +import com.example.inference.sampler.Sampler; +import com.example.model.Model; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class ChatboxInteractor { + + private final ChatboxModel model; + private Model llm; + private Sampler sampler; + private Model.Response currentResponse; + + public ChatboxInteractor(ChatboxModel viewModel) { + this.model = viewModel; + } + + private String[] buildCommands() { + List commands = new ArrayList<>(); + + ChatboxModel.Engine engine = model.getSelectedEngine(); + LlamaApp llamaApp = LlamaApp.getInstance(); + if (engine == ChatboxModel.Engine.TORNADO_VM) { + llamaApp.setUseTornadoVM(true); + } else { + llamaApp.setUseTornadoVM(false); + } + + ChatboxModel.Mode mode = model.getSelectedMode(); + if (mode == ChatboxModel.Mode.INTERACTIVE) { + commands.add("--interactive"); + } + + // Assume that models are found in a /models directory. + String selectedModel = model.getSelectedModel(); + if (selectedModel == null || selectedModel.isEmpty()) { + model.setOutputText("Please select a model."); + return null; + } + String modelPath = String.format("./models/%s", selectedModel); + String prompt = String.format("\"%s\"", model.getPromptText()); + + commands.addAll(Arrays.asList("--model", modelPath)); + if (!model.getInteractiveSessionActive()) { + commands.addAll(Arrays.asList("--prompt", prompt)); + } + + return commands.toArray(new String[commands.size()]); + } + + private void cleanTornadoVMResources() { + if (currentResponse != null && currentResponse.tornadoVMPlan() != null) { + try { + currentResponse.tornadoVMPlan().freeTornadoExecutionPlan(); + } catch (Exception e) { + System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage()); + } + } + } + + private void endInteractiveSession() { + cleanTornadoVMResources(); + llm = null; + currentResponse = null; + model.setInteractiveSessionActive(false); + System.out.println("Interactive session ended"); + } + + // Load and run a model while capturing its output text to a custom stream. + public void runLlamaTornado() { + // Save the original System.out stream + PrintStream originalOut = System.out; + try { + String[] commands = buildCommands(); + if (commands == null) { + // commands is null if no model was found, so exit this process + return; + } + + StringBuilder builder = new StringBuilder(); + + // Create a custom PrintStream to capture output from loading the model and running it. + PrintStream customStream = new PrintStream(new ByteArrayOutputStream()) { + @Override + public void println(String str) { + process(str + "\n"); + } + + @Override + public void print(String str) { + process(str); + } + + private void process(String str) { + // Capture the output stream into the GUI output area. + builder.append(str); + final String currentOutput = builder.toString(); + javafx.application.Platform.runLater(() -> model.setOutputText(currentOutput)); + } + }; + + // Redirect System.out to the custom print stream. + System.setOut(customStream); + System.setErr(customStream); + + Options options = Options.parseOptions(commands); + + if (model.getInteractiveSessionActive()) { + builder.append(model.getOutputText()); // Include the current output to avoid clearing the entire text. + String userText = model.getPromptText(); + // Display the user message with a '>' prefix + builder.append("> "); + builder.append(userText); + builder.append(System.getProperty("line.separator")); + if (List.of("quit", "exit").contains(userText)) { + endInteractiveSession(); + } else { + currentResponse = llm.runInteractiveStep(sampler, options, userText, currentResponse); + } + } else { + builder.append("Processing... please wait"); + builder.append(System.getProperty("line.separator")); + + // Load the model and run. + llm = LlamaApp.loadModel(options); + sampler = LlamaApp.createSampler(llm, options); + if (options.interactive()) { + // Start a new interactive session. + builder.append("Interactive session started (write 'exit' or 'quit' to stop)"); + builder.append(System.getProperty("line.separator")); + // Display the user message with a '>' prefix + builder.append("> "); + builder.append(model.getPromptText()); + builder.append(System.getProperty("line.separator")); + currentResponse = llm.runInteractiveStep(sampler, options, model.getPromptText(), new Model.Response()); + model.setInteractiveSessionActive(true); + } else { + llm.runInstructOnce(sampler, options); + llm = null; + sampler = null; + } + } + + } catch (Exception e) { + // Catch all exceptions so that they're logged in the output area. + e.printStackTrace(); + e.printStackTrace(originalOut); + // Make sure to end the interactive session if an exception occurs. + if (model.getInteractiveSessionActive()) { + endInteractiveSession(); + } + } finally { + System.setOut(originalOut); + } + } + +} \ No newline at end of file diff --git a/src/main/java/com/example/gui/ChatboxModel.java b/src/main/java/com/example/gui/ChatboxModel.java new file mode 100644 index 0000000..16a4cd4 --- /dev/null +++ b/src/main/java/com/example/gui/ChatboxModel.java @@ -0,0 +1,94 @@ +package com.example.gui; + +import javafx.beans.property.BooleanProperty; +import javafx.beans.property.ObjectProperty; +import javafx.beans.property.SimpleBooleanProperty; +import javafx.beans.property.SimpleObjectProperty; +import javafx.beans.property.SimpleStringProperty; +import javafx.beans.property.StringProperty; + +public class ChatboxModel { + + public enum Engine { TORNADO_VM, JVM } + public enum Mode { INSTRUCT, INTERACTIVE } + + private final ObjectProperty selectedEngine = new SimpleObjectProperty<>(Engine.TORNADO_VM); + private final ObjectProperty selectedMode = new SimpleObjectProperty<>(Mode.INSTRUCT); + private final StringProperty selectedModel = new SimpleStringProperty(""); + private final StringProperty promptText = new SimpleStringProperty(""); + private final StringProperty outputText = new SimpleStringProperty(""); + private final BooleanProperty interactiveSessionActive = new SimpleBooleanProperty(false); + + public Engine getSelectedEngine() { + return selectedEngine.get(); + } + + public ObjectProperty selectedEngineProperty() { + return selectedEngine; + } + + public void setSelectedEngine(Engine engine) { + this.selectedEngine.set(engine); + } + + public Mode getSelectedMode() { + return selectedMode.get(); + } + + public ObjectProperty selectedModeProperty() { + return selectedMode; + } + + public void setSelectedMode(Mode mode) { + this.selectedMode.set(mode); + } + + public String getSelectedModel() { + return selectedModel.get(); + } + + public StringProperty selectedModelProperty() { + return selectedModel; + } + + public void setSelectedModel(String selectedModel) { + this.selectedModel.set(selectedModel); + } + + public String getPromptText() { + return promptText.get(); + } + + public StringProperty promptTextProperty() { + return promptText; + } + + public void setPromptText(String text) { + this.promptText.set(text); + } + + public String getOutputText() { + return outputText.get(); + } + + public StringProperty outputTextProperty() { + return outputText; + } + + public void setOutputText(String text) { + this.outputText.set(text); + } + + public boolean getInteractiveSessionActive() { + return interactiveSessionActive.get(); + } + + public BooleanProperty interactiveSessionActiveProperty() { + return interactiveSessionActive; + } + + public void setInteractiveSessionActive(boolean value) { + this.interactiveSessionActive.set(value); + } + +} \ No newline at end of file diff --git a/src/main/java/com/example/gui/ChatboxViewBuilder.java b/src/main/java/com/example/gui/ChatboxViewBuilder.java new file mode 100644 index 0000000..ee6894a --- /dev/null +++ b/src/main/java/com/example/gui/ChatboxViewBuilder.java @@ -0,0 +1,247 @@ +package com.example.gui; + +import atlantafx.base.theme.Styles; +import javafx.beans.binding.Bindings; +import javafx.beans.property.BooleanProperty; +import javafx.beans.property.SimpleBooleanProperty; +import javafx.beans.property.StringProperty; +import javafx.geometry.Insets; +import javafx.geometry.Pos; +import javafx.scene.Node; +import javafx.scene.control.Button; +import javafx.scene.control.CheckBox; +import javafx.scene.control.ComboBox; +import javafx.scene.control.Control; +import javafx.scene.control.Label; +import javafx.scene.control.TextArea; +import javafx.scene.control.TextField; +import javafx.scene.layout.HBox; +import javafx.scene.layout.Priority; +import javafx.scene.layout.Region; +import javafx.scene.layout.VBox; +import javafx.util.Builder; + +import java.io.File; +import java.util.function.Consumer; + +public class ChatboxViewBuilder implements Builder { + + private static final int PANEL_WIDTH = 640; + + private final ChatboxModel model; + private final BooleanProperty inferenceRunning = new SimpleBooleanProperty(false); + private final Consumer actionHandler; + + public ChatboxViewBuilder(ChatboxModel model, Consumer actionHandler) { + this.model = model; + this.actionHandler = actionHandler; + } + + @Override + public Region build() { + VBox results = new VBox(); + + HBox panels = new HBox(createLeftPanel(), createRightPanel()); + VBox.setVgrow(panels, Priority.ALWAYS); + results.getChildren().add(panels); + + return results; + } + + private Node createLeftPanel() { + VBox panel = new VBox(12); + panel.setPrefWidth(PANEL_WIDTH); + panel.setPadding(new Insets(24, 12, 24, 24)); + HBox.setHgrow(panel, Priority.ALWAYS); + panel.getChildren().addAll( + createHeaderLabel("TornadoVM Chat"), + createEngineBox(), + createChatModeBox(), + createModelSelectBox(), + createLabel("Prompt:"), + createPromptBox(), + createRunButton(), + createLabel("Output:"), + createOutputArea() + ); + return panel; + } + + private Node createEngineBox() { + ComboBox engineDropdown = new ComboBox<>(); + engineDropdown.disableProperty().bind(Bindings.createBooleanBinding(() -> (inferenceRunning.get() || model.getInteractiveSessionActive()), + inferenceRunning, + model.interactiveSessionActiveProperty())); + engineDropdown.valueProperty().bindBidirectional(model.selectedEngineProperty()); + engineDropdown.getItems().addAll(ChatboxModel.Engine.values()); + engineDropdown.setMaxWidth(Double.MAX_VALUE); + engineDropdown.setPrefWidth(152); + HBox box = new HBox(8, createLabel("Engine:"), engineDropdown); + box.setAlignment(Pos.CENTER_LEFT); + return box; + } + + private Node createChatModeBox() { + ComboBox modeDropdown = new ComboBox<>(); + modeDropdown.disableProperty().bind(Bindings.createBooleanBinding(() -> (inferenceRunning.get() || model.getInteractiveSessionActive()), + inferenceRunning, + model.interactiveSessionActiveProperty())); + modeDropdown.valueProperty().bindBidirectional(model.selectedModeProperty()); + modeDropdown.getItems().addAll(ChatboxModel.Mode.values()); + modeDropdown.setMaxWidth(Double.MAX_VALUE); + modeDropdown.setPrefWidth(152); + HBox box = new HBox(8, createLabel("Chat:"), modeDropdown); + box.setAlignment(Pos.CENTER_LEFT); + return box; + } + + private Node createModelSelectBox() { + ComboBox modelDropdown = new ComboBox<>(); + modelDropdown.disableProperty().bind(Bindings.createBooleanBinding(() -> (inferenceRunning.get() || model.getInteractiveSessionActive()), + inferenceRunning, + model.interactiveSessionActiveProperty())); + modelDropdown.valueProperty().bindBidirectional(model.selectedModelProperty()); + HBox.setHgrow(modelDropdown, Priority.ALWAYS); + modelDropdown.setMaxWidth(Double.MAX_VALUE); + + Button reloadButton = new Button("Reload"); + reloadButton.getStyleClass().add(Styles.ACCENT); + reloadButton.setMinWidth(80); + reloadButton.disableProperty().bind(Bindings.createBooleanBinding(() -> (inferenceRunning.get() || model.getInteractiveSessionActive()), + inferenceRunning, + model.interactiveSessionActiveProperty())); + reloadButton.setOnAction(e -> { + // Search for a /models folder containing model files. + modelDropdown.getItems().clear(); + File llama3ModelsDir = new File("./models"); + if (llama3ModelsDir.exists() && llama3ModelsDir.isDirectory()) { + File[] files = llama3ModelsDir.listFiles((dir, name) -> name.endsWith(".gguf")); + if (files != null) { + for (File file : files) { + modelDropdown.getItems().add(file.getName()); + } + + int numModels = modelDropdown.getItems().size(); + String message = String.format("Found %d %s in %s", numModels, (numModels == 1 ? "model" : "models"), llama3ModelsDir.toPath().normalize().toAbsolutePath()); + String currentOutput = model.getOutputText(); + if (currentOutput.isEmpty()) { + model.setOutputText(message); + } else { + model.setOutputText(String.format("%s\n%s", model.getOutputText(), message)); + } + + if (numModels == 0) { + modelDropdown.getSelectionModel().clearSelection(); + } else { + modelDropdown.getSelectionModel().select(0); + } + } + } + }); + + reloadButton.fire(); // Trigger the reload once at the start. + + HBox box = new HBox(8, createLabel("Model:"), modelDropdown, reloadButton); + box.setAlignment(Pos.CENTER_LEFT); + return box; + } + + private Node createPromptBox() { + TextField promptField = boundTextField(model.promptTextProperty()); + HBox.setHgrow(promptField, Priority.ALWAYS); + promptField.setMaxWidth(Double.MAX_VALUE); + return new HBox(8, promptField); + } + + private Node createRunButton() { + Button runButton = new Button("Run"); + runButton.getStyleClass().add(Styles.ACCENT); + runButton.setMaxWidth(Double.MAX_VALUE); + runButton.disableProperty().bind(inferenceRunning); + runButton.setOnAction(e -> { + inferenceRunning.set(true); + actionHandler.accept(() -> inferenceRunning.set(false)); + }); + return runButton; + } + + private Node createOutputArea() { + TextArea outputArea = new TextArea(); + outputArea.setEditable(false); + outputArea.setWrapText(true); + VBox.setVgrow(outputArea, Priority.ALWAYS); + model.outputTextProperty().subscribe((newValue) -> { + outputArea.setText(newValue); + // Autoscroll the text area to the bottom. + outputArea.positionCaret(newValue.length()); + }); + return outputArea; + } + + private Node createRightPanel() { + VBox panel = new VBox(8); + panel.setPrefWidth(PANEL_WIDTH); + panel.setPadding(new Insets(24, 24, 24, 12)); + HBox.setHgrow(panel, Priority.ALWAYS); + panel.getChildren().addAll( + createMonitorOutputArea(), + createMonitorOptionsPanel() + ); + return panel; + } + + private TextArea createMonitorOutputArea() { + TextArea textArea = new TextArea(); + textArea.setEditable(false); + textArea.setWrapText(true); + VBox.setVgrow(textArea, Priority.ALWAYS); + return textArea; + } + + private Node createMonitorOptionsPanel() { + VBox box = new VBox(); + box.setPadding(new Insets(8)); + box.getChildren().addAll( + createSubHeaderLabel("System Monitor"), + createSystemMonitoringCheckBoxes() + ); + return box; + } + + private Node createSystemMonitoringCheckBoxes() { + HBox checkBoxes = new HBox(8); + checkBoxes.setAlignment(Pos.CENTER_LEFT); + checkBoxes.getChildren().addAll( + new CheckBox("htop"), + new CheckBox("nvtop"), + new CheckBox("GPU-Monitor") + ); + return checkBoxes; + } + + // Helper method for creating TextField objects with bound text property + private TextField boundTextField(StringProperty boundProperty) { + TextField textField = new TextField(); + textField.textProperty().bindBidirectional(boundProperty); + return textField; + } + + // Helper method to create Label objects with a minimum width + private Label createLabel(String text) { + Label label = new Label(text); + label.setMinWidth(Control.USE_PREF_SIZE); + return label; + } + + private Label createHeaderLabel(String text) { + Label label = createLabel(text); + label.setStyle("-fx-font-size: 16pt; -fx-font-weight: bold;"); + return label; + } + + private Label createSubHeaderLabel(String text) { + Label label = createLabel(text); + label.setStyle("-fx-font-size: 12pt; -fx-font-weight: bold;"); + return label; + } +} \ No newline at end of file diff --git a/src/main/java/com/example/gui/LlamaChatbox.java b/src/main/java/com/example/gui/LlamaChatbox.java new file mode 100644 index 0000000..7ac2cb2 --- /dev/null +++ b/src/main/java/com/example/gui/LlamaChatbox.java @@ -0,0 +1,23 @@ +package com.example.gui; + +import atlantafx.base.theme.CupertinoDark; +import javafx.application.Application; +import javafx.scene.Scene; +import javafx.stage.Stage; + +public class LlamaChatbox extends Application { + + @Override + public void start(Stage stage) { + Application.setUserAgentStylesheet(new CupertinoDark().getUserAgentStylesheet()); + ChatboxController controller = new ChatboxController(); + Scene scene = new Scene(controller.getView(), 800, 600); + stage.setTitle("TornadoVM Chat"); + stage.setScene(scene); + stage.show(); + } + + public static void main(String[] args) { + launch(args); + } +} \ No newline at end of file diff --git a/src/main/java/com/example/loader/weights/ModelLoader.java b/src/main/java/com/example/loader/weights/ModelLoader.java index c4f7751..e96b8da 100644 --- a/src/main/java/com/example/loader/weights/ModelLoader.java +++ b/src/main/java/com/example/loader/weights/ModelLoader.java @@ -96,7 +96,8 @@ public static Weights loadWeights(Map tensorEntries, Co GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); - if (LlamaApp.USE_TORNADOVM) { + LlamaApp llamaApp = LlamaApp.getInstance(); + if (llamaApp.getUseTornadoVM()) { System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); } else { diff --git a/src/main/java/com/example/model/Model.java b/src/main/java/com/example/model/Model.java index e42349b..bbf2a46 100644 --- a/src/main/java/com/example/model/Model.java +++ b/src/main/java/com/example/model/Model.java @@ -1,5 +1,6 @@ package com.example.model; +import com.example.LlamaApp; import com.example.Options; import com.example.auxiliary.LastRunMetrics; import com.example.inference.InferenceEngine; @@ -17,7 +18,6 @@ import java.util.function.IntConsumer; import static com.example.LlamaApp.SHOW_PERF_INTERACTIVE; -import static com.example.LlamaApp.USE_TORNADOVM; public interface Model { Configuration configuration(); @@ -54,6 +54,9 @@ default void runInteractive(Sampler sampler, Options options) { // Initialize TornadoVM plan once at the beginning if GPU path is enabled TornadoVMMasterPlan tornadoVMPlan = null; + // Get the LlamaApp singleton to read configuration values + LlamaApp llamaApp = LlamaApp.getInstance(); + try { while (true) { System.out.print("> "); @@ -68,7 +71,7 @@ default void runInteractive(Sampler sampler, Options options) { state = createNewState(); } - if (USE_TORNADOVM && tornadoVMPlan == null) { + if (llamaApp.getUseTornadoVM() && tornadoVMPlan == null) { tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this); } @@ -86,7 +89,7 @@ default void runInteractive(Sampler sampler, Options options) { }; // Choose between GPU and CPU path based on configuration - if (USE_TORNADOVM) { + if (llamaApp.getUseTornadoVM()) { // GPU path using TornadoVM responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); @@ -121,7 +124,7 @@ default void runInteractive(Sampler sampler, Options options) { } } finally { // Clean up TornadoVM resources when exiting the chat loop - if (USE_TORNADOVM && tornadoVMPlan != null) { + if (llamaApp.getUseTornadoVM() && tornadoVMPlan != null) { try { tornadoVMPlan.freeTornadoExecutionPlan(); } catch (Exception e) { @@ -131,6 +134,104 @@ default void runInteractive(Sampler sampler, Options options) { } } + /** + * Model agnostic implementation for interactive GUI mode. + * Takes a single user input and returns the model's response, allowing the GUI to manage the chat loop. + * + * @param sampler The sampler for token generation + * @param options The inference options + * @param userText The user's input text + * @param previousResponse A Response object referencing an ongoing chat + * @return A Response object containing the model's output and updated state + */ + default Response runInteractiveStep(Sampler sampler, Options options, String userText, Response previousResponse) { + ChatFormat chatFormat = ChatFormat.create(tokenizer()); + List conversationTokens = previousResponse.conversationTokens(); + State state = previousResponse.state(); + TornadoVMMasterPlan tornadoVMPlan = previousResponse.tornadoVMPlan(); + String responseText = ""; + + int startPosition = conversationTokens.size(); + + // For the first message, set up the conversation tokens if empty + if (conversationTokens.isEmpty()) { + conversationTokens.add(chatFormat.getBeginOfText()); + if (options.systemPrompt() != null) { + conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); + } + } + + // Get the LlamaApp singleton to read configuration values + LlamaApp llamaApp = LlamaApp.getInstance(); + + if (state == null) { + // State allocation can take some time for large context sizes + state = createNewState(); + } + + // Initialize TornadoVM plan once at the beginning if GPU path is enabled + if (llamaApp.getUseTornadoVM() && tornadoVMPlan == null) { + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this); + } + + conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); + conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + Set stopTokens = chatFormat.getStopTokens(); + + List responseTokens; + IntConsumer tokenConsumer = token -> { + if (options.stream()) { + if (tokenizer().shouldDisplayToken(token)) { + System.out.print(tokenizer().decode(List.of(token))); + } + } + }; + + // Choose between GPU and CPU path based on configuration + if (llamaApp.getUseTornadoVM()) { + // GPU path using TornadoVM + responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, + options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); + } else { + // CPU path + responseTokens = InferenceEngine.generateTokens(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), + sampler, options.echo(), tokenConsumer); + } + + // Include stop token in the prompt history, but not in the response displayed to the user. + conversationTokens.addAll(responseTokens); + Integer stopToken = null; + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + stopToken = responseTokens.getLast(); + responseTokens.removeLast(); + } + if (!options.stream()) { + responseText = tokenizer().decode(responseTokens); + System.out.println(responseText); + } + if (stopToken == null) { + System.err.println("\n Ran out of context length...\n Increase context length by passing to llama-tornado --max-tokens XXX"); + return new Response(responseText, state, conversationTokens, tornadoVMPlan); + } + System.out.print("\n"); + + // Optionally print performance metrics after each response + if (SHOW_PERF_INTERACTIVE) { + LastRunMetrics.printMetrics(); + } + + return new Response(responseText, state, conversationTokens, tornadoVMPlan); + } + + /** + * Simple data model for Model responses, used to keep track of conversation history and state. + */ + record Response(String responseText, State state, List conversationTokens, TornadoVMMasterPlan tornadoVMPlan) { + public Response() { + this("", null, new ArrayList<>(), null); + } + } + /** * Model agnostic default implementation for instruct mode. * @param sampler @@ -140,6 +241,7 @@ default void runInstructOnce(Sampler sampler, Options options) { State state = createNewState(); ChatFormat chatFormat = ChatFormat.create(tokenizer()); TornadoVMMasterPlan tornadoVMPlan = null; + LlamaApp llamaApp = LlamaApp.getInstance(); List promptTokens = new ArrayList<>(); promptTokens.add(chatFormat.getBeginOfText()); @@ -162,7 +264,7 @@ default void runInstructOnce(Sampler sampler, Options options) { Set stopTokens = chatFormat.getStopTokens(); - if (USE_TORNADOVM) { + if (llamaApp.getUseTornadoVM()) { tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this); // Call generateTokensGPU without the token consumer parameter responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null,