Skip to content

Commit 833e9ea

Browse files
committed
Update weight type handling in configureQuantizedMatrixVectorFinalWeight
Refactor the logic to include `F16` weight type alongside `Q8_0`. Simplifies task configuration by removing redundant break statements for better readability and maintainability.
1 parent 23bc075 commit 833e9ea

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,8 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
184184
// @formatter:on
185185
private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
186186
switch (weights.weightType) {
187+
case F16:
187188
case Q8_0:
188-
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
189-
context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
190-
config.dim, config.vocabularySize, LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); //
191-
break;
192189
case Q4_0:
193190
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
194191
context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //

0 commit comments

Comments
 (0)