Skip to content

Commit 7ccd7e6

Browse files
author
Charles Larivier
committed
feat: add aggregations.py
Signed-off-by: Charles Larivier <charles@dribbble.com>
1 parent a144a78 commit 7ccd7e6

File tree

2 files changed

+61
-16
lines changed

2 files changed

+61
-16
lines changed

src/metabase/mbql/aggregations.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,52 @@
1+
from typing import List
2+
13
from metabase.mbql.base import Mbql
24

35

46
class Aggregation(Mbql):
5-
mbql: str
7+
function: str
68

7-
def compile(self):
8-
return [self.mbql]
9+
def compile(self) -> List:
10+
return [self.function, super(Aggregation, self).compile()]
911

1012

11-
class ColumnAggregation(Aggregation):
12-
def __init__(self, field_id: int):
13-
self.field_id = field_id
13+
class Count(Aggregation):
14+
function = "count"
1415

15-
def compile(self):
16-
return [self.mbql, ["field", self.field_id, None]]
16+
def __init__(self, id: int = None):
17+
self.id = id
1718

19+
def compile(self) -> List:
20+
return [self.function]
1821

19-
class Count(Aggregation):
20-
mbql = "count"
22+
23+
class Sum(Aggregation):
24+
function = "sum"
25+
26+
27+
class Average(Aggregation):
28+
function = "avg"
29+
30+
31+
class Distinct(Aggregation):
32+
function = "distinct"
33+
34+
35+
class CumulativeSum(Aggregation):
36+
function = "cum-sum"
37+
38+
39+
class CumulativeCount(Aggregation):
40+
function = "cum-count"
2141

2242

23-
class Sum(ColumnAggregation):
24-
mbql = "sum"
43+
class StandardDeviation(Aggregation):
44+
function = "stddev"
2545

2646

27-
class Max(ColumnAggregation):
28-
mbql = "max"
47+
class Min(Aggregation):
48+
function = "min"
2949

3050

31-
class Min(ColumnAggregation):
32-
mbql = "min"
51+
class Max(Aggregation):
52+
function = "max"

tests/mbql/test_aggregations.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from unittest import TestCase
2+
3+
from metabase.mbql.aggregations import Aggregation, Count
4+
5+
6+
class AggregationTests(TestCase):
7+
def test_aggregation(self):
8+
"""Ensure Aggregation.compile() returns [self.function, ['field', self.id, self.option']]."""
9+
10+
class Mock(Aggregation):
11+
function = "mock"
12+
13+
aggregation = Mock(id=2)
14+
self.assertEqual(["mock", ["field", 2, None]], aggregation.compile())
15+
16+
aggregation = Mock(id=2, option={"foo": "bar"})
17+
self.assertEqual(["mock", ["field", 2, {"foo": "bar"}]], aggregation.compile())
18+
19+
def test_count(self):
20+
"""Ensure Count optionally accepts an id attribute."""
21+
count = Count()
22+
self.assertEqual(["count"], count.compile())
23+
24+
count = Count(id=5)
25+
self.assertEqual(["count"], count.compile())

0 commit comments

Comments
 (0)