Skip to content

Commit 64c0111

Browse files
authored
fix: Improve "in" CEL operator behavior with null (#4611)
1 parent d39afd3 commit 64c0111

File tree

6 files changed

+84
-60
lines changed

6 files changed

+84
-60
lines changed

keep/api/core/cel_to_sql/sql_providers/base.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Any, List
2-
from types import NoneType
32

43
from sqlalchemy import Dialect, String
54

@@ -185,12 +184,6 @@ def literal_proc(self, value: Any) -> str:
185184
def _get_order_by_field(self, cel_sort_by: str) -> str:
186185
return self.get_field_expression(cel_sort_by)
187186

188-
def _get_default_value_for_type(self, type: type) -> str:
189-
if type is str or type is NoneType:
190-
return "'__@NULL@__'" # This is a workaround for handling NULL values in SQL
191-
192-
return "NULL"
193-
194187
def __build_sql_filter(self, abstract_node: Node, stack: list[Node]) -> str:
195188
stack.append(abstract_node)
196189
result = None
@@ -229,8 +222,11 @@ def __build_sql_filter(self, abstract_node: Node, stack: list[Node]) -> str:
229222
def json_extract_as_text(self, column: str, path: list[str]) -> str:
230223
raise NotImplementedError("Extracting JSON is not implemented. Must be implemented in the child class.")
231224

232-
def coalesce(self, args: List[str]) -> str:
233-
raise NotImplementedError("COALESCE is not implemented. Must be implemented in the child class.")
225+
def coalesce(self, args):
226+
if len(args) == 1:
227+
return args[0]
228+
229+
return f"COALESCE({', '.join(args)})"
234230

235231
def cast(self, expression_to_cast: str, to_type: type, force=False) -> str:
236232
raise NotImplementedError("CAST is not implemented. Must be implemented in the child class.")
@@ -311,6 +307,9 @@ def _visit_comparison_node(self, comparison_node: ComparisonNode, stack: list[No
311307
return result
312308

313309
def _visit_equal(self, first_operand: str, second_operand: str) -> str:
310+
if second_operand == "NULL":
311+
return f"{first_operand} IS NULL"
312+
314313
return f"{first_operand} = {second_operand}"
315314

316315
def _visit_not_equal(self, first_operand: str, second_operand: str) -> str:
@@ -344,13 +343,41 @@ def _visit_in(self, first_operand: Node, array: list[ConstantNode], stack: list[
344343
else:
345344
first_operand_str = self.__build_sql_filter(first_operand, stack)
346345

347-
return f"{first_operand_str} in ({ ', '.join([self._visit_constant_node(c.value) for c in array])})"
346+
constant_nodes_without_none = []
347+
is_none_found = False
348+
349+
for item in array:
350+
if isinstance(item, ConstantNode):
351+
if item.value is None:
352+
is_none_found = True
353+
continue
354+
constant_nodes_without_none.append(item)
355+
356+
or_queries = []
357+
358+
if len(constant_nodes_without_none) > 0:
359+
or_queries.append(
360+
f"{first_operand_str} in ({ ', '.join([self._visit_constant_node(c.value) for c in constant_nodes_without_none])})"
361+
)
362+
363+
if is_none_found:
364+
or_queries.append(self._visit_equal(first_operand_str, "NULL"))
365+
366+
if len(or_queries) == 0:
367+
return self._visit_constant_node(False)
368+
369+
final_query = or_queries[0]
370+
371+
for query in or_queries[1:]:
372+
final_query = self._visit_logical_or(final_query, query)
373+
374+
return final_query
348375

349376
# endregion
350377

351-
def _visit_constant_node(self, value: str) -> str:
378+
def _visit_constant_node(self, value: Any) -> str:
352379
if value is None:
353-
return self._get_default_value_for_type(NoneType)
380+
return "NULL"
354381
if isinstance(value, str):
355382
return self.literal_proc(value)
356383
if isinstance(value, bool):
@@ -373,8 +400,6 @@ def _visit_multiple_fields_node(self, multiple_fields_node: MultipleFieldsNode,
373400
if len(coalesce_args) == 1:
374401
return coalesce_args[0]
375402

376-
coalesce_args.append(self._get_default_value_for_type(cast_to))
377-
378403
return self.coalesce(coalesce_args)
379404

380405
def _visit_member_access_node(self, member_access_node: MemberAccessNode, stack) -> str:

keep/api/core/cel_to_sql/sql_providers/mysql.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@ def _visit_constant_node(self, value: str) -> str:
9292

9393
return super()._visit_constant_node(value)
9494

95-
def coalesce(self, args):
96-
return f"COALESCE({', '.join(args)})"
97-
9895
def _visit_contains_method_calling(
9996
self, property_path: str, method_args: List[ConstantNode]
10097
) -> str:

keep/api/core/cel_to_sql/sql_providers/postgresql.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@ def json_extract_as_text(self, column: str, path: list[str]) -> str:
1717
json_property_path = " -> ".join(all_columns[:-1])
1818
return f"({json_property_path}) ->> {all_columns[-1]}" # (json_column -> 'labels' -> tags) ->> 'service'
1919

20-
def coalesce(self, args):
21-
coalesce_args = args
22-
23-
if len(args) == 1:
24-
coalesce_args += ["NULL"]
25-
26-
return f"COALESCE({', '.join(args)})"
27-
2820
def cast(self, expression_to_cast: str, to_type, force=False):
2921
if to_type is str:
3022
to_type_str = "TEXT"

keep/api/core/cel_to_sql/sql_providers/sqlite.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,6 @@ def json_extract_as_text(self, column: str, path: list[str]) -> str:
1111
property_path_str = ".".join([f'"{item}"' for item in path])
1212
return f"json_extract({column}, '$.{property_path_str}')"
1313

14-
def coalesce(self, args):
15-
coalesce_args = args
16-
17-
if len(args) == 1:
18-
coalesce_args += ["NULL"]
19-
20-
return f"COALESCE({', '.join(coalesce_args)})"
21-
2214
def cast(self, expression_to_cast: str, to_type, force=False):
2315
if to_type is str:
2416
to_type_str = "TEXT"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "keep"
3-
version = "0.41.29"
3+
version = "0.41.30"
44
description = "Alerting. for developers, by developers."
55
authors = ["Keep Alerting LTD"]
66
packages = [{include = "keep"}]

tests/cel_to_sql/cel-to-sql-test-cases.json

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,47 @@
11
[
2+
{
3+
"input_cel": "alert.severity == null",
4+
"description": "Equality with null",
5+
"expected_sql_dialect_based": {
6+
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"severity\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"severity\"'))) IS NULL",
7+
"postgresql": "COALESCE((alert_enrichments) ->> 'severity', (alert_event) ->> 'severity') IS NULL",
8+
"sqlite": "COALESCE(json_extract(alert_enrichments, '$.\"severity\"'), json_extract(alert_event, '$.\"severity\"')) IS NULL"
9+
}
10+
},
211
{
312
"input_cel": "alert.severity == 'HIGH'",
413
"description": "Queried field refers to multiple JSON columns",
514
"expected_sql_dialect_based": {
6-
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"severity\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"severity\"')), '__@NULL@__') = 'HIGH'",
7-
"postgresql": "COALESCE(((alert_enrichments) ->> 'severity')::TEXT, ((alert_event) ->> 'severity')::TEXT, '__@NULL@__') = 'HIGH'",
8-
"sqlite": "COALESCE(CAST(json_extract(alert_enrichments, '$.\"severity\"') as TEXT), CAST(json_extract(alert_event, '$.\"severity\"') as TEXT), '__@NULL@__') = 'HIGH'"
15+
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"severity\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"severity\"'))) = 'HIGH'",
16+
"postgresql": "COALESCE(((alert_enrichments) ->> 'severity')::TEXT, ((alert_event) ->> 'severity')::TEXT) = 'HIGH'",
17+
"sqlite": "COALESCE(CAST(json_extract(alert_enrichments, '$.\"severity\"') as TEXT), CAST(json_extract(alert_event, '$.\"severity\"') as TEXT)) = 'HIGH'"
918
}
1019
},
1120
{
1221
"input_cel": "name != 'Payments incident'",
13-
"description": "Queried field refers to multipl columns",
22+
"description": "Queried field refers to multiple columns",
23+
"expected_sql_dialect_based": {
24+
"mysql": "COALESCE(user_generated_name, ai_generated_name) != 'Payments incident'",
25+
"postgresql": "COALESCE(user_generated_name, ai_generated_name) != 'Payments incident'",
26+
"sqlite": "COALESCE(user_generated_name, ai_generated_name) != 'Payments incident'"
27+
}
28+
},
29+
{
30+
"input_cel": "name in ['Payments incident', 'API incident', 'Network incident', null]",
31+
"description": "IN operator along with NULL",
1432
"expected_sql_dialect_based": {
15-
"mysql": "COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') != 'Payments incident'",
16-
"postgresql": "COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') != 'Payments incident'",
17-
"sqlite": "COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') != 'Payments incident'"
33+
"mysql": "(COALESCE(user_generated_name, ai_generated_name) in ('Payments incident', 'API incident', 'Network incident') OR COALESCE(user_generated_name, ai_generated_name) IS NULL)",
34+
"postgresql": "(COALESCE(user_generated_name, ai_generated_name) in ('Payments incident', 'API incident', 'Network incident') OR COALESCE(user_generated_name, ai_generated_name) IS NULL)",
35+
"sqlite": "(COALESCE(user_generated_name, ai_generated_name) in ('Payments incident', 'API incident', 'Network incident') OR COALESCE(user_generated_name, ai_generated_name) IS NULL)"
1836
}
1937
},
2038
{
2139
"input_cel": "!(name in ['Payments incident', 'API incident', 'Network incident', null])",
22-
"description": "IN operator along with NOT",
40+
"description": "IN operator along with NOT and NULL",
2341
"expected_sql_dialect_based": {
24-
"mysql": "NOT (COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') in ('Payments incident', 'API incident', 'Network incident', '__@NULL@__'))",
25-
"postgresql": "NOT (COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') in ('Payments incident', 'API incident', 'Network incident', '__@NULL@__'))",
26-
"sqlite": "NOT (COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') in ('Payments incident', 'API incident', 'Network incident', '__@NULL@__'))"
42+
"mysql": "NOT ((COALESCE(user_generated_name, ai_generated_name) in ('Payments incident', 'API incident', 'Network incident') OR COALESCE(user_generated_name, ai_generated_name) IS NULL))",
43+
"postgresql": "NOT ((COALESCE(user_generated_name, ai_generated_name) in ('Payments incident', 'API incident', 'Network incident') OR COALESCE(user_generated_name, ai_generated_name) IS NULL))",
44+
"sqlite": "NOT ((COALESCE(user_generated_name, ai_generated_name) in ('Payments incident', 'API incident', 'Network incident') OR COALESCE(user_generated_name, ai_generated_name) IS NULL))"
2745
}
2846
},
2947
{
@@ -84,18 +102,18 @@
84102
"input_cel": "alert.randomDate >= '2025-01-30T10:00:09.553Z'",
85103
"description": "Comparison operator with dates for a JSON multiple columns",
86104
"expected_sql_dialect_based": {
87-
"sqlite": "COALESCE(json_extract(alert_enrichments, '$.\"randomDate\"'), json_extract(alert_event, '$.\"randomDate\"'), NULL) >= datetime('2025-01-30 10:00:09')",
88-
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"randomDate\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"randomDate\"')), NULL) >= CAST('2025-01-30 10:00:09' as DATETIME)",
89-
"postgresql": "COALESCE(((alert_enrichments) ->> 'randomDate')::TIMESTAMP, ((alert_event) ->> 'randomDate')::TIMESTAMP, NULL) >= CAST('2025-01-30 10:00:09' as TIMESTAMP)"
105+
"sqlite": "COALESCE(json_extract(alert_enrichments, '$.\"randomDate\"'), json_extract(alert_event, '$.\"randomDate\"')) >= datetime('2025-01-30 10:00:09')",
106+
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"randomDate\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"randomDate\"'))) >= CAST('2025-01-30 10:00:09' as DATETIME)",
107+
"postgresql": "COALESCE(((alert_enrichments) ->> 'randomDate')::TIMESTAMP, ((alert_event) ->> 'randomDate')::TIMESTAMP) >= CAST('2025-01-30 10:00:09' as TIMESTAMP)"
90108
}
91109
},
92110
{
93111
"input_cel": "alert.count > 7.84",
94112
"description": "Greater than with float",
95113
"expected_sql_dialect_based": {
96-
"sqlite": "COALESCE(CAST(json_extract(alert_enrichments, '$.\"count\"') as REAL), CAST(json_extract(alert_event, '$.\"count\"') as REAL), NULL) > 7.84",
97-
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"count\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"count\"')), NULL) > 7.84",
98-
"postgresql": "COALESCE(((alert_enrichments) ->> 'count')::FLOAT, ((alert_event) ->> 'count')::FLOAT, NULL) > 7.84"
114+
"sqlite": "COALESCE(CAST(json_extract(alert_enrichments, '$.\"count\"') as REAL), CAST(json_extract(alert_event, '$.\"count\"') as REAL)) > 7.84",
115+
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"count\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"count\"'))) > 7.84",
116+
"postgresql": "COALESCE(((alert_enrichments) ->> 'count')::FLOAT, ((alert_event) ->> 'count')::FLOAT) > 7.84"
99117
}
100118
},
101119
{
@@ -129,9 +147,9 @@
129147
"input_cel": "name == 'Payments incident' && severity <= 'critical'",
130148
"description": "AND with less than than comparison operator with enum when constat is the last value in enum",
131149
"expected_sql_dialect_based": {
132-
"sqlite": "COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') = 'Payments incident'",
133-
"mysql": "COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') = 'Payments incident'",
134-
"postgresql": "COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') = 'Payments incident'"
150+
"sqlite": "COALESCE(user_generated_name, ai_generated_name) = 'Payments incident'",
151+
"mysql": "COALESCE(user_generated_name, ai_generated_name) = 'Payments incident'",
152+
"postgresql": "COALESCE(user_generated_name, ai_generated_name) = 'Payments incident'"
135153
}
136154
},
137155
{
@@ -147,18 +165,18 @@
147165
"input_cel": "name == 'Payments incident' && severity > 'critical'",
148166
"description": "AND with greater than comparison operator with enum when constat is the last value in enum",
149167
"expected_sql_dialect_based": {
150-
"sqlite": "(COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') = 'Payments incident' AND false)",
151-
"mysql": "(COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') = 'Payments incident' AND FALSE)",
152-
"postgresql": "(COALESCE(user_generated_name, ai_generated_name, '__@NULL@__') = 'Payments incident' AND false)"
168+
"sqlite": "(COALESCE(user_generated_name, ai_generated_name) = 'Payments incident' AND false)",
169+
"mysql": "(COALESCE(user_generated_name, ai_generated_name) = 'Payments incident' AND FALSE)",
170+
"postgresql": "(COALESCE(user_generated_name, ai_generated_name) = 'Payments incident' AND false)"
153171
}
154172
},
155173
{
156174
"input_cel": "alert.count <= 100",
157175
"description": "Less than or equal with integer",
158176
"expected_sql_dialect_based": {
159-
"sqlite": "COALESCE(CAST(json_extract(alert_enrichments, '$.\"count\"') as REAL), CAST(json_extract(alert_event, '$.\"count\"') as REAL), NULL) <= 100",
160-
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"count\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"count\"')), NULL) <= 100",
161-
"postgresql": "COALESCE(((alert_enrichments) ->> 'count')::FLOAT, ((alert_event) ->> 'count')::FLOAT, NULL) <= 100"
177+
"sqlite": "COALESCE(CAST(json_extract(alert_enrichments, '$.\"count\"') as REAL), CAST(json_extract(alert_event, '$.\"count\"') as REAL)) <= 100",
178+
"mysql": "COALESCE(JSON_UNQUOTE(JSON_EXTRACT(alert_enrichments, '$.\"count\"')), JSON_UNQUOTE(JSON_EXTRACT(alert_event, '$.\"count\"'))) <= 100",
179+
"postgresql": "COALESCE(((alert_enrichments) ->> 'count')::FLOAT, ((alert_event) ->> 'count')::FLOAT) <= 100"
162180
}
163181
},
164182
{

0 commit comments

Comments
 (0)