|
| 1 | +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | +"""Tests for the compression spec builder.""" |
| 16 | + |
| 17 | +import tensorflow as tf |
| 18 | + |
| 19 | +from tflite_micro.tensorflow.lite.micro.compression import spec |
| 20 | +from tflite_micro.tensorflow.lite.micro.compression import spec_builder |
| 21 | + |
| 22 | + |
| 23 | +class SpecBuilderTest(tf.test.TestCase): |
| 24 | + |
| 25 | + def test_basic_builder_pattern(self): |
| 26 | + """Test basic fluent builder usage.""" |
| 27 | + result = (spec_builder.SpecBuilder().add_tensor( |
| 28 | + subgraph=0, tensor=2).with_lut(index_bitwidth=4).add_tensor( |
| 29 | + subgraph=0, tensor=4).with_lut(index_bitwidth=2).build()) |
| 30 | + |
| 31 | + self.assertEqual(len(result), 2) |
| 32 | + |
| 33 | + # Check first tensor |
| 34 | + self.assertEqual(result[0].subgraph, 0) |
| 35 | + self.assertEqual(result[0].tensor, 2) |
| 36 | + self.assertEqual(len(result[0].compression), 1) |
| 37 | + self.assertIsInstance(result[0].compression[0], |
| 38 | + spec.LookUpTableCompression) |
| 39 | + self.assertEqual(result[0].compression[0].index_bitwidth, 4) |
| 40 | + |
| 41 | + # Check second tensor |
| 42 | + self.assertEqual(result[1].subgraph, 0) |
| 43 | + self.assertEqual(result[1].tensor, 4) |
| 44 | + self.assertEqual(len(result[1].compression), 1) |
| 45 | + self.assertIsInstance(result[1].compression[0], |
| 46 | + spec.LookUpTableCompression) |
| 47 | + self.assertEqual(result[1].compression[0].index_bitwidth, 2) |
| 48 | + |
| 49 | + def test_non_chained_usage(self): |
| 50 | + """Test using builder without method chaining.""" |
| 51 | + builder = spec_builder.SpecBuilder() |
| 52 | + builder.add_tensor(0, 2).with_lut(4) |
| 53 | + builder.add_tensor(0, 4).with_lut(2) |
| 54 | + result = builder.build() |
| 55 | + |
| 56 | + self.assertEqual(len(result), 2) |
| 57 | + self.assertEqual(result[0].tensor, 2) |
| 58 | + self.assertEqual(result[0].compression[0].index_bitwidth, 4) |
| 59 | + self.assertEqual(result[1].tensor, 4) |
| 60 | + self.assertEqual(result[1].compression[0].index_bitwidth, 2) |
| 61 | + |
| 62 | + def test_empty_spec(self): |
| 63 | + """Test building an empty spec.""" |
| 64 | + result = spec_builder.SpecBuilder().build() |
| 65 | + self.assertEqual(len(result), 0) |
| 66 | + |
| 67 | + def test_single_tensor(self): |
| 68 | + """Test building a spec with just one tensor.""" |
| 69 | + result = (spec_builder.SpecBuilder().add_tensor( |
| 70 | + subgraph=2, tensor=42).with_lut(index_bitwidth=16).build()) |
| 71 | + |
| 72 | + self.assertEqual(len(result), 1) |
| 73 | + self.assertEqual(result[0].subgraph, 2) |
| 74 | + self.assertEqual(result[0].tensor, 42) |
| 75 | + self.assertEqual(result[0].compression[0].index_bitwidth, 16) |
| 76 | + |
| 77 | + def test_tensor_without_compression(self): |
| 78 | + """Test that tensors can be added without compression methods.""" |
| 79 | + builder = spec_builder.SpecBuilder() |
| 80 | + # Add tensor but don't call with_lut |
| 81 | + builder.add_tensor(0, 1) |
| 82 | + builder.add_tensor(0, 2).with_lut(4) |
| 83 | + result = builder.build() |
| 84 | + |
| 85 | + self.assertEqual(len(result), 2) |
| 86 | + self.assertEqual(result[0].tensor, 1) |
| 87 | + self.assertEqual(len(result[0].compression), 0) |
| 88 | + self.assertEqual(result[1].tensor, 2) |
| 89 | + self.assertEqual(len(result[1].compression), 1) |
| 90 | + |
| 91 | + def test_builder_produces_same_type_as_parse_yaml(self): |
| 92 | + """Test that builder produces same data structure as parse_yaml.""" |
| 93 | + # Build using the builder |
| 94 | + built_spec = (spec_builder.SpecBuilder().add_tensor( |
| 95 | + subgraph=0, tensor=42).with_lut(index_bitwidth=4).add_tensor( |
| 96 | + subgraph=0, tensor=55).with_lut(index_bitwidth=2).build()) |
| 97 | + |
| 98 | + # Parse the example YAML from spec.py |
| 99 | + parsed_spec = spec.parse_yaml(spec.EXAMPLE_YAML_SPEC) |
| 100 | + |
| 101 | + # They should be equivalent |
| 102 | + self.assertEqual(len(built_spec), len(parsed_spec)) |
| 103 | + for built, parsed in zip(built_spec, parsed_spec): |
| 104 | + self.assertEqual(built.subgraph, parsed.subgraph) |
| 105 | + self.assertEqual(built.tensor, parsed.tensor) |
| 106 | + self.assertEqual(len(built.compression), len(parsed.compression)) |
| 107 | + self.assertEqual(built.compression[0].index_bitwidth, |
| 108 | + parsed.compression[0].index_bitwidth) |
| 109 | + |
| 110 | + |
| 111 | +if __name__ == "__main__": |
| 112 | + tf.test.main() |
0 commit comments