Skip to content

Commit 38ffbbc

Browse files
authored
Optimize code structure and test cases (#49)
1 parent c417257 commit 38ffbbc

File tree

11 files changed

+909
-790
lines changed

11 files changed

+909
-790
lines changed

docs/advanced/filter.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,17 @@ items = await item_crud.select_models(
131131
# 获取年龄等于 30 岁或 40 岁的员工
132132
items = await item_crud.select_models(
133133
session=db,
134-
__or__=[
135-
{'age__eq': 30},
136-
{'age__eq': 40}
137-
]
134+
__or__={
135+
'age__eq': [30, 40]
136+
}
138137
)
139138

140139
# 获取年龄在 30 - 40 岁之间或薪资大于 20k 的员工
141140
items = await item_crud.select_models(
142141
session=db,
143-
__or__=[
144-
{'age__between': [30, 40]},
145-
{'payroll__gt': 20000}
146-
]
142+
__or__={
143+
'age__between': [30, 40],
144+
'payroll__gt': 20000
145+
}
147146
)
148147
```

sqlalchemy_crud_plus/crud.py

Lines changed: 82 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from typing import Any, Generic, Iterable, Sequence, Type
3+
from typing import Any, Generic, Iterable, Sequence
44

55
from sqlalchemy import (
66
Column,
77
ColumnExpressionArgument,
8-
Row,
9-
RowMapping,
108
Select,
119
delete,
1210
func,
@@ -16,13 +14,13 @@
1614
)
1715
from sqlalchemy.ext.asyncio import AsyncSession
1816

19-
from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, MultipleResultsError
17+
from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, ModelColumnError, MultipleResultsError
2018
from sqlalchemy_crud_plus.types import CreateSchema, Model, UpdateSchema
2119
from sqlalchemy_crud_plus.utils import apply_sorting, parse_filters
2220

2321

2422
class CRUDPlus(Generic[Model]):
25-
def __init__(self, model: Type[Model]):
23+
def __init__(self, model: type[Model]):
2624
self.model = model
2725
self.primary_key = self._get_primary_key()
2826

@@ -37,11 +35,11 @@ def _get_primary_key(self) -> Column | list[Column]:
3735
else:
3836
return list(primary_key)
3937

40-
def _get_pk_filter(self, pk: Any | Sequence[Any]) -> list[bool]:
38+
def _get_pk_filter(self, pk: Any | Sequence[Any]) -> list[ColumnExpressionArgument[bool]]:
4139
"""
4240
Get the primary key filter(s).
4341
44-
:param pk: Single value for simple primary key, or tuple for composite primary key.
42+
:param pk: Single value for simple primary key, or tuple for composite primary key
4543
:return:
4644
"""
4745
if isinstance(self.primary_key, list):
@@ -60,17 +58,20 @@ async def create_model(
6058
**kwargs,
6159
) -> Model:
6260
"""
63-
Create a new instance of a model
61+
Create a new instance of a model.
6462
65-
:param session: The SQLAlchemy async session.
66-
:param obj: The Pydantic schema containing data to be saved.
67-
:param flush: If `True`, flush all object changes to the database. Default is `False`.
68-
:param commit: If `True`, commits the transaction immediately. Default is `False`.
69-
:param kwargs: Additional model data not included in the pydantic schema.
63+
:param session: The SQLAlchemy async session
64+
:param obj: The Pydantic schema containing data to be saved
65+
:param flush: If `True`, flush all object changes to the database
66+
:param commit: If `True`, commits the transaction immediately
67+
:param kwargs: Additional model data not included in the pydantic schema
7068
:return:
7169
"""
72-
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
70+
obj_data = obj.model_dump()
71+
if kwargs:
72+
obj_data.update(kwargs)
7373

74+
ins = self.model(**obj_data)
7475
session.add(ins)
7576

7677
if flush:
@@ -89,18 +90,21 @@ async def create_models(
8990
**kwargs,
9091
) -> list[Model]:
9192
"""
92-
Create new instances of a model
93+
Create new instances of a model.
9394
94-
:param session: The SQLAlchemy async session.
95-
:param objs: The Pydantic schema list containing data to be saved.
96-
:param flush: If `True`, flush all object changes to the database. Default is `False`.
97-
:param commit: If `True`, commits the transaction immediately. Default is `False`.
98-
:param kwargs: Additional model data not included in the pydantic schema.
95+
:param session: The SQLAlchemy async session
96+
:param objs: The Pydantic schema list containing data to be saved
97+
:param flush: If `True`, flush all object changes to the database
98+
:param commit: If `True`, commits the transaction immediately
99+
:param kwargs: Additional model data not included in the pydantic schema
99100
:return:
100101
"""
101102
ins_list = []
102103
for obj in objs:
103-
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
104+
obj_data = obj.model_dump()
105+
if kwargs:
106+
obj_data.update(kwargs)
107+
ins = self.model(**obj_data)
104108
ins_list.append(ins)
105109

106110
session.add_all(ins_list)
@@ -119,19 +123,22 @@ async def count(
119123
**kwargs,
120124
) -> int:
121125
"""
122-
Counts records that match specified filters.
126+
Count records that match specified filters.
123127
124-
:param session: The sqlalchemy session to use for the operation.
125-
:param whereclause: The WHERE clauses to apply to the query.
126-
:param kwargs: Query expressions.
128+
:param session: The SQLAlchemy async session
129+
:param whereclause: Additional WHERE clauses to apply to the query
130+
:param kwargs: Filter expressions using field__operator=value syntax
127131
:return:
128132
"""
129133
filters = list(whereclause)
130134

131135
if kwargs:
132136
filters.extend(parse_filters(self.model, **kwargs))
133137

134-
stmt = select(func.count()).select_from(self.model).where(*filters)
138+
stmt = select(func.count()).select_from(self.model)
139+
if filters:
140+
stmt = stmt.where(*filters)
141+
135142
query = await session.execute(stmt)
136143
total_count = query.scalar()
137144
return total_count if total_count is not None else 0
@@ -143,11 +150,11 @@ async def exists(
143150
**kwargs,
144151
) -> bool:
145152
"""
146-
Whether the records that match the specified filter exist.
153+
Check whether records that match the specified filters exist.
147154
148-
:param session: The sqlalchemy session to use for the operation.
149-
:param whereclause: The WHERE clauses to apply to the query.
150-
:param kwargs: Query expressions.
155+
:param session: The SQLAlchemy async session
156+
:param whereclause: Additional WHERE clauses to apply to the query
157+
:param kwargs: Filter expressions using field__operator=value syntax
151158
:return:
152159
"""
153160
filter_list = list(whereclause)
@@ -174,7 +181,7 @@ async def select_model(
174181
:return:
175182
"""
176183
filters = self._get_pk_filter(pk)
177-
filters + list(whereclause)
184+
filters.extend(list(whereclause))
178185
stmt = select(self.model).where(*filters)
179186
query = await session.execute(stmt)
180187
return query.scalars().first()
@@ -235,13 +242,13 @@ async def select_models(
235242
session: AsyncSession,
236243
*whereclause: ColumnExpressionArgument[bool],
237244
**kwargs,
238-
) -> Sequence[Row[Any] | RowMapping | Any]:
245+
) -> Sequence[Model]:
239246
"""
240-
Query all rows
247+
Query all rows that match the specified filters.
241248
242-
:param session: The SQLAlchemy async session.
243-
:param whereclause: The WHERE clauses to apply to the query.
244-
:param kwargs: Query expressions.
249+
:param session: The SQLAlchemy async session
250+
:param whereclause: Additional WHERE clauses to apply to the query
251+
:param kwargs: Filter expressions using field__operator=value syntax
245252
:return:
246253
"""
247254
stmt = await self.select(*whereclause, **kwargs)
@@ -255,15 +262,15 @@ async def select_models_order(
255262
sort_orders: str | list[str] | None = None,
256263
*whereclause: ColumnExpressionArgument[bool],
257264
**kwargs,
258-
) -> Sequence[Row | RowMapping | Any] | None:
265+
) -> Sequence[Model]:
259266
"""
260-
Query all rows and sort by columns
267+
Query all rows that match the specified filters and sort by columns.
261268
262-
:param session: The SQLAlchemy async session.
263-
:param sort_columns: more details see apply_sorting
264-
:param sort_orders: more details see apply_sorting
265-
:param whereclause: The WHERE clauses to apply to the query.
266-
:param kwargs: Query expressions.
269+
:param session: The SQLAlchemy async session
270+
:param sort_columns: Column name(s) to sort by
271+
:param sort_orders: Sort order(s) ('asc' or 'desc')
272+
:param whereclause: Additional WHERE clauses to apply to the query
273+
:param kwargs: Filter expressions using field__operator=value syntax
267274
:return:
268275
"""
269276
stmt = await self.select_order(sort_columns, sort_orders, *whereclause, **kwargs)
@@ -313,21 +320,25 @@ async def update_model_by_column(
313320
**kwargs,
314321
) -> int:
315322
"""
316-
Update an instance by model column
323+
Update records by model column filters.
317324
318-
:param session: The SQLAlchemy async session.
319-
:param obj: A pydantic schema or dictionary containing the update data
320-
:param allow_multiple: If `True`, allows updating multiple records that match the filters.
321-
:param flush: If `True`, flush all object changes to the database. Default is `False`.
322-
:param commit: If `True`, commits the transaction immediately. Default is `False`.
323-
:param kwargs: Query expressions.
325+
:param session: The SQLAlchemy async session
326+
:param obj: A Pydantic schema or dictionary containing the update data
327+
:param allow_multiple: If `True`, allows updating multiple records that match the filters
328+
:param flush: If `True`, flush all object changes to the database
329+
:param commit: If `True`, commits the transaction immediately
330+
:param kwargs: Filter expressions using field__operator=value syntax
324331
:return:
325332
"""
326333
filters = parse_filters(self.model, **kwargs)
327334

328-
total_count = await self.count(session, *filters)
329-
if not allow_multiple and total_count > 1:
330-
raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.')
335+
if not filters:
336+
raise ValueError('At least one filter condition must be provided for update operation')
337+
338+
if not allow_multiple:
339+
total_count = await self.count(session, *filters)
340+
if total_count > 1:
341+
raise MultipleResultsError(f'Only one record is expected to be updated, found {total_count} records.')
331342

332343
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
333344
stmt = update(self.model).where(*filters).values(**instance_data)
@@ -379,22 +390,30 @@ async def delete_model_by_column(
379390
**kwargs,
380391
) -> int:
381392
"""
382-
Delete an instance by model column
393+
Delete records by model column filters.
383394
384-
:param session: The SQLAlchemy async session.
385-
:param allow_multiple: If `True`, allows deleting multiple records that match the filters.
395+
:param session: The SQLAlchemy async session
396+
:param allow_multiple: If `True`, allows deleting multiple records that match the filters
386397
:param logical_deletion: If `True`, enable logical deletion instead of physical deletion
387-
:param deleted_flag_column: Specify the flag column for logical deletion
388-
:param flush: If `True`, flush all object changes to the database. Default is `False`.
389-
:param commit: If `True`, commits the transaction immediately. Default is `False`.
390-
:param kwargs: Query expressions.
398+
:param deleted_flag_column: Column name for logical deletion flag
399+
:param flush: If `True`, flush all object changes to the database
400+
:param commit: If `True`, commits the transaction immediately
401+
:param kwargs: Filter expressions using field__operator=value syntax
391402
:return:
392403
"""
404+
if logical_deletion:
405+
if not hasattr(self.model, deleted_flag_column):
406+
raise ModelColumnError(f'Column {deleted_flag_column} is not found in {self.model}')
407+
393408
filters = parse_filters(self.model, **kwargs)
394409

395-
total_count = await self.count(session, *filters)
396-
if not allow_multiple and total_count > 1:
397-
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
410+
if not filters:
411+
raise ValueError('At least one filter condition must be provided for delete operation')
412+
413+
if not allow_multiple:
414+
total_count = await self.count(session, *filters)
415+
if total_count > 1:
416+
raise MultipleResultsError(f'Only one record is expected to be deleted, found {total_count} records.')
398417

399418
stmt = (
400419
update(self.model).where(*filters).values(**{deleted_flag_column: True})

0 commit comments

Comments
 (0)