Skip to content

Commit 408b886

Browse files
committed
Update Python model and types
1 parent 04f3899 commit 408b886

File tree

4 files changed

+107
-61
lines changed

4 files changed

+107
-61
lines changed

languages/python/jupyter_notebook/cs_models.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@
99
import os
1010

1111
class CsTypeDecorator(TypeDecorator):
12-
def __init__(self, table_name, column_name):
12+
def __init__(self, table, column):
1313
super().__init__()
14-
self.table_name = table_name
15-
self.column_name = column_name
14+
self.table = table
15+
self.column = column
1616

1717
def process_bind_param(self, value, dialect):
1818
if value is not None:
1919
value_dict = {
2020
"k": "pt",
2121
"p": str(value),
2222
"i": {
23-
"t": self.table_name,
24-
"c": self.column_name
23+
"t": self.table,
24+
"c": self.column
2525
},
2626
"v": 1,
2727
"q": None
@@ -106,4 +106,12 @@ def __init__(self, e_utf8_str=None, e_jsonb=None, e_int=None, e_float=None, e_da
106106
self.encrypted_boolean = e_bool
107107

108108
def __repr__(self):
109-
return f"<Example(id={self.id}, encrypted_utf8_str={self.encrypted_utf8_str}, encrypted_jsonb={self.encrypted_jsonb}, encrypted_int={self.encrypted_int}, encrypted_float={self.encrypted_float}, encrypted_date={self.encrypted_date}, encrypted_boolean={self.encrypted_boolean})>"
109+
return "<Example(" \
110+
f"id={self.id}, " \
111+
f"encrypted_utf8_str={self.encrypted_utf8_str}, " \
112+
f"encrypted_jsonb={self.encrypted_jsonb}, " \
113+
f"encrypted_int={self.encrypted_int}, " \
114+
f"encrypted_float={self.encrypted_float}, " \
115+
f"encrypted_date={self.encrypted_date}, " \
116+
f"encrypted_boolean={self.encrypted_boolean}" \
117+
")>"

languages/python/jupyter_notebook/cs_models_test.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
from cs_models import *
55

66
class TestExampleModel(unittest.TestCase):
7+
pg_password = os.getenv('PGPASSWORD', 'postgres')
8+
pg_user = os.getenv('PGUSER', 'postgres')
9+
pg_host = os.getenv('PGHOST', 'localhost')
10+
pg_port = os.getenv('PGPORT', '6432')
11+
pg_db = os.getenv('PGDATABASE', 'cs_test_db')
12+
713
def setUp(self):
8-
# TODO: configure database URL in environment variable and use a test db (not getting_started)
9-
self.engine = create_engine('postgresql://postgres:postgres@localhost:6432/cipherstash_getting_started')
14+
self.engine = create_engine(f'postgresql://{self.pg_user}:{self.pg_password}@{self.pg_host}:{self.pg_port}/{self.pg_db}')
1015
Session = sessionmaker(bind=self.engine)
1116
self.session = Session()
1217
BaseModel.metadata.create_all(self.engine)
@@ -41,6 +46,13 @@ def test_encrypted_utf8_str(self):
4146
def test_encrypted_jsonb(self):
4247
found = self.session.query(Example).filter(Example.id == self.example.id).one()
4348
self.assertEqual(found.encrypted_jsonb, {"key": "value"})
49+
50+
def test_example_prints_value(self):
51+
self.example.id = 1
52+
self.assertEqual(
53+
str(self.example),
54+
"<Example(id=1, encrypted_utf8_str=str, encrypted_jsonb={'key': 'value'}, encrypted_int=1, encrypted_float=1.1, encrypted_date=2024-01-01, encrypted_boolean=True)>"
55+
)
4456

4557
if __name__ == '__main__':
4658
unittest.main()

languages/python/jupyter_notebook/cs_types.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
from pprint import pprint
33
from datetime import datetime
44
import json
5+
from enum import Enum
56

67
class CsValue:
78
def __init__(self, v, t: str, c: str):
89
self.value = v
910
self.table = t
1011
self.column = c
1112

12-
def to_db_format(self, query_type=None):
13+
def to_db_format(self, query_type = None):
1314
data = {
1415
"k": "pt",
15-
"p": self.value_in_db_format(query_type),
16+
"p": self._value_in_db_format(query_type),
1617
"i": {
1718
"t": str(self.table),
1819
"c": str(self.column)
@@ -24,62 +25,59 @@ def to_db_format(self, query_type=None):
2425

2526
@classmethod
2627
def from_parsed_json(cls, parsed):
27-
return cls.value_from_db_format(parsed["p"])
28+
return cls._value_from_db_format(parsed["p"])
2829

2930
class CsInt(CsValue):
30-
def value_in_db_format(self, query_type = None):
31+
def _value_in_db_format(self, query_type):
3132
return str(self.value)
3233

3334
@classmethod
34-
def value_from_db_format(cls, s: str):
35+
def _value_from_db_format(cls, s: str):
3536
return int(s)
3637

3738
class CsBool(CsValue):
38-
def value_in_db_format(self, query_type = None):
39+
def _value_in_db_format(self, query_type):
3940
return str(self.value).lower()
4041

4142
@classmethod
42-
def value_from_db_format(cls, s: str):
43+
def _value_from_db_format(cls, s: str):
4344
return s.lower() == 'true'
4445

4546
class CsDate(CsValue):
46-
def value_in_db_format(self, query_type = None):
47+
def _value_in_db_format(self, query_type):
4748
return self.value.isoformat()
4849

4950
@classmethod
50-
def value_from_db_format(cls, s: str):
51+
def _value_from_db_format(cls, s: str):
5152
return datetime.fromisoformat(s).date()
5253

5354
class CsFloat(CsValue):
54-
def value_in_db_format(self, query_type = None):
55+
def _value_in_db_format(self, query_type):
5556
return str(self.value)
5657

5758
@classmethod
58-
def value_from_db_format(cls, s: str):
59+
def _value_from_db_format(cls, s: str):
5960
return float(s)
6061

6162
class CsText(CsValue):
62-
def value_in_db_format(self, query_type = None):
63+
def _value_in_db_format(self, query_type):
6364
return self.value
6465

6566
@classmethod
66-
def value_from_db_format(cls, s: str):
67+
def _value_from_db_format(cls, s: str):
6768
return s
6869

6970
class CsJsonb(CsValue):
70-
def value_in_db_format(self, query_type):
71+
def _value_in_db_format(self, query_type):
7172
if query_type == "ejson_path":
7273
return self.value
7374
else:
7475
return json.dumps(self.value)
7576

7677
@classmethod
77-
def value_from_db_format(cls, s: str):
78+
def _value_from_db_format(cls, s: str):
7879
return json.loads(s)
7980

80-
def id_map(x):
81-
return x
82-
8381
class CsRow:
8482
column_function_mapping = {
8583
'encrypted_int': CsInt.from_parsed_json,
@@ -90,9 +88,15 @@ class CsRow:
9088
'encrypted_jsonb': CsText.from_parsed_json
9189
}
9290

91+
@staticmethod
92+
def id_map(x):
93+
return x
94+
9395
def __init__(self, row):
9496
self.row = {}
9597
for k, v in row.items():
96-
self.row[k] = None if v == None else self.column_function_mapping.get(k, id_map)(v)
97-
98-
98+
if v == None:
99+
self.row[k] = None
100+
else:
101+
mapping = self.column_function_mapping.get(k, self.id_map)
102+
self.row[k] = mapping(v)

languages/python/jupyter_notebook/cs_types_test.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from cs_types import *
55

66
class EqlTest(unittest.TestCase):
7+
def setUp(self):
8+
self.template_dict = json.loads('{"k": "pt", "p": "1", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}')
9+
710
def test(self):
811
self.assertTrue(True)
912

@@ -14,94 +17,113 @@ def test_to_db_format(self):
1417
)
1518

1619
def test_from_parsed_json_uses_p_value(self):
17-
parsed = json.loads('{"k": "pt", "p": "1", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}')
20+
self.template_dict["p"] = "1"
1821
self.assertEqual(
19-
CsInt.from_parsed_json(parsed),
22+
CsInt.from_parsed_json(self.template_dict),
2023
1
2124
)
2225

23-
def test_cs_int_prints_value(self):
24-
cs_int = CsInt(1, "table", "column")
26+
def test_cs_int_to_db_format(self):
27+
cs_int = CsInt(123, "table", "column")
2528
self.assertEqual(
26-
cs_int.value_in_db_format(),
27-
"1"
29+
'{"k": "pt", "p": "123", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}',
30+
cs_int.to_db_format()
2831
)
2932

30-
def test_ces_int_makes_int(self):
33+
def test_cs_int_from_parsed_json(self):
34+
self.template_dict["p"] = "123"
3135
self.assertEqual(
32-
CsInt.value_from_db_format("1"),
33-
1
36+
CsInt.from_parsed_json(self.template_dict),
37+
123
3438
)
3539

36-
def test_cs_bool_prints_value_in_lower_case(self):
40+
def test_cs_bool_to_db_format_true(self):
3741
cs_bool = CsBool(True, "table", "column")
3842
self.assertEqual(
39-
cs_bool.value_in_db_format(),
40-
"true"
43+
'{"k": "pt", "p": "true", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}',
44+
cs_bool.to_db_format()
45+
)
46+
47+
def test_cs_bool_to_db_format_false(self):
48+
cs_bool = CsBool(False, "table", "column")
49+
self.assertEqual(
50+
'{"k": "pt", "p": "false", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}',
51+
cs_bool.to_db_format()
4152
)
4253

43-
def test_cs_bool_returns_bool(self):
54+
def test_cs_bool_from_parsed_json_true(self):
55+
self.template_dict["p"] = "true"
4456
self.assertEqual(
45-
CsBool.value_from_db_format("true"),
57+
CsBool.from_parsed_json(self.template_dict),
4658
True
4759
)
4860

49-
def test_cs_date_prints_value(self):
61+
def test_cs_bool_from_parsed_json_false(self):
62+
self.template_dict["p"] = "false"
63+
self.assertEqual(
64+
CsBool.from_parsed_json(self.template_dict),
65+
False
66+
)
67+
68+
def test_cs_date_to_db_format(self):
5069
cs_date = CsDate(date(2024, 11, 1), "table", "column")
5170
self.assertEqual(
52-
cs_date.value_in_db_format(),
53-
"2024-11-01"
71+
'{"k": "pt", "p": "2024-11-01", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}',
72+
cs_date.to_db_format()
5473
)
5574

56-
def test_cs_date_returns_datetime(self):
75+
def test_cs_date_from_parsed_json(self):
76+
self.template_dict["p"] = "2024-11-01"
5777
self.assertEqual(
58-
CsDate.value_from_db_format("2024-11-01"),
78+
CsDate.from_parsed_json(self.template_dict),
5979
date(2024, 11, 1)
6080
)
6181

62-
def test_cs_float_prints_value(self):
82+
def test_cs_float_to_db_format(self):
6383
cs_float = CsFloat(1.1, "table", "column")
6484
self.assertEqual(
65-
cs_float.value_in_db_format(),
66-
"1.1"
85+
'{"k": "pt", "p": "1.1", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}',
86+
cs_float.to_db_format()
6787
)
6888

69-
def test_cs_float_returns_float(self):
89+
def test_cs_float_from_parsed_json(self):
90+
self.template_dict["p"] = "1.1"
7091
self.assertEqual(
71-
CsFloat.value_from_db_format("1.1"),
92+
CsFloat.from_parsed_json(self.template_dict),
7293
1.1
7394
)
7495

75-
def test_cs_text_prints_value(self):
96+
def test_cs_text_to_db_format(self):
7697
cs_text = CsText("text", "table", "column")
7798
self.assertEqual(
78-
cs_text.value_in_db_format(),
79-
"text"
99+
'{"k": "pt", "p": "text", "i": {"t": "table", "c": "column"}, "v": 1, "q": null}',
100+
cs_text.to_db_format()
80101
)
81102

82-
def test_cs_text_returns_value(self):
103+
def test_cs_text_from_parsed_json(self):
104+
self.template_dict["p"] = "text"
83105
self.assertEqual(
84-
CsText.value_from_db_format("text"),
106+
CsText.from_parsed_json(self.template_dict),
85107
"text"
86108
)
87109

88110
def test_cs_jsonb_prints_json_string(self):
89111
cs_jsonb = CsJsonb({"a": 1}, "table", "column")
90112
self.assertEqual(
91-
cs_jsonb.value_in_db_format("ste_vec"),
113+
cs_jsonb._value_in_db_format("ste_vec"),
92114
'{"a": 1}'
93115
)
94116

95117
def test_cs_jsonb_prints_value_for_ejson_path(self):
96118
cs_jsonb = CsJsonb("$.a.b", "table", "column")
97119
self.assertEqual(
98-
cs_jsonb.value_in_db_format("ejson_path"),
120+
cs_jsonb._value_in_db_format("ejson_path"),
99121
'$.a.b'
100122
)
101123

102124
def test_cs_jsonb_returns_value(self):
103125
self.assertEqual(
104-
CsJsonb.value_from_db_format('{"a": 1}'),
126+
CsJsonb._value_from_db_format('{"a": 1}'),
105127
{"a": 1}
106128
)
107129

0 commit comments

Comments
 (0)