Skip to content

Commit 746c05d

Browse files
Ensure thrift field IDs stay within range (#276)
Ensure thrift field IDs stay within range
2 parents 12d2ced + 039c53a commit 746c05d

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package cli_service
2+
3+
import (
4+
"bufio"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"regexp"
9+
"runtime"
10+
"strconv"
11+
"strings"
12+
"testing"
13+
)
14+
15+
// TestThriftFieldIdsAreWithinAllowedRange validates that all Thrift field IDs
16+
// in cli_service.go are within the allowed range.
17+
//
18+
// Field IDs in Thrift must stay below 3329 to avoid conflicts with reserved ranges
19+
// and ensure compatibility with various Thrift implementations and protocols.
20+
func TestThriftFieldIdsAreWithinAllowedRange(t *testing.T) {
21+
const maxAllowedFieldID = 3329
22+
23+
// Get the directory of this test file
24+
_, filename, _, ok := runtime.Caller(0)
25+
if !ok {
26+
t.Fatal("Failed to get current file path")
27+
}
28+
29+
// Build path to cli_service.go
30+
testDir := filepath.Dir(filename)
31+
cliServicePath := filepath.Join(testDir, "cli_service.go")
32+
33+
violations, err := validateThriftFieldIDs(cliServicePath, maxAllowedFieldID)
34+
if err != nil {
35+
t.Fatalf("Failed to validate thrift field IDs: %v", err)
36+
}
37+
38+
if len(violations) > 0 {
39+
errorMessage := fmt.Sprintf(
40+
"Found Thrift field IDs that exceed the maximum allowed value of %d.\n"+
41+
"This can cause compatibility issues and conflicts with reserved ID ranges.\n"+
42+
"Violations found:\n",
43+
maxAllowedFieldID-1)
44+
45+
for _, violation := range violations {
46+
errorMessage += fmt.Sprintf(" - %s\n", violation)
47+
}
48+
49+
t.Fatal(errorMessage)
50+
}
51+
}
52+
53+
// validateThriftFieldIDs parses the cli_service.go file and extracts all thrift field IDs
54+
// to validate they are within the allowed range.
55+
func validateThriftFieldIDs(filePath string, maxAllowedFieldID int) ([]string, error) {
56+
file, err := os.Open(filePath)
57+
if err != nil {
58+
return nil, fmt.Errorf("failed to open file %s: %w", filePath, err)
59+
}
60+
defer file.Close()
61+
62+
var violations []string
63+
scanner := bufio.NewScanner(file)
64+
lineNumber := 0
65+
66+
// Regex to match thrift field tags
67+
// Matches patterns like: `thrift:"fieldName,123,required"` or `thrift:"fieldName,123"`
68+
thriftTagRegex := regexp.MustCompile(`thrift:"([^"]*),(\d+)(?:,([^"]*))?"`)
69+
70+
for scanner.Scan() {
71+
lineNumber++
72+
line := scanner.Text()
73+
74+
// Find all thrift tags in the line
75+
matches := thriftTagRegex.FindAllStringSubmatch(line, -1)
76+
for _, match := range matches {
77+
if len(match) >= 3 {
78+
fieldName := match[1]
79+
fieldIDStr := match[2]
80+
81+
fieldID, err := strconv.Atoi(fieldIDStr)
82+
if err != nil {
83+
// Skip invalid field IDs (shouldn't happen in generated code)
84+
continue
85+
}
86+
87+
if fieldID >= maxAllowedFieldID {
88+
// Extract struct/field context from the line
89+
context := extractFieldContext(line)
90+
violation := fmt.Sprintf(
91+
"Line %d: Field '%s' has ID %d (exceeds maximum of %d) - %s",
92+
lineNumber, fieldName, fieldID, maxAllowedFieldID-1, context)
93+
violations = append(violations, violation)
94+
}
95+
}
96+
}
97+
}
98+
99+
if err := scanner.Err(); err != nil {
100+
return nil, fmt.Errorf("error reading file: %w", err)
101+
}
102+
103+
return violations, nil
104+
}
105+
106+
// extractFieldContext extracts the field declaration context from a line of code
107+
func extractFieldContext(line string) string {
108+
// Remove leading/trailing whitespace
109+
line = strings.TrimSpace(line)
110+
111+
// Try to extract the field name and type from the line
112+
// Format is typically: FieldName Type `tags...`
113+
parts := strings.Fields(line)
114+
if len(parts) >= 2 {
115+
fieldName := parts[0]
116+
fieldType := parts[1]
117+
return fmt.Sprintf("%s %s", fieldName, fieldType)
118+
}
119+
120+
// Fallback to returning the trimmed line if we can't parse it
121+
if len(line) > 100 {
122+
return line[:100] + "..."
123+
}
124+
return line
125+
}

0 commit comments

Comments
 (0)