Skip to content

Optimize code structure and test cases #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions docs/advanced/filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,17 @@ items = await item_crud.select_models(
# 获取年龄等于 30 岁或 40 岁的员工
items = await item_crud.select_models(
session=db,
__or__=[
{'age__eq': 30},
{'age__eq': 40}
]
__or__={
'age__eq': [30, 40]
}
)

# 获取年龄在 30 - 40 岁之间或薪资大于 20k 的员工
items = await item_crud.select_models(
session=db,
__or__=[
{'age__between': [30, 40]},
{'payroll__gt': 20000}
]
__or__={
'age__between': [30, 40],
'payroll__gt': 20000
}
)
```
145 changes: 82 additions & 63 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Any, Generic, Iterable, Sequence, Type
from typing import Any, Generic, Iterable, Sequence

from sqlalchemy import (
Column,
ColumnExpressionArgument,
Row,
RowMapping,
Select,
delete,
func,
Expand All @@ -16,13 +14,13 @@
)
from sqlalchemy.ext.asyncio import AsyncSession

from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, MultipleResultsError
from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, ModelColumnError, MultipleResultsError
from sqlalchemy_crud_plus.types import CreateSchema, Model, UpdateSchema
from sqlalchemy_crud_plus.utils import apply_sorting, parse_filters


class CRUDPlus(Generic[Model]):
def __init__(self, model: Type[Model]):
def __init__(self, model: type[Model]):
self.model = model
self.primary_key = self._get_primary_key()

Expand All @@ -37,11 +35,11 @@ def _get_primary_key(self) -> Column | list[Column]:
else:
return list(primary_key)

def _get_pk_filter(self, pk: Any | Sequence[Any]) -> list[bool]:
def _get_pk_filter(self, pk: Any | Sequence[Any]) -> list[ColumnExpressionArgument[bool]]:
"""
Get the primary key filter(s).

:param pk: Single value for simple primary key, or tuple for composite primary key.
:param pk: Single value for simple primary key, or tuple for composite primary key
:return:
"""
if isinstance(self.primary_key, list):
Expand All @@ -60,17 +58,20 @@ async def create_model(
**kwargs,
) -> Model:
"""
Create a new instance of a model
Create a new instance of a model.

:param session: The SQLAlchemy async session.
:param obj: The Pydantic schema containing data to be saved.
:param flush: If `True`, flush all object changes to the database. Default is `False`.
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:param kwargs: Additional model data not included in the pydantic schema.
:param session: The SQLAlchemy async session
:param obj: The Pydantic schema containing data to be saved
:param flush: If `True`, flush all object changes to the database
:param commit: If `True`, commits the transaction immediately
:param kwargs: Additional model data not included in the pydantic schema
:return:
"""
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
obj_data = obj.model_dump()
if kwargs:
obj_data.update(kwargs)

ins = self.model(**obj_data)
session.add(ins)

if flush:
Expand All @@ -89,18 +90,21 @@ async def create_models(
**kwargs,
) -> list[Model]:
"""
Create new instances of a model
Create new instances of a model.

:param session: The SQLAlchemy async session.
:param objs: The Pydantic schema list containing data to be saved.
:param flush: If `True`, flush all object changes to the database. Default is `False`.
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:param kwargs: Additional model data not included in the pydantic schema.
:param session: The SQLAlchemy async session
:param objs: The Pydantic schema list containing data to be saved
:param flush: If `True`, flush all object changes to the database
:param commit: If `True`, commits the transaction immediately
:param kwargs: Additional model data not included in the pydantic schema
:return:
"""
ins_list = []
for obj in objs:
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
obj_data = obj.model_dump()
if kwargs:
obj_data.update(kwargs)
ins = self.model(**obj_data)
ins_list.append(ins)

session.add_all(ins_list)
Expand All @@ -119,19 +123,22 @@ async def count(
**kwargs,
) -> int:
"""
Counts records that match specified filters.
Count records that match specified filters.

:param session: The sqlalchemy session to use for the operation.
:param whereclause: The WHERE clauses to apply to the query.
:param kwargs: Query expressions.
:param session: The SQLAlchemy async session
:param whereclause: Additional WHERE clauses to apply to the query
:param kwargs: Filter expressions using field__operator=value syntax
:return:
"""
filters = list(whereclause)

if kwargs:
filters.extend(parse_filters(self.model, **kwargs))

stmt = select(func.count()).select_from(self.model).where(*filters)
stmt = select(func.count()).select_from(self.model)
if filters:
stmt = stmt.where(*filters)

query = await session.execute(stmt)
total_count = query.scalar()
return total_count if total_count is not None else 0
Expand All @@ -143,11 +150,11 @@ async def exists(
**kwargs,
) -> bool:
"""
Whether the records that match the specified filter exist.
Check whether records that match the specified filters exist.

:param session: The sqlalchemy session to use for the operation.
:param whereclause: The WHERE clauses to apply to the query.
:param kwargs: Query expressions.
:param session: The SQLAlchemy async session
:param whereclause: Additional WHERE clauses to apply to the query
:param kwargs: Filter expressions using field__operator=value syntax
:return:
"""
filter_list = list(whereclause)
Expand All @@ -174,7 +181,7 @@ async def select_model(
:return:
"""
filters = self._get_pk_filter(pk)
filters + list(whereclause)
filters.extend(list(whereclause))
stmt = select(self.model).where(*filters)
query = await session.execute(stmt)
return query.scalars().first()
Expand Down Expand Up @@ -235,13 +242,13 @@ async def select_models(
session: AsyncSession,
*whereclause: ColumnExpressionArgument[bool],
**kwargs,
) -> Sequence[Row[Any] | RowMapping | Any]:
) -> Sequence[Model]:
"""
Query all rows
Query all rows that match the specified filters.

:param session: The SQLAlchemy async session.
:param whereclause: The WHERE clauses to apply to the query.
:param kwargs: Query expressions.
:param session: The SQLAlchemy async session
:param whereclause: Additional WHERE clauses to apply to the query
:param kwargs: Filter expressions using field__operator=value syntax
:return:
"""
stmt = await self.select(*whereclause, **kwargs)
Expand All @@ -255,15 +262,15 @@ async def select_models_order(
sort_orders: str | list[str] | None = None,
*whereclause: ColumnExpressionArgument[bool],
**kwargs,
) -> Sequence[Row | RowMapping | Any] | None:
) -> Sequence[Model]:
"""
Query all rows and sort by columns
Query all rows that match the specified filters and sort by columns.

:param session: The SQLAlchemy async session.
:param sort_columns: more details see apply_sorting
:param sort_orders: more details see apply_sorting
:param whereclause: The WHERE clauses to apply to the query.
:param kwargs: Query expressions.
:param session: The SQLAlchemy async session
:param sort_columns: Column name(s) to sort by
:param sort_orders: Sort order(s) ('asc' or 'desc')
:param whereclause: Additional WHERE clauses to apply to the query
:param kwargs: Filter expressions using field__operator=value syntax
:return:
"""
stmt = await self.select_order(sort_columns, sort_orders, *whereclause, **kwargs)
Expand Down Expand Up @@ -313,21 +320,25 @@ async def update_model_by_column(
**kwargs,
) -> int:
"""
Update an instance by model column
Update records by model column filters.

:param session: The SQLAlchemy async session.
:param obj: A pydantic schema or dictionary containing the update data
:param allow_multiple: If `True`, allows updating multiple records that match the filters.
:param flush: If `True`, flush all object changes to the database. Default is `False`.
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:param kwargs: Query expressions.
:param session: The SQLAlchemy async session
:param obj: A Pydantic schema or dictionary containing the update data
:param allow_multiple: If `True`, allows updating multiple records that match the filters
:param flush: If `True`, flush all object changes to the database
:param commit: If `True`, commits the transaction immediately
:param kwargs: Filter expressions using field__operator=value syntax
:return:
"""
filters = parse_filters(self.model, **kwargs)

total_count = await self.count(session, *filters)
if not allow_multiple and total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.')
if not filters:
raise ValueError('At least one filter condition must be provided for update operation')

if not allow_multiple:
total_count = await self.count(session, *filters)
if total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be updated, found {total_count} records.')

instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
stmt = update(self.model).where(*filters).values(**instance_data)
Expand Down Expand Up @@ -379,22 +390,30 @@ async def delete_model_by_column(
**kwargs,
) -> int:
"""
Delete an instance by model column
Delete records by model column filters.

:param session: The SQLAlchemy async session.
:param allow_multiple: If `True`, allows deleting multiple records that match the filters.
:param session: The SQLAlchemy async session
:param allow_multiple: If `True`, allows deleting multiple records that match the filters
:param logical_deletion: If `True`, enable logical deletion instead of physical deletion
:param deleted_flag_column: Specify the flag column for logical deletion
:param flush: If `True`, flush all object changes to the database. Default is `False`.
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:param kwargs: Query expressions.
:param deleted_flag_column: Column name for logical deletion flag
:param flush: If `True`, flush all object changes to the database
:param commit: If `True`, commits the transaction immediately
:param kwargs: Filter expressions using field__operator=value syntax
:return:
"""
if logical_deletion:
if not hasattr(self.model, deleted_flag_column):
raise ModelColumnError(f'Column {deleted_flag_column} is not found in {self.model}')

filters = parse_filters(self.model, **kwargs)

total_count = await self.count(session, *filters)
if not allow_multiple and total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
if not filters:
raise ValueError('At least one filter condition must be provided for delete operation')

if not allow_multiple:
total_count = await self.count(session, *filters)
if total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be deleted, found {total_count} records.')

stmt = (
update(self.model).where(*filters).values(**{deleted_flag_column: True})
Expand Down
Loading