-
Notifications
You must be signed in to change notification settings - Fork 15
Add JUnit 5 support and initial unit test for Tokenizer #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ea7a3a2
a883978
69dee5f
6b36927
08addb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package com.example.tokenizer.impl; | ||
|
||
import com.example.tokenizer.vocabulary.Vocabulary; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.*; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
class MistralTokenizerTest { | ||
|
||
private Vocabulary vocabulary; | ||
private MistralTokenizer tokenizer; | ||
|
||
@BeforeEach | ||
void setup() { | ||
List<String> baseTokens = List.of("▁h", "e", "l", "o", "▁", "▁hello"); | ||
List<String> byteFallbackTokens = new ArrayList<>(); | ||
|
||
for (int i = 0; i < 256; i++) { | ||
byteFallbackTokens.add(String.format("<0x%02X>", i)); | ||
} | ||
|
||
List<String> allTokens = new ArrayList<>(); | ||
allTokens.addAll(baseTokens); | ||
allTokens.addAll(byteFallbackTokens); | ||
|
||
String[] tokens = allTokens.toArray(new String[0]); | ||
float[] scores = new float[tokens.length]; | ||
Arrays.fill(scores, 0.0f); // dummy scores | ||
|
||
int[] tokenTypes = new int[tokens.length]; | ||
Arrays.fill(tokenTypes, 1); // mark all normal | ||
tokenTypes[baseTokens.size()] = 0; // mark <0x00> as special | ||
|
||
Map<String, Object> metadata = new HashMap<>(); | ||
metadata.put("tokenizer.ggml.token_type", tokenTypes); | ||
|
||
vocabulary = new Vocabulary(tokens, scores); | ||
tokenizer = new MistralTokenizer(metadata, vocabulary); | ||
} | ||
|
||
@Test | ||
void testEncodeSimpleText() { | ||
List<Integer> tokens = tokenizer.encodeAsList("hello"); | ||
assertNotNull(tokens); | ||
assertFalse(tokens.isEmpty()); | ||
} | ||
|
||
@Test | ||
void testRegexPatternReturnsNull() { | ||
assertNull(tokenizer.regexPattern()); | ||
} | ||
Comment on lines
+51
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. regexPattern() should not return null |
||
|
||
@Test | ||
void testSpecialTokenDetection() { | ||
assertTrue(tokenizer.isSpecialToken(6)); | ||
assertFalse(tokenizer.isSpecialToken(0)); | ||
} | ||
|
||
@Test | ||
void testShouldDisplayToken() { | ||
assertTrue(tokenizer.shouldDisplayToken(0)); | ||
assertFalse(tokenizer.shouldDisplayToken(6)); | ||
} | ||
|
||
@Test | ||
void testDecodeSpecialByteFallbackToken() { | ||
List<Integer> tokens = List.of(6); // token <0x00> | ||
String result = tokenizer.decode(tokens); | ||
assertEquals("\u0000", result); // ASCII for <0x00> | ||
} | ||
|
||
@Test | ||
void testEncodeEmptyInput() { | ||
List<Integer> tokens = tokenizer.encodeAsList(""); | ||
assertTrue(tokens.isEmpty(), "Should return empty token list for empty input"); | ||
} | ||
} |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This TokenizerTest covers the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package com.example.tokenizer.impl; | ||
|
||
import org.junit.jupiter.api.Test; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
class TokenizerInterfaceTest { | ||
|
||
@Test | ||
void testReplaceControlCharactersWithCodePoints() { | ||
int[] input = {'H', 'e', '\n', 0x07, 'l', 'o'}; // 0x07 = BEL (control character) | ||
String result = Tokenizer.replaceControlCharacters(input); | ||
|
||
assertEquals("He\n\\u0007lo", result); // \n allowed, BEL escaped | ||
} | ||
|
||
@Test | ||
void testReplaceControlCharactersWithString() { | ||
String input = "He\n\u0007lo"; // \u0007 is a bell character (non-printable control char) | ||
String result = Tokenizer.replaceControlCharacters(input); | ||
|
||
assertEquals("He\n\\u0007lo", result); | ||
} | ||
|
||
@Test | ||
void testReplaceControlCharactersWithOnlyPrintableChars() { | ||
String input = "Hello, World!"; | ||
String result = Tokenizer.replaceControlCharacters(input); | ||
|
||
assertEquals(input, result); | ||
} | ||
|
||
@Test | ||
void testReplaceControlCharactersWithMultipleControlChars() { | ||
String input = "\u0001\u0002A\nB\u0003"; // \u0001, \u0002, \u0003 are control chars | ||
String result = Tokenizer.replaceControlCharacters(input); | ||
|
||
assertEquals("\\u0001\\u0002A\nB\\u0003", result); | ||
} | ||
|
||
@Test | ||
void testReplaceControlCharactersEmptyInput() { | ||
String input = ""; | ||
String result = Tokenizer.replaceControlCharacters(input); | ||
|
||
assertEquals("", result); | ||
} | ||
|
||
@Test | ||
void testReplaceControlCharactersNullSafe() { | ||
// Add this test if you plan to make it null-safe. | ||
assertThrows(NullPointerException.class, () -> { | ||
Tokenizer.replaceControlCharacters((String) null); | ||
}); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect a check based on the expected result rather than checking if it's not null and empty.