Skip to content

Commit 4743948

Browse files
authored
functions passed into macros that have mutliple args will parse correctly (#98)
1 parent d8a7dac commit 4743948

File tree

2 files changed

+50
-41
lines changed

2 files changed

+50
-41
lines changed

macros.go

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,16 @@ var DefaultMacros Macros = Macros{
126126
"column": macroColumn,
127127
}
128128

129-
func trimAll(s []string) []string {
130-
r := make([]string, len(s))
131-
for i, v := range s {
132-
r[i] = strings.TrimSpace(v)
133-
}
134-
135-
return r
129+
type Macro struct {
130+
Name string
131+
Args []string
136132
}
137133

138-
var pair = map[rune]rune{')': '('}
139-
140134
// getMacroMatches extracts macro strings with their respective arguments from the sql input given
141135
// It manually parses the string to find the closing parenthesis of the macro (because regex has no memory)
142-
func getMacroMatches(input string, name string) ([][]string, error) {
136+
func getMacroMatches(input string, name string) ([]Macro, error) {
143137
macroName := fmt.Sprintf("\\$__%s\\b", name)
144-
matchedMacros := [][]string{}
138+
matchedMacros := []Macro{}
145139
rgx, err := regexp.Compile(macroName)
146140

147141
if err != nil {
@@ -201,19 +195,43 @@ func getMacroMatches(input string, name string) ([][]string, error) {
201195
macroEnd = matched[matchedIndex][1] - macroStart
202196
}
203197
macroString := inputCopy[0:macroEnd]
204-
macroMatch := []string{macroString}
198+
macroMatch := Macro{Name: macroString}
205199

206200
args := ""
207201
// if opening parenthesis was found, extract contents as arguments
208202
if argStart > 0 {
209203
args = inputCopy[argStart : macroEnd-1]
210204
}
211-
macroMatch = append(macroMatch, args)
205+
macroMatch.Args = parseArgs(args)
212206
matchedMacros = append(matchedMacros, macroMatch)
213207
}
214208
return matchedMacros, nil
215209
}
216210

211+
func parseArgs(args string) []string {
212+
argsArray := []string{}
213+
phrase := []rune{}
214+
bracketCount := 0
215+
for _, v := range args {
216+
phrase = append(phrase, v)
217+
if v == '(' {
218+
bracketCount++
219+
continue
220+
}
221+
if v == ')' {
222+
bracketCount--
223+
continue
224+
}
225+
if v == ',' && bracketCount == 0 {
226+
removeComma := phrase[:len(phrase)-1]
227+
argsArray = append(argsArray, string(removeComma))
228+
phrase = []rune{}
229+
}
230+
}
231+
argsArray = append(argsArray, strings.TrimSpace(string(phrase)))
232+
return argsArray
233+
}
234+
217235
// Interpolate returns an interpolated query string given a backend.DataQuery
218236
func Interpolate(driver Driver, query *Query) (string, error) {
219237
macros := Macros{}
@@ -223,41 +241,23 @@ func Interpolate(driver Driver, query *Query) (string, error) {
223241
rawSQL := query.RawSQL
224242

225243
for key, macro := range macros {
226-
matches, err := getMatches(key, rawSQL)
227-
244+
matches, err := getMacroMatches(rawSQL, key)
228245
if err != nil {
229246
return rawSQL, err
230247
}
231-
for _, match := range matches {
232-
if len(match) == 0 {
233-
// There were no matches for this macro
234-
continue
235-
}
236-
237-
args := []string{}
238-
if len(match) > 1 {
239-
// This macro has arguments
240-
args = trimAll(strings.Split(match[1], ","))
241-
}
248+
if len(matches) == 0 {
249+
continue
250+
}
242251

243-
res, err := macro(query.WithSQL(rawSQL), args)
252+
for _, match := range matches {
253+
res, err := macro(query.WithSQL(rawSQL), match.Args)
244254
if err != nil {
245255
return rawSQL, err
246256
}
247257

248-
rawSQL = strings.Replace(rawSQL, match[0], res, -1)
258+
rawSQL = strings.Replace(rawSQL, match.Name, res, -1)
249259
}
250260
}
251261

252262
return rawSQL, nil
253263
}
254-
255-
func getMatches(macroName, rawSQL string) ([][]string, error) {
256-
parsedInput, err := getMacroMatches(rawSQL, macroName)
257-
258-
if err != nil {
259-
return nil, err
260-
}
261-
262-
return parsedInput, err
263-
}

macros_test.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ func (h *MockDB) Macros() (macros Macros) {
3232
"timeGroup": func(query *Query, args []string) (out string, err error) {
3333
return "grouped!", nil
3434
},
35+
"multiParams": func(query *Query, args []string) (out string, err error) {
36+
r := "bar"
37+
for _, v := range args {
38+
r += "_" + v
39+
}
40+
return r, nil
41+
},
3542
}
3643
}
3744

@@ -71,6 +78,8 @@ func TestInterpolate(t *testing.T) {
7178
{input: "select $__column from $__table", output: "select my_col from my_table", name: "table and column macros"},
7279
{input: "select * from table where ( datetime >= $__foo() ) AND ( datetime <= $__foo() ) limit 100", output: "select * from table where ( datetime >= bar ) AND ( datetime <= bar ) limit 100", name: "macro functions inside more complex clauses"},
7380
{input: "select * from table where ( datetime >= $__foo ) AND ( datetime <= $__foo ) limit 100", output: "select * from table where ( datetime >= bar ) AND ( datetime <= bar ) limit 100", name: "macros inside more complex clauses"},
81+
{input: "select * from foo where $__multiParams(foo, bar)", output: "select * from foo where bar_foo_bar", name: "macro with multiple parameters"},
82+
{input: "select * from foo where $__params(FUNC(foo, bar))", output: "select * from foo where bar_FUNC(foo, bar)", name: "function in macro with multiple parameters"},
7483
}
7584
for i, tc := range tests {
7685
driver := MockDB{}
@@ -90,14 +99,14 @@ func TestInterpolate(t *testing.T) {
9099
func TestGetMatches(t *testing.T) {
91100
t.Run("FindAllStringSubmatch returns DefaultMacros", func(t *testing.T) {
92101
for macroName := range DefaultMacros {
93-
matches, err := getMatches(macroName, fmt.Sprintf("$__%s", macroName))
102+
matches, err := getMacroMatches(fmt.Sprintf("$__%s", macroName), macroName)
94103

95104
assert.NoError(t, err)
96-
assert.Equal(t, [][]string{{fmt.Sprintf("$__%s", macroName), ""}}, matches)
105+
assert.Equal(t, []Macro{{Name: fmt.Sprintf("$__%s", macroName), Args: []string{""}}}, matches)
97106
}
98107
})
99108
t.Run("does not return matches for macro name which is substring", func(t *testing.T) {
100-
matches, err := getMatches("timeFilter", "$__timeFilterEpoch(time_column)")
109+
matches, err := getMacroMatches("$__timeFilterEpoch(time_column)", "timeFilter")
101110

102111
assert.NoError(t, err)
103112
assert.Nil(t, matches)

0 commit comments

Comments
 (0)