Skip to content

Commit 36cb486

Browse files
author
Charles Larivier
committed
feat: add query.py
Signed-off-by: Charles Larivier <charles@dribbble.com>
1 parent 6c8de95 commit 36cb486

File tree

2 files changed

+100
-5
lines changed

2 files changed

+100
-5
lines changed

src/metabase/mbql/query.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,40 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Union
2+
from typing import List
33

44
from metabase.mbql.aggregations import Aggregation
5+
from metabase.mbql.filter import Filter
56
from metabase.mbql.groupby import GroupBy
67

78

89
@dataclass
910
class Query:
1011
table_id: int
11-
aggregations: List[Union[Aggregation, Metric]]
12+
aggregations: List[Aggregation]
1213
group_by: List[GroupBy] = field(default_factory=list)
1314
filters: List[Filter] = field(default_factory=list)
1415

1516
def compile(self):
1617
return {
1718
"source-table": self.table_id,
18-
"aggregation": [agg.compile() for agg in self.aggregations],
19-
"breakout": [],
20-
"filter": [],
19+
"aggregation": self._aggregations,
20+
"breakout": self._group_by,
21+
"filter": self._filters,
2122
}
23+
24+
@property
25+
def _aggregations(self):
26+
return [aggregation.compile() for aggregation in self.aggregations]
27+
28+
@property
29+
def _group_by(self):
30+
return [group.compile() for group in self.group_by]
31+
32+
@property
33+
def _filters(self):
34+
if len(self.filters) == 0:
35+
return self.filters
36+
37+
if len(self.filters) == 1:
38+
return self.filters[0].compile()
39+
40+
return ["and"] + [filt.compile() for filt in self.filters]

tests/mbql/test_query.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from unittest import TestCase
2+
3+
from metabase.mbql.aggregations import Count, Max
4+
from metabase.mbql.filter import Equal
5+
from metabase.mbql.groupby import GroupBy
6+
from metabase.mbql.query import Query
7+
8+
9+
class QueryTests(TestCase):
10+
def test_compile(self):
11+
"""Ensure Query.compile() returns valid MBQL."""
12+
query = Query(
13+
table_id=14,
14+
aggregations=[Count(), Max(5)],
15+
group_by=[GroupBy(14)],
16+
filters=[Equal(2, 5), Equal(5, "foo")],
17+
)
18+
19+
self.assertEqual(
20+
{
21+
"source-table": 14,
22+
"aggregation": [["count"], ["max", ["field", 5, None]]],
23+
"breakout": [["field", 14, None]],
24+
"filter": [
25+
"and",
26+
["=", ["field", 2, None], 5],
27+
["=", ["field", 5, None], "foo"],
28+
],
29+
},
30+
query.compile(),
31+
)
32+
33+
def test__aggregations(self):
34+
"""Ensure Query._aggregations returns a list of compiled Aggregation."""
35+
query = Query(table_id=12, aggregations=[])
36+
self.assertEqual([], query._aggregations)
37+
38+
query = Query(
39+
table_id=12,
40+
aggregations=[Count(), Max(5)],
41+
)
42+
self.assertEqual([["count"], ["max", ["field", 5, None]]], query._aggregations)
43+
44+
def test__group_by(self):
45+
"""Ensure Query._group_by returns a list of compiled GroupBy."""
46+
query = Query(table_id=12, aggregations=[Count()], group_by=[])
47+
self.assertEqual([], query._group_by)
48+
49+
query = Query(
50+
table_id=12,
51+
aggregations=[Count()],
52+
group_by=[GroupBy(5)],
53+
)
54+
self.assertEqual([["field", 5, None]], query._group_by)
55+
56+
def test__filters(self):
57+
"""Ensure Query._filters returns a list of compiled Filter."""
58+
query = Query(table_id=12, aggregations=[Count()], filters=[])
59+
self.assertEqual([], query._filters)
60+
61+
query = Query(
62+
table_id=12,
63+
aggregations=[Count()],
64+
filters=[Equal(5, 2)],
65+
)
66+
self.assertListEqual(["=", ["field", 5, None], 2], query._filters)
67+
68+
query = Query(
69+
table_id=12,
70+
aggregations=[Count()],
71+
filters=[Equal(5, 2), Equal(6, 1)],
72+
)
73+
self.assertEqual(
74+
["and", ["=", ["field", 5, None], 2], ["=", ["field", 6, None], 1]],
75+
query._filters,
76+
)

0 commit comments

Comments
 (0)