1
+ import inspect
2
+ import pytest
3
+
4
+ from databricks .sql .thrift_api .TCLIService import ttypes
5
+
6
+
7
+ class TestThriftFieldIds :
8
+ """
9
+ Unit test to validate that all Thrift-generated field IDs comply with the maximum limit.
10
+
11
+ Field IDs in Thrift must stay below 3329 to avoid conflicts with reserved ranges
12
+ and ensure compatibility with various Thrift implementations and protocols.
13
+ """
14
+
15
+ MAX_ALLOWED_FIELD_ID = 3329
16
+
17
+ # Known exceptions that exceed the field ID limit
18
+ KNOWN_EXCEPTIONS = {
19
+ ('TExecuteStatementReq' , 'enforceEmbeddedSchemaCorrectness' ): 3353 ,
20
+ ('TSessionHandle' , 'serverProtocolVersion' ): 3329 ,
21
+ }
22
+
23
+ def test_all_thrift_field_ids_are_within_allowed_range (self ):
24
+ """
25
+ Validates that all field IDs in Thrift-generated classes are within the allowed range.
26
+
27
+ This test prevents field ID conflicts and ensures compatibility with different
28
+ Thrift implementations and protocols.
29
+ """
30
+ violations = []
31
+
32
+ # Get all classes from the ttypes module
33
+ for name , obj in inspect .getmembers (ttypes ):
34
+ if (inspect .isclass (obj ) and
35
+ hasattr (obj , 'thrift_spec' ) and
36
+ obj .thrift_spec is not None ):
37
+
38
+ self ._check_class_field_ids (obj , name , violations )
39
+
40
+ if violations :
41
+ error_message = self ._build_error_message (violations )
42
+ pytest .fail (error_message )
43
+
44
+ def _check_class_field_ids (self , cls , class_name , violations ):
45
+ """
46
+ Checks all field IDs in a Thrift class and reports violations.
47
+
48
+ Args:
49
+ cls: The Thrift class to check
50
+ class_name: Name of the class for error reporting
51
+ violations: List to append violation messages to
52
+ """
53
+ thrift_spec = cls .thrift_spec
54
+
55
+ if not isinstance (thrift_spec , (tuple , list )):
56
+ return
57
+
58
+ for spec_entry in thrift_spec :
59
+ if spec_entry is None :
60
+ continue
61
+
62
+ # Thrift spec format: (field_id, field_type, field_name, ...)
63
+ if isinstance (spec_entry , (tuple , list )) and len (spec_entry ) >= 3 :
64
+ field_id = spec_entry [0 ]
65
+ field_name = spec_entry [2 ]
66
+
67
+ # Skip known exceptions
68
+ if (class_name , field_name ) in self .KNOWN_EXCEPTIONS :
69
+ continue
70
+
71
+ if isinstance (field_id , int ) and field_id >= self .MAX_ALLOWED_FIELD_ID :
72
+ violations .append (
73
+ "{} field '{}' has field ID {} (exceeds maximum of {})" .format (
74
+ class_name , field_name , field_id , self .MAX_ALLOWED_FIELD_ID - 1
75
+ )
76
+ )
77
+
78
+ def _build_error_message (self , violations ):
79
+ """
80
+ Builds a comprehensive error message for field ID violations.
81
+
82
+ Args:
83
+ violations: List of violation messages
84
+
85
+ Returns:
86
+ Formatted error message
87
+ """
88
+ error_message = (
89
+ "Found Thrift field IDs that exceed the maximum allowed value of {}.\n "
90
+ "This can cause compatibility issues and conflicts with reserved ID ranges.\n "
91
+ "Violations found:\n " .format (self .MAX_ALLOWED_FIELD_ID - 1 )
92
+ )
93
+
94
+ for violation in violations :
95
+ error_message += " - {}\n " .format (violation )
96
+
97
+ return error_message
0 commit comments