Skip to content

Commit 2144aab

Browse files
Add test to check thrift field IDs
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent c123af3 commit 2144aab

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

tests/unit/test_thrift_field_ids.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
Unit test to validate that all Thrift-generated field IDs comply with the maximum limit.
3+
4+
Field IDs in Thrift must stay below 3329 to avoid conflicts with reserved ranges
5+
and ensure compatibility with various Thrift implementations and protocols.
6+
"""
7+
8+
import inspect
9+
import pytest
10+
import unittest
11+
12+
from databricks.sql.thrift_api.TCLIService import ttypes
13+
14+
15+
class TestThriftFieldIds(unittest.TestCase):
16+
"""Test suite for validating Thrift field ID constraints."""
17+
18+
MAX_ALLOWED_FIELD_ID = 3329
19+
20+
# Known exceptions that exceed the field ID limit
21+
KNOWN_EXCEPTIONS = {
22+
('TExecuteStatementReq', 'enforceEmbeddedSchemaCorrectness'): 3353,
23+
('TSessionHandle', 'serverProtocolVersion'): 3329,
24+
}
25+
26+
def test_all_thrift_field_ids_are_within_allowed_range(self):
27+
"""
28+
Validates that all field IDs in Thrift-generated classes are within the allowed range.
29+
30+
This test prevents field ID conflicts and ensures compatibility with different
31+
Thrift implementations and protocols.
32+
"""
33+
violations = []
34+
35+
# Get all classes from the ttypes module
36+
for name, obj in inspect.getmembers(ttypes):
37+
if (inspect.isclass(obj) and
38+
hasattr(obj, 'thrift_spec') and
39+
obj.thrift_spec is not None):
40+
41+
self._check_class_field_ids(obj, name, violations)
42+
43+
if violations:
44+
error_message = self._build_error_message(violations)
45+
self.fail(error_message)
46+
47+
def _check_class_field_ids(self, cls, class_name, violations):
48+
"""
49+
Checks all field IDs in a Thrift class and reports violations.
50+
51+
Args:
52+
cls: The Thrift class to check
53+
class_name: Name of the class for error reporting
54+
violations: List to append violation messages to
55+
"""
56+
thrift_spec = cls.thrift_spec
57+
58+
if not isinstance(thrift_spec, (tuple, list)):
59+
return
60+
61+
for spec_entry in thrift_spec:
62+
if spec_entry is None:
63+
continue
64+
65+
# Thrift spec format: (field_id, field_type, field_name, ...)
66+
if isinstance(spec_entry, (tuple, list)) and len(spec_entry) >= 3:
67+
field_id = spec_entry[0]
68+
field_name = spec_entry[2]
69+
70+
# Skip known exceptions
71+
if (class_name, field_name) in self.KNOWN_EXCEPTIONS:
72+
continue
73+
74+
if isinstance(field_id, int) and field_id >= self.MAX_ALLOWED_FIELD_ID:
75+
violations.append(
76+
"{} field '{}' has field ID {} (exceeds maximum of {})".format(
77+
class_name, field_name, field_id, self.MAX_ALLOWED_FIELD_ID - 1
78+
)
79+
)
80+
81+
def _build_error_message(self, violations):
82+
"""
83+
Builds a comprehensive error message for field ID violations.
84+
85+
Args:
86+
violations: List of violation messages
87+
88+
Returns:
89+
Formatted error message
90+
"""
91+
error_message = (
92+
"Found Thrift field IDs that exceed the maximum allowed value of {}.\n"
93+
"This can cause compatibility issues and conflicts with reserved ID ranges.\n"
94+
"Violations found:\n".format(self.MAX_ALLOWED_FIELD_ID - 1)
95+
)
96+
97+
for violation in violations:
98+
error_message += " - {}\n".format(violation)
99+
100+
return error_message

0 commit comments

Comments
 (0)