Skip to content

Add mor and __gor__ filters #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions docs/advanced/filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ items = await item_crud.select_models(

运算符需要多个值,且仅允许元组,列表,集合

```python
# 获取年龄在 30 - 40 岁之间的员工
```python title="__between"
# 获取年龄在 30 - 40 岁之间且名字在目标列表的员工
items = await item_crud.select_models(
session=db,
age__between=[30, 40],
name__in=['bob', 'lucy'],
)
```

Expand All @@ -86,7 +87,7 @@ items = await item_crud.select_models(
可以通过将多个过滤器链接在一起来实现 AND 子句

```python
# 获取年龄在 30 以上,薪资大于 2w 的员工
# 获取年龄在 30 以上,薪资大于 20k 的员工
items = await item_crud.select_models(
session=db,
age__gt=30,
Expand All @@ -100,14 +101,48 @@ items = await item_crud.select_models(

每个键都应是库已支持的过滤器,仅允许字典

```python
```python title="__or"
# 获取年龄在 40 岁以上或 30 岁以下的员工
items = await item_crud.select_models(
session=db,
age__or={'gt': 40, 'lt': 30},
)
```

## MOR

!!! note

`or` 过滤器的高级用法,每个键都应是库已支持的过滤器,仅允许字典

```python title="__mor"
# 获取年龄等于 30 岁和 40 岁的员工
items = await item_crud.select_models(
session=db,
age__mor={'eq': [30, 40]}, # (1)
)
```

1. 原因:在 python 字典中,不允许存在相同的键值;<br/>
场景:我有一个列,需要多个相同条件但不同条件值的查询,此时,你应该使用 `mor` 过滤器,正如此示例一样使用它

## GOR

!!! note

`or` 过滤器的更高级用法,每个值都应是一个已受支持的条件过滤器,它应该是一个数组

```python title="__gor__"
# 获取年龄在 30 - 40 岁之间且薪资大于 20k 的员工
items = await item_crud.select_models(
session=db,
__gor__=[
{'age__between': [30, 40]},
{'payroll__gt': 20000}
]
)
```

## 算数

!!! note
Expand All @@ -119,9 +154,9 @@ items = await item_crud.select_models(
`condition`:此值将作为运算后的比较值,比较条件取决于使用的过滤器

```python
# 获取薪资打八折以后仍高于 15000 的员工
# 获取薪资打八折以后仍高于 20k 的员工
items = await item_crud.select_models(
session=db,
payroll__mul={'value': 0.8, 'condition': {'gt': 15000}},
payroll__mul={'value': 0.8, 'condition': {'gt': 20000}},
)
```
103 changes: 74 additions & 29 deletions sqlalchemy_crud_plus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqlalchemy import ColumnElement, Select, and_, asc, desc, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.util import AliasedClass

from sqlalchemy_crud_plus.errors import ColumnSortError, ModelColumnError, SelectOperatorError
Expand Down Expand Up @@ -70,7 +71,7 @@ def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = Tr
raise SelectOperatorError(f'Nested arithmetic operations are not allowed: {operator}')

sqlalchemy_filter = _SUPPORTED_FILTERS.get(operator)
if sqlalchemy_filter is None:
if sqlalchemy_filter is None and operator not in ['or', 'mor', '__gor']:
warnings.warn(
f'The operator <{operator}> is not yet supported, only {", ".join(_SUPPORTED_FILTERS.keys())}.',
SyntaxWarning,
Expand All @@ -80,48 +81,92 @@ def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = Tr
return sqlalchemy_filter


def get_column(model: Type[Model] | AliasedClass, field_name: str):
def get_column(model: Type[Model] | AliasedClass, field_name: str) -> InstrumentedAttribute | None:
column = getattr(model, field_name, None)
if column is None:
raise ModelColumnError(f'Column {field_name} is not found in {model}')
return column


def _create_or_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
or_filters = []
if op == 'or':
for or_op, or_value in value.items():
sqlalchemy_filter = get_sqlalchemy_filter(or_op, or_value)
if sqlalchemy_filter is not None:
or_filters.append(sqlalchemy_filter(column)(or_value))
elif op == 'mor':
for or_op, or_values in value.items():
for or_value in or_values:
sqlalchemy_filter = get_sqlalchemy_filter(or_op, or_value)
if sqlalchemy_filter is not None:
or_filters.append(sqlalchemy_filter(column)(or_value))
return or_filters


def _create_arithmetic_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
arithmetic_filters = []
if isinstance(value, dict) and {'value', 'condition'}.issubset(value):
arithmetic_value = value['value']
condition = value['condition']
sqlalchemy_filter = get_sqlalchemy_filter(op, arithmetic_value)
if sqlalchemy_filter is not None:
for cond_op, cond_value in condition.items():
arithmetic_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
arithmetic_filters.append(
arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(cond_value)
if cond_op != 'between'
else arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(*cond_value)
)
return arithmetic_filters


def _create_and_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
and_filters = []
sqlalchemy_filter = get_sqlalchemy_filter(op, value)
if sqlalchemy_filter is not None:
and_filters.append(sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value))
return and_filters


def parse_filters(model: Type[Model] | AliasedClass, **kwargs) -> list[ColumnElement]:
filters = []

def process_filters(target_column: str, target_op: str, target_value: Any):
# OR / MOR
or_filters = _create_or_filters(target_column, target_op, target_value)
if or_filters:
filters.append(or_(*or_filters))

# ARITHMETIC
arithmetic_filters = _create_arithmetic_filters(target_column, target_op, target_value)
if arithmetic_filters:
filters.append(and_(*arithmetic_filters))
else:
# AND
and_filters = _create_and_filters(target_column, target_op, target_value)
if and_filters:
filters.append(*and_filters)

for key, value in kwargs.items():
if '__' in key:
field_name, op = key.rsplit('__', 1)
column = get_column(model, field_name)
if op == 'or':
or_filters = [
sqlalchemy_filter(column)(or_value)
for or_op, or_value in value.items()
if (sqlalchemy_filter := get_sqlalchemy_filter(or_op, or_value)) is not None
]
filters.append(or_(*or_filters))
elif isinstance(value, dict) and {'value', 'condition'}.issubset(value):
advanced_value = value['value']
condition = value['condition']
sqlalchemy_filter = get_sqlalchemy_filter(op, advanced_value)
if sqlalchemy_filter is not None:
condition_filters = []
for cond_op, cond_value in condition.items():
condition_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
condition_filters.append(
condition_filter(sqlalchemy_filter(column)(advanced_value))(cond_value)
if cond_op != 'between'
else condition_filter(sqlalchemy_filter(column)(advanced_value))(*cond_value)
)
filters.append(and_(*condition_filters))

# OR GROUP
if field_name == '__gor' and op == '':
_or_filters = []
for field_or in value:
for _key, _value in field_or.items():
_field_name, _op = _key.rsplit('__', 1)
_column = get_column(model, _field_name)
process_filters(_column, _op, _value)
if _or_filters:
filters.append(or_(*_or_filters))
else:
sqlalchemy_filter = get_sqlalchemy_filter(op, value)
if sqlalchemy_filter is not None:
filters.append(
sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value)
)
column = get_column(model, field_name)
process_filters(column, op, value)
else:
# NON FILTER
column = get_column(model, key)
filters.append(column == value)

Expand Down
25 changes: 24 additions & 1 deletion tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def test_select_model_by_column_with_ne(create_test_model, async_db_sessio
async def test_select_model_by_column_with_between(create_test_model, async_db_session):
async with async_db_session() as session:
crud = CRUDPlus(Ins)
result = await crud.select_model_by_column(session, id__between=(0, 11))
result = await crud.select_model_by_column(session, id__between=(0, 10))
assert result.id == 1


Expand Down Expand Up @@ -338,6 +338,29 @@ async def test_select_model_by_column_with_or(create_test_model, async_db_sessio
assert result.id == 1


@pytest.mark.asyncio
async def test_select_model_by_column_with_mor(create_test_model, async_db_session):
async with async_db_session() as session:
crud = CRUDPlus(Ins)
result = await crud.select_model_by_column(session, id__mor={'eq': [1, 2, 3, 4, 5, 6, 7, 8, 9]})
assert result.id == 1


@pytest.mark.asyncio
async def test_select_model_by_column_with___gor__(create_test_model, async_db_session):
async with async_db_session() as session:
crud = CRUDPlus(Ins)
result = await crud.select_model_by_column(
session,
__gor__=[
{'id__eq': 1},
{'name__mor': {'endswith': ['1', '2']}},
{'id__mul': {'value': 1, 'condition': {'eq': 1}}},
],
)
assert result.id == 1


@pytest.mark.asyncio
async def test_select(create_test_model):
crud = CRUDPlus(Ins)
Expand Down