Skip to content

Add test to check thrift field IDs #601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/unit/test_thrift_field_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Unit test to validate that all Thrift-generated field IDs comply with the maximum limit.

Field IDs in Thrift must stay below 3329 to avoid conflicts with reserved ranges
and ensure compatibility with various Thrift implementations and protocols.
"""

import inspect
import pytest
import unittest

from databricks.sql.thrift_api.TCLIService import ttypes


class TestThriftFieldIds(unittest.TestCase):
"""Test suite for validating Thrift field ID constraints."""

MAX_ALLOWED_FIELD_ID = 3329

# Known exceptions that exceed the field ID limit
KNOWN_EXCEPTIONS = {
('TExecuteStatementReq', 'enforceEmbeddedSchemaCorrectness'): 3353,
('TSessionHandle', 'serverProtocolVersion'): 3329,
}

def test_all_thrift_field_ids_are_within_allowed_range(self):
"""
Validates that all field IDs in Thrift-generated classes are within the allowed range.

This test prevents field ID conflicts and ensures compatibility with different
Thrift implementations and protocols.
"""
violations = []

# Get all classes from the ttypes module
for name, obj in inspect.getmembers(ttypes):
if (inspect.isclass(obj) and
hasattr(obj, 'thrift_spec') and
obj.thrift_spec is not None):

self._check_class_field_ids(obj, name, violations)

if violations:
error_message = self._build_error_message(violations)
self.fail(error_message)

def _check_class_field_ids(self, cls, class_name, violations):
"""
Checks all field IDs in a Thrift class and reports violations.

Args:
cls: The Thrift class to check
class_name: Name of the class for error reporting
violations: List to append violation messages to
"""
thrift_spec = cls.thrift_spec

if not isinstance(thrift_spec, (tuple, list)):
return

for spec_entry in thrift_spec:
if spec_entry is None:
continue

# Thrift spec format: (field_id, field_type, field_name, ...)
if isinstance(spec_entry, (tuple, list)) and len(spec_entry) >= 3:
field_id = spec_entry[0]
field_name = spec_entry[2]

# Skip known exceptions
if (class_name, field_name) in self.KNOWN_EXCEPTIONS:
continue

if isinstance(field_id, int) and field_id >= self.MAX_ALLOWED_FIELD_ID:
violations.append(
"{} field '{}' has field ID {} (exceeds maximum of {})".format(
class_name, field_name, field_id, self.MAX_ALLOWED_FIELD_ID - 1
)
)

def _build_error_message(self, violations):
"""
Builds a comprehensive error message for field ID violations.

Args:
violations: List of violation messages

Returns:
Formatted error message
"""
error_message = (
"Found Thrift field IDs that exceed the maximum allowed value of {}.\n"
"This can cause compatibility issues and conflicts with reserved ID ranges.\n"
"Violations found:\n".format(self.MAX_ALLOWED_FIELD_ID - 1)
)

for violation in violations:
error_message += " - {}\n".format(violation)

return error_message