Skip to content

Commit 94c1467

Browse files
committed
Add simple (WIP) tests to python code
1 parent 562539c commit 94c1467

File tree

4 files changed

+214
-33
lines changed

4 files changed

+214
-33
lines changed

languages/python/jupyter_notebook/cs_models.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from sqlalchemy.types import TypeDecorator, String, Integer, Date, Boolean, Float
33
from sqlalchemy import create_engine, select, text
44
from sqlalchemy.exc import IntegrityError
5+
from datetime import datetime
56
import json
67

78
import sys
@@ -23,6 +24,7 @@ def process_bind_param(self, value, dialect):
2324
"c": self.column_name
2425
},
2526
"v": 1,
27+
"q": None
2628
}
2729
value = json.dumps(value_dict)
2830
return value
@@ -35,43 +37,51 @@ def process_result_value(self, value, dialect):
3537
class EncryptedInt(CsTypeDecorator):
3638
impl = String
3739

38-
def __init__(self, *args, **kwargs):
39-
super().__init__(*args, **kwargs)
40+
def process_result_value(self, value, dialect):
41+
if value is None:
42+
return None
43+
return int(value['p'])
44+
4045

4146
class EncryptedBoolean(CsTypeDecorator):
4247
impl = String
4348

44-
def __init__(self, *args, **kwargs):
45-
super().__init__(*args, **kwargs)
46-
4749
def process_bind_param(self, value, dialect):
4850
if value is not None:
4951
value = str(value).lower()
5052
return super().process_bind_param(value, dialect)
5153

54+
def process_result_value(self, value, dialect):
55+
if value is None:
56+
return None
57+
return value['p'] == 'true'
58+
5259
class EncryptedDate(CsTypeDecorator):
5360
impl = String
5461

55-
def __init__(self, *args, **kwargs):
56-
super().__init__(*args, **kwargs)
62+
def process_result_value(self, value, dialect):
63+
if value is None:
64+
return None
65+
return datetime.fromisoformat(value['p']).date()
5766

5867
class EncryptedFloat(CsTypeDecorator):
5968
impl = String
6069

61-
def __init__(self, *args, **kwargs):
62-
super().__init__(*args, **kwargs)
70+
def process_result_value(self, value, dialect):
71+
if value is None:
72+
return None
73+
return float(value['p'])
6374

6475
class EncryptedUtf8Str(CsTypeDecorator):
6576
impl = String
6677

67-
def __init__(self, *args, **kwargs):
68-
super().__init__(*args, **kwargs)
69-
7078
class EncryptedJsonb(CsTypeDecorator):
7179
impl = String
7280

73-
def __init__(self, *args, **kwargs):
74-
super().__init__(*args, **kwargs)
81+
def process_result_value(self, value, dialect):
82+
if value is None:
83+
return None
84+
return json.loads(value['p'])
7585

7686
class BaseModel(DeclarativeBase):
7787
pass
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
from datetime import date
3+
4+
from cs_models import *
5+
6+
class TestExampleModel(unittest.TestCase):
7+
def setUp(self):
8+
# TODO: configure database URL in environment variable and use a test db (not getting_started)
9+
self.engine = create_engine('postgresql://postgres:postgres@localhost:6432/cipherstash_getting_started')
10+
Session = sessionmaker(bind=self.engine)
11+
self.session = Session()
12+
BaseModel.metadata.create_all(self.engine)
13+
14+
self.session.query(Example).delete()
15+
self.example = Example(
16+
e_int=1, e_utf8_str="str", e_jsonb='{"key": "value"}', e_float=1.1, e_date=date(2024, 1, 1), e_bool=True
17+
)
18+
self.session.add(self.example)
19+
self.session.commit()
20+
21+
def test_encrypted_int(self):
22+
found = self.session.query(Example).filter(Example.id == self.example.id).one()
23+
self.assertEqual(found.encrypted_int, 1)
24+
25+
def test_encrypted_boolean(self):
26+
found = self.session.query(Example).filter(Example.id == self.example.id).one()
27+
self.assertEqual(found.encrypted_boolean, True)
28+
29+
def test_encrypted_date(self):
30+
found = self.session.query(Example).filter(Example.id == self.example.id).one()
31+
self.assertEqual(found.encrypted_date, date(2024, 1, 1))
32+
33+
def test_encrypted_float(self):
34+
found = self.session.query(Example).filter(Example.id == self.example.id).one()
35+
self.assertEqual(found.encrypted_float, 1.1)
36+
37+
def test_encrypted_utf8_str(self):
38+
found = self.session.query(Example).filter(Example.id == self.example.id).one()
39+
self.assertEqual(found.encrypted_utf8_str, "str")
40+
41+
def test_encrypted_jsonb(self):
42+
found = self.session.query(Example).filter(Example.id == self.example.id).one()
43+
self.assertEqual(found.encrypted_jsonb, {"key": "value"})
44+
45+
if __name__ == '__main__':
46+
unittest.main()

languages/python/jupyter_notebook/cs_types.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
#!/usr/bin/env python
2-
3-
import psycopg2
41
from psycopg2.extras import RealDictCursor
52
from pprint import pprint
63
from datetime import datetime
@@ -12,71 +9,69 @@ def __init__(self, v, t: str, c: str):
129
self.table = t
1310
self.column = c
1411

15-
def to_db_format(self):
12+
def to_db_format(self, query_type=None):
1613
data = {
1714
"k": "pt",
18-
"p": self.value_in_db_format(),
15+
"p": self.value_in_db_format(query_type),
1916
"i": {
2017
"t": str(self.table),
2118
"c": str(self.column)
2219
},
2320
"v": 1,
21+
"q": query_type,
2422
}
2523
return json.dumps(data)
2624

27-
# TODO: Unused at the moment
28-
@classmethod
29-
def from_json_str(cls, json_str):
30-
parsed = json.loads(json_str)
31-
return cls.from_parsed_json(parsed)
32-
3325
@classmethod
3426
def from_parsed_json(cls, parsed):
3527
return cls.value_from_db_format(parsed["p"])
3628

3729
class CsInt(CsValue):
38-
def value_in_db_format(self):
30+
def value_in_db_format(self, query_type = None):
3931
return str(self.value)
4032

4133
@classmethod
4234
def value_from_db_format(cls, s: str):
4335
return int(s)
4436

4537
class CsBool(CsValue):
46-
def value_in_db_format(self):
38+
def value_in_db_format(self, query_type = None):
4739
return str(self.value).lower()
4840

4941
@classmethod
5042
def value_from_db_format(cls, s: str):
5143
return s.lower() == 'true'
5244

5345
class CsDate(CsValue):
54-
def value_in_db_format(self):
46+
def value_in_db_format(self, query_type = None):
5547
return self.value.isoformat()
5648

5749
@classmethod
5850
def value_from_db_format(cls, s: str):
59-
return datetime.fromisoformat(s)
51+
return datetime.fromisoformat(s).date()
6052

6153
class CsFloat(CsValue):
62-
def value_in_db_format(self):
54+
def value_in_db_format(self, query_type = None):
6355
return str(self.value)
6456

6557
@classmethod
6658
def value_from_db_format(cls, s: str):
6759
return float(s)
6860

6961
class CsText(CsValue):
70-
def value_in_db_format(self):
62+
def value_in_db_format(self, query_type = None):
7163
return self.value
7264

7365
@classmethod
7466
def value_from_db_format(cls, s: str):
7567
return s
7668

7769
class CsJsonb(CsValue):
78-
def value_in_db_format(self):
79-
return json.dumps(self.value)
70+
def value_in_db_format(self, query_type):
71+
if query_type == "ejson_path":
72+
return self.value
73+
else:
74+
return json.dumps(self.value)
8075

8176
@classmethod
8277
def value_from_db_format(cls, s: str):
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import unittest
2+
import json
3+
from datetime import date
4+
from cs_types import *
5+
6+
class EqlTest(unittest.TestCase):
7+
def test(self):
8+
self.assertTrue(True)
9+
10+
def test_to_db_format(self):
11+
self.assertEqual(
12+
CsInt(1, "table", "column").to_db_format(),
13+
'{"k": "pt", "p": "1", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}'
14+
)
15+
16+
def test_from_parsed_json_uses_p_value(self):
17+
parsed = json.loads('{"k": "pt", "p": "1", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}')
18+
self.assertEqual(
19+
CsInt.from_parsed_json(parsed),
20+
1
21+
)
22+
23+
def test_cs_int_prints_value(self):
24+
cs_int = CsInt(1, "table", "column")
25+
self.assertEqual(
26+
cs_int.value_in_db_format(),
27+
"1"
28+
)
29+
30+
def test_ces_int_makes_int(self):
31+
self.assertEqual(
32+
CsInt.value_from_db_format("1"),
33+
1
34+
)
35+
36+
def test_cs_bool_prints_value_in_lower_case(self):
37+
cs_bool = CsBool(True, "table", "column")
38+
self.assertEqual(
39+
cs_bool.value_in_db_format(),
40+
"true"
41+
)
42+
43+
def test_cs_bool_returns_bool(self):
44+
self.assertEqual(
45+
CsBool.value_from_db_format("true"),
46+
True
47+
)
48+
49+
def test_cs_date_prints_value(self):
50+
cs_date = CsDate(date(2024, 11, 1), "table", "column")
51+
self.assertEqual(
52+
cs_date.value_in_db_format(),
53+
"2024-11-01"
54+
)
55+
56+
def test_cs_date_returns_datetime(self):
57+
self.assertEqual(
58+
CsDate.value_from_db_format("2024-11-01"),
59+
date(2024, 11, 1)
60+
)
61+
62+
def test_cs_float_prints_value(self):
63+
cs_float = CsFloat(1.1, "table", "column")
64+
self.assertEqual(
65+
cs_float.value_in_db_format(),
66+
"1.1"
67+
)
68+
69+
def test_cs_float_returns_float(self):
70+
self.assertEqual(
71+
CsFloat.value_from_db_format("1.1"),
72+
1.1
73+
)
74+
75+
def test_cs_text_prints_value(self):
76+
cs_text = CsText("text", "table", "column")
77+
self.assertEqual(
78+
cs_text.value_in_db_format(),
79+
"text"
80+
)
81+
82+
def test_cs_text_returns_value(self):
83+
self.assertEqual(
84+
CsText.value_from_db_format("text"),
85+
"text"
86+
)
87+
88+
def test_cs_jsonb_prints_json_string(self):
89+
cs_jsonb = CsJsonb({"a": 1}, "table", "column")
90+
self.assertEqual(
91+
cs_jsonb.value_in_db_format("ste_vec"),
92+
'{"a": 1}'
93+
)
94+
95+
def test_cs_jsonb_prints_value_for_ejson_path(self):
96+
cs_jsonb = CsJsonb("$.a.b", "table", "column")
97+
self.assertEqual(
98+
cs_jsonb.value_in_db_format("ejson_path"),
99+
'$.a.b'
100+
)
101+
102+
def test_cs_jsonb_returns_value(self):
103+
self.assertEqual(
104+
CsJsonb.value_from_db_format('{"a": 1}'),
105+
{"a": 1}
106+
)
107+
108+
def test_cs_row_makes_row(self):
109+
cs_row = CsRow(
110+
{"encrypted_int": json.loads(CsInt(1, "table", "column").to_db_format()),
111+
"encrypted_boolean": json.loads(CsBool(True, "table", "column").to_db_format()),
112+
"encrypted_date": json.loads(CsDate(date(2024, 11, 1), "table", "column").to_db_format()),
113+
"encrypted_float": json.loads(CsFloat(1.1, "table", "column").to_db_format()),
114+
"encrypted_utf8_str": json.loads(CsText("text", "table", "column").to_db_format()),
115+
"encrypted_jsonb": json.loads(CsJsonb('{"a": 1}', "table", "column").to_db_format())
116+
})
117+
118+
self.assertEqual(
119+
cs_row.row,
120+
{"encrypted_int": 1,
121+
"encrypted_boolean": True,
122+
"encrypted_date": date(2024, 11, 1),
123+
"encrypted_float": 1.1,
124+
"encrypted_utf8_str": "text",
125+
"encrypted_jsonb": '"{\\"a\\": 1}"'
126+
}
127+
)
128+
129+
if __name__ == '__main__':
130+
unittest.main()

0 commit comments

Comments
 (0)