Skip to content

Commit bbf70db

Browse files
rkuestersuleshahid
andauthored
feat(compression): add SpecBuilder for programmatic compression specs (#3133)
Add a fluent builder API for creating compression specifications without writing YAML strings. This is useful in scripts and Jupyter notebooks. Example usage: spec = (compression.SpecBuilder() .add_tensor(subgraph=0, tensor=2) .with_lut(index_bitwidth=4) .build()) BUG=#3125 Co-authored-by: suleshahid <110432064+suleshahid@users.noreply.github.com>
1 parent 41a9c8a commit bbf70db

File tree

5 files changed

+239
-1
lines changed

5 files changed

+239
-1
lines changed

python/tflite_micro/postinstall_check.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def compression_test():
6161
# with compressible tensors, but we verify the function is importable
6262
assert callable(compression.compress)
6363

64+
# Test availability of the SpecBuilder
65+
_ = (compression.SpecBuilder().add_tensor(
66+
subgraph=0, tensor=0).with_lut(index_bitwidth=4).build())
67+
6468
return True
6569

6670

tensorflow/lite/micro/compression/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ py_library(
2323
deps = [
2424
":compress_lib",
2525
":spec",
26+
":spec_builder",
2627
],
2728
)
2829

@@ -189,6 +190,25 @@ py_test(
189190
],
190191
)
191192

193+
py_library(
194+
name = "spec_builder",
195+
srcs = ["spec_builder.py"],
196+
deps = [
197+
":spec",
198+
],
199+
)
200+
201+
py_test(
202+
name = "spec_builder_test",
203+
size = "small",
204+
srcs = ["spec_builder_test.py"],
205+
deps = [
206+
":spec",
207+
":spec_builder",
208+
requirement("tensorflow"),
209+
],
210+
)
211+
192212
py_library(
193213
name = "test_models",
194214
srcs = ["test_models.py"],

tensorflow/lite/micro/compression/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@
2222

2323
from .compress import compress
2424
from .spec import parse_yaml
25+
from .spec_builder import SpecBuilder
2526

26-
__all__ = ["compress", "parse_yaml"]
27+
__all__ = ["compress", "parse_yaml", "SpecBuilder"]
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Builder pattern for creating compression specifications programmatically.
16+
17+
This module provides a fluent API for building compression specs without
18+
needing to write YAML strings.
19+
20+
Example usage:
21+
from tflite_micro.compression import SpecBuilder
22+
23+
spec = (SpecBuilder()
24+
.add_tensor(subgraph=0, tensor=2)
25+
.with_lut(index_bitwidth=4)
26+
.add_tensor(subgraph=0, tensor=4)
27+
.with_lut(index_bitwidth=2)
28+
.build())
29+
"""
30+
31+
from typing import List, Optional
32+
from . import spec
33+
34+
35+
class TensorBuilder:
36+
"""Builder for individual tensor compression specifications."""
37+
38+
def __init__(self, subgraph: int, tensor: int,
39+
parent_builder: 'SpecBuilder'):
40+
self.subgraph = subgraph
41+
self.tensor = tensor
42+
self.compression_methods: List[spec.CompressionMethod] = []
43+
self._parent = parent_builder
44+
45+
def with_lut(self, index_bitwidth: int) -> 'SpecBuilder':
46+
"""Add LUT compression to this tensor.
47+
48+
Args:
49+
index_bitwidth: Number of bits for the LUT index (e.g., 4 for 16 values)
50+
51+
Returns:
52+
The parent SpecBuilder for method chaining
53+
"""
54+
self.compression_methods.append(
55+
spec.LookUpTableCompression(index_bitwidth=index_bitwidth))
56+
return self._parent
57+
58+
def _build(self) -> spec.Tensor:
59+
"""Build the Tensor specification object."""
60+
return spec.Tensor(subgraph=self.subgraph,
61+
tensor=self.tensor,
62+
compression=self.compression_methods)
63+
64+
65+
class SpecBuilder:
66+
"""Fluent builder for compression specifications."""
67+
68+
def __init__(self):
69+
self._tensor_builders: List[TensorBuilder] = []
70+
self._current_tensor: Optional[TensorBuilder] = None
71+
72+
def add_tensor(self, subgraph: int, tensor: int) -> TensorBuilder:
73+
"""Add a tensor to be compressed.
74+
75+
Args:
76+
subgraph: The subgraph index containing the tensor
77+
tensor: The tensor index within the subgraph
78+
79+
Returns:
80+
A TensorBuilder for configuring compression methods
81+
"""
82+
# Finalize any current tensor
83+
if self._current_tensor is not None:
84+
self._tensor_builders.append(self._current_tensor)
85+
86+
# Create new tensor builder
87+
self._current_tensor = TensorBuilder(subgraph, tensor, self)
88+
return self._current_tensor
89+
90+
def build(self) -> List[spec.Tensor]:
91+
"""Build the final compression specification.
92+
93+
Returns:
94+
A list of Tensor specifications ready for use with compress()
95+
"""
96+
# Make sure to include the last tensor if there is one
97+
if self._current_tensor is not None:
98+
self._tensor_builders.append(self._current_tensor)
99+
self._current_tensor = None
100+
101+
return [tb._build() for tb in self._tensor_builders]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Tests for the compression spec builder."""
16+
17+
import tensorflow as tf
18+
19+
from tflite_micro.tensorflow.lite.micro.compression import spec
20+
from tflite_micro.tensorflow.lite.micro.compression import spec_builder
21+
22+
23+
class SpecBuilderTest(tf.test.TestCase):
24+
25+
def test_basic_builder_pattern(self):
26+
"""Test basic fluent builder usage."""
27+
result = (spec_builder.SpecBuilder().add_tensor(
28+
subgraph=0, tensor=2).with_lut(index_bitwidth=4).add_tensor(
29+
subgraph=0, tensor=4).with_lut(index_bitwidth=2).build())
30+
31+
self.assertEqual(len(result), 2)
32+
33+
# Check first tensor
34+
self.assertEqual(result[0].subgraph, 0)
35+
self.assertEqual(result[0].tensor, 2)
36+
self.assertEqual(len(result[0].compression), 1)
37+
self.assertIsInstance(result[0].compression[0],
38+
spec.LookUpTableCompression)
39+
self.assertEqual(result[0].compression[0].index_bitwidth, 4)
40+
41+
# Check second tensor
42+
self.assertEqual(result[1].subgraph, 0)
43+
self.assertEqual(result[1].tensor, 4)
44+
self.assertEqual(len(result[1].compression), 1)
45+
self.assertIsInstance(result[1].compression[0],
46+
spec.LookUpTableCompression)
47+
self.assertEqual(result[1].compression[0].index_bitwidth, 2)
48+
49+
def test_non_chained_usage(self):
50+
"""Test using builder without method chaining."""
51+
builder = spec_builder.SpecBuilder()
52+
builder.add_tensor(0, 2).with_lut(4)
53+
builder.add_tensor(0, 4).with_lut(2)
54+
result = builder.build()
55+
56+
self.assertEqual(len(result), 2)
57+
self.assertEqual(result[0].tensor, 2)
58+
self.assertEqual(result[0].compression[0].index_bitwidth, 4)
59+
self.assertEqual(result[1].tensor, 4)
60+
self.assertEqual(result[1].compression[0].index_bitwidth, 2)
61+
62+
def test_empty_spec(self):
63+
"""Test building an empty spec."""
64+
result = spec_builder.SpecBuilder().build()
65+
self.assertEqual(len(result), 0)
66+
67+
def test_single_tensor(self):
68+
"""Test building a spec with just one tensor."""
69+
result = (spec_builder.SpecBuilder().add_tensor(
70+
subgraph=2, tensor=42).with_lut(index_bitwidth=16).build())
71+
72+
self.assertEqual(len(result), 1)
73+
self.assertEqual(result[0].subgraph, 2)
74+
self.assertEqual(result[0].tensor, 42)
75+
self.assertEqual(result[0].compression[0].index_bitwidth, 16)
76+
77+
def test_tensor_without_compression(self):
78+
"""Test that tensors can be added without compression methods."""
79+
builder = spec_builder.SpecBuilder()
80+
# Add tensor but don't call with_lut
81+
builder.add_tensor(0, 1)
82+
builder.add_tensor(0, 2).with_lut(4)
83+
result = builder.build()
84+
85+
self.assertEqual(len(result), 2)
86+
self.assertEqual(result[0].tensor, 1)
87+
self.assertEqual(len(result[0].compression), 0)
88+
self.assertEqual(result[1].tensor, 2)
89+
self.assertEqual(len(result[1].compression), 1)
90+
91+
def test_builder_produces_same_type_as_parse_yaml(self):
92+
"""Test that builder produces same data structure as parse_yaml."""
93+
# Build using the builder
94+
built_spec = (spec_builder.SpecBuilder().add_tensor(
95+
subgraph=0, tensor=42).with_lut(index_bitwidth=4).add_tensor(
96+
subgraph=0, tensor=55).with_lut(index_bitwidth=2).build())
97+
98+
# Parse the example YAML from spec.py
99+
parsed_spec = spec.parse_yaml(spec.EXAMPLE_YAML_SPEC)
100+
101+
# They should be equivalent
102+
self.assertEqual(len(built_spec), len(parsed_spec))
103+
for built, parsed in zip(built_spec, parsed_spec):
104+
self.assertEqual(built.subgraph, parsed.subgraph)
105+
self.assertEqual(built.tensor, parsed.tensor)
106+
self.assertEqual(len(built.compression), len(parsed.compression))
107+
self.assertEqual(built.compression[0].index_bitwidth,
108+
parsed.compression[0].index_bitwidth)
109+
110+
111+
if __name__ == "__main__":
112+
tf.test.main()

0 commit comments

Comments
 (0)