Skip to content

Add JUnit 5 support and initial unit test for Tokenizer #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,34 @@ Click [here](https://github.com/beehive-lab/GPULlama3.java/tree/main/docs/TORNAD

Click [here](https://github.com/beehive-lab/GPULlama3.java/tree/main/docs/GPULlama3_ROADMAP.md) to see the roadmap of the project.


## Run All Tests

You can run all unit tests using the following Maven command:

```bash
mvn test

Sample Output
-------------------------------------------------------
T E S T S
-------------------------------------------------------
Running com.example.tokenizer.impl.MistralTokenizerTest
Running com.example.tokenizer.impl.TokenizerInterfaceTest

Tests run: 12, Failures: 0, Errors: 0, Skipped: 0

To run tests inside an IDE (e.g., IntelliJ), right-click on the test classes and choose Run.

## Test Coverage
Here are the tokenizer unit tests included:

| **Test Class** | **Description** |
|--------------------------|---------------------------------------------------------------------------------------------------------------------------------------------|
| `MistralTokenizerTest` | Verifies Mistral tokenizer functionality including byte fallback (`<0xXX>`), special token handling, encoding and decoding logic |
| `TokenizerInterfaceTest` | Unit tests for utility methods like `replaceControlCharacters`, ensuring printable and safe token rendering |

```
-----------

## Acknowledgments
Expand Down
11 changes: 11 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
<artifactId>tornado-runtime</artifactId>
<version>1.1.1-dev</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.10.2</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -68,6 +74,11 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.2.5</version>
</plugin>
</plugins>
</build>
</project>
80 changes: 80 additions & 0 deletions src/test/java/com/example/tokenizer/impl/MistralTokenizerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.example.tokenizer.impl;

import com.example.tokenizer.vocabulary.Vocabulary;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.*;

import static org.junit.jupiter.api.Assertions.*;

class MistralTokenizerTest {

private Vocabulary vocabulary;
private MistralTokenizer tokenizer;

@BeforeEach
void setup() {
List<String> baseTokens = List.of("▁h", "e", "l", "o", "▁", "▁hello");
List<String> byteFallbackTokens = new ArrayList<>();

for (int i = 0; i < 256; i++) {
byteFallbackTokens.add(String.format("<0x%02X>", i));
}

List<String> allTokens = new ArrayList<>();
allTokens.addAll(baseTokens);
allTokens.addAll(byteFallbackTokens);

String[] tokens = allTokens.toArray(new String[0]);
float[] scores = new float[tokens.length];
Arrays.fill(scores, 0.0f); // dummy scores

int[] tokenTypes = new int[tokens.length];
Arrays.fill(tokenTypes, 1); // mark all normal
tokenTypes[baseTokens.size()] = 0; // mark <0x00> as special

Map<String, Object> metadata = new HashMap<>();
metadata.put("tokenizer.ggml.token_type", tokenTypes);

vocabulary = new Vocabulary(tokens, scores);
tokenizer = new MistralTokenizer(metadata, vocabulary);
}

@Test
void testEncodeSimpleText() {
List<Integer> tokens = tokenizer.encodeAsList("hello");
assertNotNull(tokens);
assertFalse(tokens.isEmpty());
Comment on lines +47 to +48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect a check based on the expected result rather than checking if it's not null and empty.

}

@Test
void testRegexPatternReturnsNull() {
assertNull(tokenizer.regexPattern());
}
Comment on lines +51 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regexPattern() should not return null


@Test
void testSpecialTokenDetection() {
assertTrue(tokenizer.isSpecialToken(6));
assertFalse(tokenizer.isSpecialToken(0));
}

@Test
void testShouldDisplayToken() {
assertTrue(tokenizer.shouldDisplayToken(0));
assertFalse(tokenizer.shouldDisplayToken(6));
}

@Test
void testDecodeSpecialByteFallbackToken() {
List<Integer> tokens = List.of(6); // token <0x00>
String result = tokenizer.decode(tokens);
assertEquals("\u0000", result); // ASCII for <0x00>
}

@Test
void testEncodeEmptyInput() {
List<Integer> tokens = tokenizer.encodeAsList("");
assertTrue(tokens.isEmpty(), "Should return empty token list for empty input");
}
}
56 changes: 56 additions & 0 deletions src/test/java/com/example/tokenizer/impl/TokenizerTest.java
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TokenizerTest covers the static replaceControlCharacters() methods which are implemented in the Tokenizer interface, not LlamaTokenizer. In order this junit support to make sense I would suggest to add coverage for LlamaTokenizer as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.example.tokenizer.impl;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.*;

class TokenizerInterfaceTest {

@Test
void testReplaceControlCharactersWithCodePoints() {
int[] input = {'H', 'e', '\n', 0x07, 'l', 'o'}; // 0x07 = BEL (control character)
String result = Tokenizer.replaceControlCharacters(input);

assertEquals("He\n\\u0007lo", result); // \n allowed, BEL escaped
}

@Test
void testReplaceControlCharactersWithString() {
String input = "He\n\u0007lo"; // \u0007 is a bell character (non-printable control char)
String result = Tokenizer.replaceControlCharacters(input);

assertEquals("He\n\\u0007lo", result);
}

@Test
void testReplaceControlCharactersWithOnlyPrintableChars() {
String input = "Hello, World!";
String result = Tokenizer.replaceControlCharacters(input);

assertEquals(input, result);
}

@Test
void testReplaceControlCharactersWithMultipleControlChars() {
String input = "\u0001\u0002A\nB\u0003"; // \u0001, \u0002, \u0003 are control chars
String result = Tokenizer.replaceControlCharacters(input);

assertEquals("\\u0001\\u0002A\nB\\u0003", result);
}

@Test
void testReplaceControlCharactersEmptyInput() {
String input = "";
String result = Tokenizer.replaceControlCharacters(input);

assertEquals("", result);
}

@Test
void testReplaceControlCharactersNullSafe() {
// Add this test if you plan to make it null-safe.
assertThrows(NullPointerException.class, () -> {
Tokenizer.replaceControlCharacters((String) null);
});
}
}