diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index ccae54b4..3b2eb379 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -807,6 +807,7 @@ def paginate( *, page: int | None = None, per_page: int | None = None, + default_per_page: int = 20, max_per_page: int | None = None, error_out: bool = True, count: bool = True, @@ -843,6 +844,7 @@ def paginate( session=self.session(), page=page, per_page=per_page, + default_per_page=default_per_page, max_per_page=max_per_page, error_out=error_out, count=count, diff --git a/src/flask_sqlalchemy/pagination.py b/src/flask_sqlalchemy/pagination.py index 3d49d6e0..0ebfd613 100644 --- a/src/flask_sqlalchemy/pagination.py +++ b/src/flask_sqlalchemy/pagination.py @@ -47,15 +47,21 @@ def __init__( self, page: int | None = None, per_page: int | None = None, + default_per_page: int = 20, max_per_page: int | None = 100, error_out: bool = True, count: bool = True, **kwargs: t.Any, ) -> None: self._query_args = kwargs + + self.default_per_page: int = default_per_page + """The default number of items on a page.""" + page, per_page = self._prepare_page_args( page=page, per_page=per_page, + default_per_page=self.default_per_page, max_per_page=max_per_page, error_out=error_out, ) @@ -92,6 +98,7 @@ def _prepare_page_args( *, page: int | None = None, per_page: int | None = None, + default_per_page: int, max_per_page: int | None = None, error_out: bool = True, ) -> tuple[int, int]: @@ -112,13 +119,13 @@ def _prepare_page_args( if error_out: abort(404) - per_page = 20 + per_page = default_per_page else: if page is None: page = 1 if per_page is None: - per_page = 20 + per_page = default_per_page if max_per_page is not None: per_page = min(per_page, max_per_page) @@ -133,7 +140,7 @@ def _prepare_page_args( if error_out: abort(404) else: - per_page = 20 + per_page = default_per_page return page, per_page diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 14e24a9e..344cd0ae 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -12,14 +12,25 @@ class RangePagination(Pagination): def __init__( - self, total: int | None = 150, page: int = 1, per_page: int = 10 + self, + total: int | None = 150, + page: int = 1, + per_page: int | None = 10, + default_per_page: int = 20, + error_out: bool = True, ) -> None: if total is None: self._data = range(150) else: self._data = range(total) - super().__init__(total=total, page=page, per_page=per_page) + super().__init__( + total=total, + page=page, + per_page=per_page, + default_per_page=default_per_page, + error_out=error_out, + ) if total is None: self.total = None @@ -37,6 +48,7 @@ def test_first_page() -> None: p = RangePagination() assert p.page == 1 assert p.per_page == 10 + assert p.default_per_page == 20 assert p.total == 150 assert p.pages == 15 assert not p.has_prev @@ -74,6 +86,16 @@ def test_item_numbers_0() -> None: assert p.last == 0 +def test_default_per_page_invalid_per_page() -> None: + p = RangePagination(per_page=0, default_per_page=10, error_out=False) + assert p.per_page == 10 + + +def test_default_per_page_none() -> None: + p = RangePagination(per_page=None) + assert p.per_page == 20 + + @pytest.mark.parametrize("total", [0, None]) def test_0_pages(total: int | None) -> None: p = RangePagination(total=total)