Skip to content

Commit 4cebc36

Browse files
Add test to check thrift field IDs (#602)
* Add test to check thrift field IDs --------- Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 8b841c7 commit 4cebc36

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

tests/unit/test_thrift_field_ids.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import inspect
2+
import pytest
3+
4+
from databricks.sql.thrift_api.TCLIService import ttypes
5+
6+
7+
class TestThriftFieldIds:
8+
"""
9+
Unit test to validate that all Thrift-generated field IDs comply with the maximum limit.
10+
11+
Field IDs in Thrift must stay below 3329 to avoid conflicts with reserved ranges
12+
and ensure compatibility with various Thrift implementations and protocols.
13+
"""
14+
15+
MAX_ALLOWED_FIELD_ID = 3329
16+
17+
# Known exceptions that exceed the field ID limit
18+
KNOWN_EXCEPTIONS = {
19+
('TExecuteStatementReq', 'enforceEmbeddedSchemaCorrectness'): 3353,
20+
('TSessionHandle', 'serverProtocolVersion'): 3329,
21+
}
22+
23+
def test_all_thrift_field_ids_are_within_allowed_range(self):
24+
"""
25+
Validates that all field IDs in Thrift-generated classes are within the allowed range.
26+
27+
This test prevents field ID conflicts and ensures compatibility with different
28+
Thrift implementations and protocols.
29+
"""
30+
violations = []
31+
32+
# Get all classes from the ttypes module
33+
for name, obj in inspect.getmembers(ttypes):
34+
if (inspect.isclass(obj) and
35+
hasattr(obj, 'thrift_spec') and
36+
obj.thrift_spec is not None):
37+
38+
self._check_class_field_ids(obj, name, violations)
39+
40+
if violations:
41+
error_message = self._build_error_message(violations)
42+
pytest.fail(error_message)
43+
44+
def _check_class_field_ids(self, cls, class_name, violations):
45+
"""
46+
Checks all field IDs in a Thrift class and reports violations.
47+
48+
Args:
49+
cls: The Thrift class to check
50+
class_name: Name of the class for error reporting
51+
violations: List to append violation messages to
52+
"""
53+
thrift_spec = cls.thrift_spec
54+
55+
if not isinstance(thrift_spec, (tuple, list)):
56+
return
57+
58+
for spec_entry in thrift_spec:
59+
if spec_entry is None:
60+
continue
61+
62+
# Thrift spec format: (field_id, field_type, field_name, ...)
63+
if isinstance(spec_entry, (tuple, list)) and len(spec_entry) >= 3:
64+
field_id = spec_entry[0]
65+
field_name = spec_entry[2]
66+
67+
# Skip known exceptions
68+
if (class_name, field_name) in self.KNOWN_EXCEPTIONS:
69+
continue
70+
71+
if isinstance(field_id, int) and field_id >= self.MAX_ALLOWED_FIELD_ID:
72+
violations.append(
73+
"{} field '{}' has field ID {} (exceeds maximum of {})".format(
74+
class_name, field_name, field_id, self.MAX_ALLOWED_FIELD_ID - 1
75+
)
76+
)
77+
78+
def _build_error_message(self, violations):
79+
"""
80+
Builds a comprehensive error message for field ID violations.
81+
82+
Args:
83+
violations: List of violation messages
84+
85+
Returns:
86+
Formatted error message
87+
"""
88+
error_message = (
89+
"Found Thrift field IDs that exceed the maximum allowed value of {}.\n"
90+
"This can cause compatibility issues and conflicts with reserved ID ranges.\n"
91+
"Violations found:\n".format(self.MAX_ALLOWED_FIELD_ID - 1)
92+
)
93+
94+
for violation in violations:
95+
error_message += " - {}\n".format(violation)
96+
97+
return error_message

0 commit comments

Comments
 (0)