Skip to content

Commit 9da063f

Browse files
authored
Add multi conditions and sort select (#2)
* Add multi conditions and sort select * Update library information
1 parent 2af3678 commit 9da063f

File tree

4 files changed

+110
-5
lines changed

4 files changed

+110
-5
lines changed

pyproject.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlalchemy-crud-plus"
3-
description = "Asynchronous CRUD Operation based on SQLAlChemy2 Model"
3+
description = "Asynchronous CRUD operation based on SQLAlchemy2 model"
44
dynamic = [
55
"version",
66
]
@@ -11,10 +11,21 @@ dependencies = [
1111
"sqlalchemy>=2.0.0",
1212
"pydantic>=2.0",
1313
]
14+
classifiers = [
15+
"License :: OSI Approved :: MIT License",
16+
"Programming Language :: Python :: 3 :: Only",
17+
"Programming Language :: Python :: 3.10",
18+
"Programming Language :: Python :: 3.11",
19+
"Programming Language :: Python :: 3.12",
20+
]
1421
requires-python = ">=3.10"
1522
readme = "README.md"
1623
license = { text = "MIT" }
1724

25+
[project.urls]
26+
homepage = "https://github.com/fastapi-practices/sqlalchemy-crud-plus"
27+
repository = "https://github.com/fastapi-practices/sqlalchemy-crud-plus"
28+
1829
[tool.pdm.dev-dependencies]
1930
lint = [
2031
"ruff>=0.4.0",

sqlalchemy_crud_plus/crud.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from typing import Any, Generic, Sequence, Type, TypeVar
3+
from typing import Any, Generic, Literal, Sequence, Type, TypeVar
44

55
from pydantic import BaseModel
6-
from sqlalchemy import Row, RowMapping, select
6+
from sqlalchemy import Row, RowMapping, and_, asc, desc, or_, select
77
from sqlalchemy import delete as sa_delete
88
from sqlalchemy import update as sa_update
99
from sqlalchemy.ext.asyncio import AsyncSession
1010

11-
from sqlalchemy_crud_plus.errors import ModelColumnError
11+
from sqlalchemy_crud_plus.errors import ModelColumnError, SelectExpressionError
1212

1313
_Model = TypeVar('_Model')
1414
_CreateSchema = TypeVar('_CreateSchema', bound=BaseModel)
@@ -59,7 +59,34 @@ async def select_model_by_column(self, session: AsyncSession, column: str, colum
5959
query = await session.execute(select(self.model).where(model_column == column_value)) # type: ignore
6060
return query.scalars().first()
6161
else:
62-
raise ModelColumnError('Model column not found')
62+
raise ModelColumnError(f'Model column {column} is not found')
63+
64+
async def select_model_by_columns(
65+
self, session: AsyncSession, expression: Literal['and', 'or'] = 'and', **conditions
66+
) -> _Model | None:
67+
"""
68+
Query by columns
69+
70+
:param session:
71+
:param expression:
72+
:param conditions: Query conditions, format:column1=value1, column2=value2
73+
:return:
74+
"""
75+
where_list = []
76+
for column, value in conditions.items():
77+
if hasattr(self.model, column):
78+
model_column = getattr(self.model, column)
79+
where_list.append(model_column == value)
80+
else:
81+
raise ModelColumnError(f'Model column {column} is not found')
82+
match expression:
83+
case 'and':
84+
query = await session.execute(select(self.model).where(and_(*where_list)))
85+
case 'or':
86+
query = await session.execute(select(self.model).where(or_(*where_list)))
87+
case _:
88+
raise SelectExpressionError(f'select expression {expression} is not supported')
89+
return query.scalars().first()
6390

6491
async def select_models(self, session: AsyncSession) -> Sequence[Row | RowMapping | Any] | None:
6592
"""
@@ -71,6 +98,41 @@ async def select_models(self, session: AsyncSession) -> Sequence[Row | RowMappin
7198
query = await session.execute(select(self.model))
7299
return query.scalars().all()
73100

101+
async def select_models_order(
102+
self,
103+
session: AsyncSession,
104+
*columns,
105+
model_sort: Literal['skip', 'asc', 'desc'] = 'skip',
106+
) -> Sequence[Row | RowMapping | Any] | None:
107+
"""
108+
Query all rows asc or desc
109+
110+
:param session:
111+
:param columns:
112+
:param model_sort:
113+
:return:
114+
"""
115+
if model_sort != 'skip':
116+
if len(columns) != 1:
117+
raise SelectExpressionError('ACS and DESC only allow you to specify one column for sorting')
118+
sort_list = []
119+
for column in columns:
120+
if hasattr(self.model, column):
121+
model_column = getattr(self.model, column)
122+
sort_list.append(model_column)
123+
else:
124+
raise ModelColumnError(f'Model column {column} is not found')
125+
match model_sort:
126+
case 'skip':
127+
query = await session.execute(select(self.model).order_by(*sort_list))
128+
case 'asc':
129+
query = await session.execute(select(self.model).order_by(asc(*sort_list)))
130+
case 'desc':
131+
query = await session.execute(select(self.model).order_by(desc(*sort_list)))
132+
case _:
133+
raise SelectExpressionError(f'select sort expression {model_sort} is not supported')
134+
return query.scalars().all()
135+
74136
async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], **kwargs) -> int:
75137
"""
76138
Update an instance of a model

sqlalchemy_crud_plus/errors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,10 @@ class ModelColumnError(SQLAlchemyCRUDPlusException):
1515

1616
def __init__(self, msg: str) -> None:
1717
super().__init__(msg)
18+
19+
20+
class SelectExpressionError(SQLAlchemyCRUDPlusException):
21+
"""Error raised when a select expression is invalid."""
22+
23+
def __init__(self, msg: str) -> None:
24+
super().__init__(msg)

tests/test_crud.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ async def test_select_model_by_column():
5454
assert result.name == f'test_name_{i}'
5555

5656

57+
@pytest.mark.asyncio
58+
async def test_select_model_by_columns():
59+
await create_test_model()
60+
async with async_db_session() as session:
61+
crud = CRUDPlus(Ins)
62+
for i in range(1, 10):
63+
result = await crud.select_model_by_columns(session, id=f'{i}', name=f'test_name_{i}')
64+
assert result.name == f'test_name_{i}'
65+
result = await crud.select_model_by_columns(session, 'or', id=f'{i}', name='null')
66+
assert result.name == f'test_name_{i}'
67+
68+
5769
@pytest.mark.asyncio
5870
async def test_select_models():
5971
await create_test_model()
@@ -63,6 +75,19 @@ async def test_select_models():
6375
assert len(result) == 9
6476

6577

78+
@pytest.mark.asyncio
79+
async def test_select_models_order():
80+
await create_test_model()
81+
async with async_db_session() as session:
82+
crud = CRUDPlus(Ins)
83+
result = await crud.select_models_order(session, 'id', 'name')
84+
assert result[0].id == 1
85+
result = await crud.select_models_order(session, 'id', model_sort='asc')
86+
assert result[0].id == 1
87+
result = await crud.select_models_order(session, 'id', model_sort='desc')
88+
assert result[0].id == 9
89+
90+
6691
@pytest.mark.asyncio
6792
async def test_update_model():
6893
await create_test_model()

0 commit comments

Comments
 (0)