Compare commits

...

141 Commits

Author SHA1 Message Date
8a65754f9f fix: cleanup + simplify 2026-04-07 05:46:25 -04:00
f9579e6df0 docs: add authentication example 2026-04-07 05:46:25 -04:00
2f6036c762 feat(security): add oauth helpers 2026-04-07 05:46:25 -04:00
f37bf59996 feat: add security module 2026-04-07 05:46:25 -04:00
0ed93d62c8 Version 3.0.3 2026-04-07 05:43:10 -04:00
d3vyce
2a49814818 feat: normalize enum facet values to member names and accept name strings in filter_by (#235) 2026-04-07 11:42:24 +02:00
d3vyce
f8e090c7c3 ci: restrict ci.yml workflow permissions to contents: read (#233) 2026-04-07 10:44:59 +02:00
d3vyce
54decaf3e1 fix: coerce plain int values to enum member when filtering on IntEnum column (#231) 2026-04-07 10:43:01 +02:00
6b127d9645 Version 3.0.2 2026-04-04 19:13:09 -04:00
d3vyce
8bed96f4bf fix: apply default_load_options after create() and update() to prevent MissingGreenlet on relationship access (#229) 2026-04-04 20:47:05 +02:00
d3vyce
74d15e13bc feat: support relation tuples in order_fields for cross-table sorting (#227)
* feat: support relation tuples in order_fields for cross-table sorting

* docs: update crud module
2026-04-04 17:00:14 +02:00
d3vyce
e38d8d2d4f feat: expose sort_columns in paginated response (#225) 2026-04-04 14:46:36 +02:00
9b74f162ab Version 3.0.1 2026-04-03 05:44:10 -04:00
d3vyce
ab125c6ea1 docs: rework versioning to keep latest feature version per major (#223) 2026-04-03 11:43:36 +02:00
d3vyce
e388e26858 fix: widen JoinType to accept aliased and polymorphic targets (#221) 2026-04-02 22:58:52 +02:00
d3vyce
04da241294 fix: coerce string values to bool for Boolean facet field filtering (#219) 2026-04-02 22:58:07 +02:00
d3vyce
bbe63edc46 Version 3.0.0 (#201)
* chore: remove deprecated code

* docs: update v3 migration guide

* fix: pytest warnings

* Version 3.0.0

* fix: docs workflows
2026-04-02 11:21:31 +02:00
d3vyce
0b17c77dee fix: deduplicate relationship joins when searchable_fields and facet_fields reference the same model (#217) 2026-04-02 11:09:26 +02:00
dependabot[bot]
bce71bfd42 ⬆ Bump ty from 0.0.25 to 0.0.27 (#215)
Bumps [ty](https://github.com/astral-sh/ty) from 0.0.25 to 0.0.27.
- [Release notes](https://github.com/astral-sh/ty/releases)
- [Changelog](https://github.com/astral-sh/ty/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ty/compare/0.0.25...0.0.27)

---
updated-dependencies:
- dependency-name: ty
  dependency-version: 0.0.27
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-02 11:01:08 +02:00
dependabot[bot]
2f1eb4d468 ⬆ Bump fastapi from 0.135.1 to 0.135.3 (#214)
Bumps [fastapi](https://github.com/fastapi/fastapi) from 0.135.1 to 0.135.3.
- [Release notes](https://github.com/fastapi/fastapi/releases)
- [Commits](https://github.com/fastapi/fastapi/compare/0.135.1...0.135.3)

---
updated-dependencies:
- dependency-name: fastapi
  dependency-version: 0.135.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-02 11:00:59 +02:00
dependabot[bot]
1f06eab11d ⬆ Bump ruff from 0.15.7 to 0.15.8 (#213)
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.15.7 to 0.15.8.
- [Release notes](https://github.com/astral-sh/ruff/releases)
- [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ruff/compare/0.15.7...0.15.8)

---
updated-dependencies:
- dependency-name: ruff
  dependency-version: 0.15.8
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-02 11:00:44 +02:00
dependabot[bot]
fac9aa6f60 ⬆ Bump zensical from 0.0.30 to 0.0.31 (#212)
Bumps [zensical](https://github.com/zensical/zensical) from 0.0.30 to 0.0.31.
- [Release notes](https://github.com/zensical/zensical/releases)
- [Commits](https://github.com/zensical/zensical/compare/v0.0.30...v0.0.31)

---
updated-dependencies:
- dependency-name: zensical
  dependency-version: 0.0.31
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-02 11:00:35 +02:00
dependabot[bot]
f310466697 ⬆ Bump codecov/codecov-action from 5 to 6 (#211)
Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 5 to 6.
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/codecov/codecov-action/compare/v5...v6)

---
updated-dependencies:
- dependency-name: codecov/codecov-action
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-02 11:00:22 +02:00
d3vyce
32059dcb02 feat: consolidate *_params dependencies into per-paginate-style methods with feature toggles (#209) 2026-04-01 20:53:14 +02:00
d3vyce
f027981e80 feat: add search_column parameter and search_columns response field for targeted search (#207) 2026-04-01 18:10:56 +02:00
d3vyce
5c1487c24a fix: suppress UPDATE callbacks for objects deleted in the same transaction (#205) 2026-03-31 21:40:18 +02:00
d3vyce
ebaa61525f fix: handle boolean and ARRAY column types in filter_by facet filtering (#203) 2026-03-31 21:36:54 +02:00
dependabot[bot]
4829cfba73 ⬆ Bump pygments from 2.19.2 to 2.20.0 (#199)
Bumps [pygments](https://github.com/pygments/pygments) from 2.19.2 to 2.20.0.
- [Release notes](https://github.com/pygments/pygments/releases)
- [Changelog](https://github.com/pygments/pygments/blob/master/CHANGES)
- [Commits](https://github.com/pygments/pygments/compare/2.19.2...2.20.0)

---
updated-dependencies:
- dependency-name: pygments
  dependency-version: 2.20.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 20:05:52 +02:00
d3vyce
9ca2da4213 docs: add documentation versioning (#125)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-03-30 18:55:38 +02:00
d3vyce
0b3f097012 fix: batch insert normalizes away omitted nullable columns (#198) 2026-03-30 18:50:05 +02:00
d3vyce
1890d696bf feat: rework async event system (#196)
* feat: rework async event system

* docs: add v3 migration guide

* feat: add cache

* enhancements
2026-03-30 18:24:36 +02:00
d3vyce
104285c6e5 Create CNAME 2026-03-29 13:57:15 +02:00
d3vyce
f5afbbe37f fix: snapshot nullable columns correctly in WatchedFieldsMixin callback (#194) 2026-03-28 18:47:06 +01:00
d3vyce
f4698bea8a fix: normalize batch insert rows to prevent silent data loss for nullable columns (#192) 2026-03-27 19:20:41 +01:00
d3vyce
5215b921ae fix: facet keys always use full relation chain (#190) 2026-03-27 18:56:58 +01:00
9dad59e25d tests: close generator instead of session in test_get_db_yields_async_session 2026-03-26 16:09:44 -04:00
d3vyce
29326ab532 perf: batch insert fixtures (#188) 2026-03-26 20:29:25 +01:00
d3vyce
04afef7e33 feat(fixtures): fixtures multi-variant contexts, custom Enum support, and context-filtered loading (#187) 2026-03-26 20:19:41 +01:00
d3vyce
666c621fda fix: create_db_session commits via real transaction, not savepoint (#184) 2026-03-26 19:57:40 +01:00
460b760fa4 Version 2.4.3 2026-03-26 07:58:32 -04:00
dependabot[bot]
65d0b0e0b1 ⬆ Bump ty from 0.0.23 to 0.0.25 (#178)
* ⬆ Bump ty from 0.0.23 to 0.0.25

Bumps [ty](https://github.com/astral-sh/ty) from 0.0.23 to 0.0.25.
- [Release notes](https://github.com/astral-sh/ty/releases)
- [Changelog](https://github.com/astral-sh/ty/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ty/compare/0.0.23...0.0.25)

---
updated-dependencies:
- dependency-name: ty
  dependency-version: 0.0.25
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix: ty warnings

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: d3vyce <nicolas.sudres@proton.me>
2026-03-26 12:57:29 +01:00
dependabot[bot]
2d49cd32db ⬆ Update uv-build requirement from <0.11.0,>=0.10 to >=0.10,<0.12.0 (#181) 2026-03-26 09:43:52 +01:00
dependabot[bot]
a5dd756d87 ⬆ Bump actions/deploy-pages from 4 to 5 (#177) 2026-03-26 09:43:40 +01:00
dependabot[bot]
781cfb66c9 ⬆ Bump zensical from 0.0.27 to 0.0.29 (#179) 2026-03-26 09:42:53 +01:00
dependabot[bot]
91b84f8146 ⬆ Bump ruff from 0.15.6 to 0.15.7 (#180) 2026-03-26 09:42:38 +01:00
dependabot[bot]
396e381ac3 ⬆ Bump pytest-cov from 7.0.0 to 7.1.0 (#182) 2026-03-26 09:41:59 +01:00
d3vyce
b4eb4c1ca9 fix: force auto-begin in create_db_dependency so lock_tables always uses savepoints (#176) 2026-03-25 19:26:28 +01:00
c90717754f chore: add prek (pre-commit alternative) 2026-03-25 14:24:30 -04:00
337985ef38 Version 2.4.2 2026-03-24 15:39:14 -04:00
d3vyce
b5e6dfe6fe refactor: test suite cleanup and simplification (#174) 2026-03-24 20:38:35 +01:00
d3vyce
6681b7ade7 fix: defer on_create/on_update/on_delete dispatch until outermost transaction commits (#172) 2026-03-24 19:56:03 +01:00
d3vyce
6981c33dc8 fix: inherit @watch field filter from parent classes via MRO traversal (#170) 2026-03-23 19:08:17 +01:00
d3vyce
0c7a99039c fix: await any awaitable callback return value, not only coroutines (#168) 2026-03-23 18:58:48 +01:00
d3vyce
bcb5b0bfda fix: suppress on_create/on_delete for objects created and deleted within the same transaction (#166) 2026-03-23 18:51:28 +01:00
100e1c1aa9 Version 2.4.1 2026-03-21 11:48:17 -04:00
d3vyce
db6c7a565f feat: add offset_params, cursor_params and paginate_params FastAPI dependency factories (#162) 2026-03-21 16:44:11 +01:00
d3vyce
768e405554 fix: use URL-safe base64 encoding for cursor tokens (#160) 2026-03-21 15:33:17 +01:00
d3vyce
f0223ebde4 feat: add pages computed field to OffsetPagination schema (#159) 2026-03-21 15:24:11 +01:00
d3vyce
f8c9bf69fe feat: add include_total flag to offset pagination to skip COUNT query (#158) 2026-03-21 15:16:22 +01:00
6d6fae5538 Version 2.4.0 2026-03-20 15:54:14 -04:00
d3vyce
fc9cd1f034 fix: resolve MissingGreenlet error when accessing self attributes in WatchedFieldsMixin callbacks (#154) 2026-03-20 20:52:11 +01:00
d3vyce
f82225f995 feat: add WatchedFieldsMixin (#148)
* feat/add WatchedFieldsMixin and watch_fields decorator for field-change monitoring

* docs: add WatchedFieldsMixin

* feat: add on_event, on_create and on_delete

* docs: update README
2026-03-19 19:19:33 +01:00
dependabot[bot]
e62612a93a ⬆ Bump coverage from 7.13.4 to 7.13.5 (#149)
Bumps [coverage](https://github.com/coveragepy/coveragepy) from 7.13.4 to 7.13.5.
- [Release notes](https://github.com/coveragepy/coveragepy/releases)
- [Changelog](https://github.com/coveragepy/coveragepy/blob/main/CHANGES.rst)
- [Commits](https://github.com/coveragepy/coveragepy/compare/7.13.4...7.13.5)

---
updated-dependencies:
- dependency-name: coverage
  dependency-version: 7.13.5
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-19 19:11:03 +01:00
dependabot[bot]
56f0ea291e ⬆ Bump zensical from 0.0.26 to 0.0.27 (#150)
Bumps [zensical](https://github.com/zensical/zensical) from 0.0.26 to 0.0.27.
- [Release notes](https://github.com/zensical/zensical/releases)
- [Commits](https://github.com/zensical/zensical/compare/v0.0.26...v0.0.27)

---
updated-dependencies:
- dependency-name: zensical
  dependency-version: 0.0.27
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-19 19:10:52 +01:00
dependabot[bot]
ee896009ee ⬆ Bump ty from 0.0.21 to 0.0.23 (#151)
Bumps [ty](https://github.com/astral-sh/ty) from 0.0.21 to 0.0.23.
- [Release notes](https://github.com/astral-sh/ty/releases)
- [Changelog](https://github.com/astral-sh/ty/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ty/compare/0.0.21...0.0.23)

---
updated-dependencies:
- dependency-name: ty
  dependency-version: 0.0.23
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-19 19:10:41 +01:00
dependabot[bot]
65bf928e12 ⬆ Bump ruff from 0.15.5 to 0.15.6 (#152)
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.15.5 to 0.15.6.
- [Release notes](https://github.com/astral-sh/ruff/releases)
- [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ruff/compare/0.15.5...0.15.6)

---
updated-dependencies:
- dependency-name: ruff
  dependency-version: 0.15.6
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-19 19:10:28 +01:00
d3vyce
2e9c6c0c90 feat: support Annotated[AsyncSession, Depends(...)] in PathDependency and BodyDependency (#146) 2026-03-18 20:10:58 +01:00
2c494fcd17 Version 2.3.0 2026-03-15 12:52:44 -04:00
d3vyce
fd7269a372 refactor: CursorDirection enum, cursor value parsing and __class_getitem__ caching (#144) 2026-03-15 17:48:50 +01:00
d3vyce
c863744012 feat: PaginatedResponse[T] as discriminated union annotation (#142) 2026-03-15 17:41:33 +01:00
d3vyce
aedcbf4e04 feat: add UUIDv7Mixin (#140) 2026-03-14 20:41:46 +01:00
d3vyce
19c013bdec feat: unified paginate() endpoint with typed pagination responses (#134)
* feat: unified paginate() endpoint with typed pagination responses

* docs: unified paginate() endpoint

* fix: add tests
2026-03-14 20:23:20 +01:00
d3vyce
81407c3038 fix: use clock_timestamp() instead of now() for Mixin to ensure unique values (#138) 2026-03-14 17:30:11 +01:00
d3vyce
0fb00d44da fix: prev_cursor does not navigate backward (#136) 2026-03-14 17:18:53 +01:00
d3vyce
19232d3436 feat: add AsyncCrud subclass style and base_class param to CrudFactory (#132) 2026-03-12 22:46:51 +01:00
1eafcb3873 Version 2.2.1 2026-03-12 15:39:44 -04:00
dependabot[bot]
0d67fbb58d ⬆ Bump ty from 0.0.20 to 0.0.21 (#127)
* ⬆ Bump ty from 0.0.20 to 0.0.21

Bumps [ty](https://github.com/astral-sh/ty) from 0.0.20 to 0.0.21.
- [Release notes](https://github.com/astral-sh/ty/releases)
- [Changelog](https://github.com/astral-sh/ty/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ty/compare/0.0.20...0.0.21)

---
updated-dependencies:
- dependency-name: ty
  dependency-version: 0.0.21
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix: ty warnings

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: d3vyce <nicolas.sudres@proton.me>
2026-03-12 20:29:02 +01:00
dependabot[bot]
a59f098930 ⬆ Bump zensical from 0.0.24 to 0.0.26 (#126)
Bumps [zensical](https://github.com/zensical/zensical) from 0.0.24 to 0.0.26.
- [Release notes](https://github.com/zensical/zensical/releases)
- [Commits](https://github.com/zensical/zensical/compare/v0.0.24...v0.0.26)

---
updated-dependencies:
- dependency-name: zensical
  dependency-version: 0.0.26
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-12 20:22:02 +01:00
dependabot[bot]
96e34ba8af ⬆ Bump ruff from 0.15.4 to 0.15.5 (#128)
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.15.4 to 0.15.5.
- [Release notes](https://github.com/astral-sh/ruff/releases)
- [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ruff/compare/0.15.4...0.15.5)

---
updated-dependencies:
- dependency-name: ruff
  dependency-version: 0.15.5
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-12 20:21:50 +01:00
d3vyce
26d649791f feat: auto-include primary key in CrudFactory searchable_fields (#130) 2026-03-12 20:21:32 +01:00
dde5183e68 Version 2.2.0 2026-03-10 14:53:55 -04:00
d3vyce
e4250a9910 feat: bring first() to parity with get()/get_or_none() — add with_for_update and schema support (#123) 2026-03-10 19:34:18 +01:00
d3vyce
4800941934 fix: cascade delete M2M association rows via ORM session (#121) 2026-03-10 19:18:16 +01:00
0cc21d2012 Version 2.1.0 2026-03-09 12:45:52 -04:00
d3vyce
a3245d50f0 docs: clarify metrics module usage (#117) 2026-03-09 17:17:11 +01:00
d3vyce
baebf022f6 feat: move db related function from pytest to db module (#119) 2026-03-09 17:15:50 +01:00
dependabot[bot]
96d445e3f3 ⬆ Bump zensical from 0.0.23 to 0.0.24 (#112)
Bumps [zensical](https://github.com/zensical/zensical) from 0.0.23 to 0.0.24.
- [Release notes](https://github.com/zensical/zensical/releases)
- [Commits](https://github.com/zensical/zensical/compare/v0.0.23...v0.0.24)

---
updated-dependencies:
- dependency-name: zensical
  dependency-version: 0.0.24
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-05 14:57:58 +01:00
dependabot[bot]
80306e1af3 ⬆ Bump ruff from 0.15.2 to 0.15.4 (#113)
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.15.2 to 0.15.4.
- [Release notes](https://github.com/astral-sh/ruff/releases)
- [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ruff/compare/0.15.2...0.15.4)

---
updated-dependencies:
- dependency-name: ruff
  dependency-version: 0.15.4
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-05 14:57:48 +01:00
dependabot[bot]
fd999b63f1 ⬆ Bump ty from 0.0.18 to 0.0.20 (#114)
Bumps [ty](https://github.com/astral-sh/ty) from 0.0.18 to 0.0.20.
- [Release notes](https://github.com/astral-sh/ty/releases)
- [Changelog](https://github.com/astral-sh/ty/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ty/compare/0.0.18...0.0.20)

---
updated-dependencies:
- dependency-name: ty
  dependency-version: 0.0.20
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-05 14:57:37 +01:00
dependabot[bot]
c0f352b914 ⬆ Bump fastapi from 0.133.1 to 0.135.1 (#115)
Bumps [fastapi](https://github.com/fastapi/fastapi) from 0.133.1 to 0.135.1.
- [Release notes](https://github.com/fastapi/fastapi/releases)
- [Commits](https://github.com/fastapi/fastapi/compare/0.133.1...0.135.1)

---
updated-dependencies:
- dependency-name: fastapi
  dependency-version: 0.135.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-05 14:57:26 +01:00
c4c760484b Version 2.0.0 2026-03-04 11:20:46 -05:00
d3vyce
432e0722e0 refactor: remove deprecated parameter and function/cleanup code (#101)
* refactor: remove deprecated parameter and function

* refactor: centralize type aliases in types.py and simplify crud layer

* test: add missing tests for fixtures/utils.py

* refactor: simplify and deduplicate across crud, metrics, cli, and
exceptions

* docs: fix old Paginate references

* docs: add migration + fix icon

* docs: update README/migration to v2
2026-03-04 17:19:38 +01:00
d3vyce
e732e54518 feat: add models module (#109)
* feat: add models module

* docs: add models module
2026-03-04 15:14:55 +01:00
d3vyce
05b5a2c876 feat: rework Exception/ApiError (#107)
* feat: rework Exception/ApiError

* docs: update exceptions module

* fix: docstring
2026-03-02 16:34:29 +01:00
d3vyce
4a020c56d1 feat: Raise NotFoundError instead of LookupError in wait_for_row_change (#105) 2026-03-02 14:28:37 +01:00
56d365d14b Version 1.3.0 2026-03-01 05:22:16 -05:00
d3vyce
a257d85d45 Add sort_params helper in CrudFactory (#103)
* feat: add sort_params helper in CrudFactory

* docs: add sorting

* fix: change sort_by to order_by
2026-03-01 11:20:43 +01:00
117675d02f Version 1.2.1 2026-02-27 13:57:03 -05:00
d3vyce
d7ad7308c5 Add examples in documentations (#99)
* docs: fix crud

* docs: update README features

* docs: add pagination/search example

* docs: update zensical.toml

* docs: cleanup

* docs: update status to Stable + update description

* docs: add example run commands
2026-02-27 19:56:09 +01:00
8d57bf9525 Version 1.2.0 2026-02-26 09:34:35 -05:00
d3vyce
5a08ec2f57 feat: add faceted search in CrudFactory (#97)
* feat: add faceted search in CrudFactory

* feat: add filter_params_schema in CrudFactory

* fix: add missing Raises in build_search_filters docstring

* fix: faceted search

* fix: cov

* fix: documentation/filter_params
2026-02-26 15:23:07 +01:00
dependabot[bot]
433dc55fcd ⬆ Bump typer from 0.24.0 to 0.24.1 (#92)
Bumps [typer](https://github.com/fastapi/typer) from 0.24.0 to 0.24.1.
- [Release notes](https://github.com/fastapi/typer/releases)
- [Changelog](https://github.com/fastapi/typer/blob/master/docs/release-notes.md)
- [Commits](https://github.com/fastapi/typer/compare/0.24.0...0.24.1)

---
updated-dependencies:
- dependency-name: typer
  dependency-version: 0.24.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-26 10:53:00 +01:00
dependabot[bot]
0b2abd8c43 ⬆ Bump ruff from 0.15.1 to 0.15.2 (#93)
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.15.1 to 0.15.2.
- [Release notes](https://github.com/astral-sh/ruff/releases)
- [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ruff/compare/0.15.1...0.15.2)

---
updated-dependencies:
- dependency-name: ruff
  dependency-version: 0.15.2
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-26 10:52:47 +01:00
dependabot[bot]
07c99be89b ⬆ Bump mkdocstrings-python from 2.0.2 to 2.0.3 (#94)
Bumps [mkdocstrings-python](https://github.com/mkdocstrings/python) from 2.0.2 to 2.0.3.
- [Release notes](https://github.com/mkdocstrings/python/releases)
- [Changelog](https://github.com/mkdocstrings/python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/mkdocstrings/python/compare/2.0.2...2.0.3)

---
updated-dependencies:
- dependency-name: mkdocstrings-python
  dependency-version: 2.0.3
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-26 10:52:37 +01:00
dependabot[bot]
9b75cc7dfc ⬆ Bump ty from 0.0.17 to 0.0.18 (#95)
Bumps [ty](https://github.com/astral-sh/ty) from 0.0.17 to 0.0.18.
- [Release notes](https://github.com/astral-sh/ty/releases)
- [Changelog](https://github.com/astral-sh/ty/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ty/compare/0.0.17...0.0.18)

---
updated-dependencies:
- dependency-name: ty
  dependency-version: 0.0.18
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-26 10:52:23 +01:00
dependabot[bot]
6144b383eb ⬆ Bump fastapi from 0.129.0 to 0.133.1 (#96)
Bumps [fastapi](https://github.com/fastapi/fastapi) from 0.129.0 to 0.133.1.
- [Release notes](https://github.com/fastapi/fastapi/releases)
- [Commits](https://github.com/fastapi/fastapi/compare/0.129.0...0.133.1)

---
updated-dependencies:
- dependency-name: fastapi
  dependency-version: 0.133.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-26 10:52:02 +01:00
7ec407834a Version 1.1.2 2026-02-23 14:59:33 -05:00
d3vyce
7da34f33a2 fix: handle Date, Float, Numeric cursor column types in cursor_paginate (#90) 2026-02-23 20:58:43 +01:00
8c8911fb27 Version 1.1.1 2026-02-23 11:04:10 -05:00
d3vyce
c0c3b38054 fix: NameError in cli/config.py (#88) 2026-02-23 14:40:36 +01:00
e17d385910 Version 1.1.0 2026-02-23 08:09:59 -05:00
d3vyce
6cf7df55ef feat: add cursor based pagination in CrudFactory (#86) 2026-02-23 13:51:34 +01:00
d3vyce
7482bc5dad feat: add schema parameter to CRUD methods for typed response serialization (#84) 2026-02-23 10:02:52 +01:00
d3vyce
9d07dfea85 feat: add opt-in default_load_options parameter in CrudFactory (#82)
* feat: add opt-in default_load_options parameter in CrudFactory

* docs: add Relationship loading in CRUD
2026-02-21 12:35:15 +01:00
d3vyce
31678935aa Version 1.0.0 (#80)
* docs: fix typos

* chore: build docs only when release

* Version 1.0.0
2026-02-20 14:09:01 +01:00
d3vyce
823a0b3e36 docs: add analytics (#78) 2026-02-19 18:05:30 +01:00
1591cd3d64 fix: documentation deploy workflow 2026-02-19 10:46:56 -05:00
d3vyce
6714ceeb92 chore: documentation (#76)
* chore: update docstring example to use python code block

* docs: add documentation

* feat: add docs build + fix other workdlows

* fix: add missing return type
2026-02-19 16:43:38 +01:00
d3vyce
73fae04333 chore: cleanup before v1 (#73)
* chore: move dependencies module to the project root

* chore:update README

* chore: clean conftest

* chore: remove old code + comment

* fix: uv.lock dependencies
2026-02-19 11:49:57 +01:00
d3vyce
32ed36e102 feat: add proper optional-dependencies for each modules (#75) 2026-02-19 11:08:18 +01:00
dependabot[bot]
48567310bc ⬆ Bump ty from 0.0.16 to 0.0.17 (#68)
Bumps [ty](https://github.com/astral-sh/ty) from 0.0.16 to 0.0.17.
- [Release notes](https://github.com/astral-sh/ty/releases)
- [Changelog](https://github.com/astral-sh/ty/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ty/compare/0.0.16...0.0.17)

---
updated-dependencies:
- dependency-name: ty
  dependency-version: 0.0.17
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-19 10:34:31 +01:00
dependabot[bot]
de51ed4675 ⬆ Bump fastapi from 0.128.8 to 0.129.0 (#69)
Bumps [fastapi](https://github.com/fastapi/fastapi) from 0.128.8 to 0.129.0.
- [Release notes](https://github.com/fastapi/fastapi/releases)
- [Commits](https://github.com/fastapi/fastapi/compare/0.128.8...0.129.0)

---
updated-dependencies:
- dependency-name: fastapi
  dependency-version: 0.129.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-19 10:34:17 +01:00
dependabot[bot]
794767edbb ⬆ Bump ruff from 0.15.0 to 0.15.1 (#70)
Bumps [ruff](https://github.com/astral-sh/ruff) from 0.15.0 to 0.15.1.
- [Release notes](https://github.com/astral-sh/ruff/releases)
- [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ruff/compare/0.15.0...0.15.1)

---
updated-dependencies:
- dependency-name: ruff
  dependency-version: 0.15.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-19 10:34:06 +01:00
dependabot[bot]
9c136f05bb ⬆ Bump typer from 0.23.0 to 0.24.0 (#71)
Bumps [typer](https://github.com/fastapi/typer) from 0.23.0 to 0.24.0.
- [Release notes](https://github.com/fastapi/typer/releases)
- [Changelog](https://github.com/fastapi/typer/blob/master/docs/release-notes.md)
- [Commits](https://github.com/fastapi/typer/compare/0.23.0...0.24.0)

---
updated-dependencies:
- dependency-name: typer
  dependency-version: 0.24.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-19 10:33:54 +01:00
3299a439fe Version 0.10.0 2026-02-17 07:28:10 -05:00
d3vyce
d5b22a72fd feat: add a metrics module (#67) 2026-02-17 13:24:53 +01:00
d3vyce
c32f2e18be feat: add many to many support in CrudFactory (#65) 2026-02-15 15:57:15 +01:00
d971261f98 Version 0.9.0 2026-02-14 14:38:58 -05:00
d3vyce
74a54b7396 feat: add optional data field in ApiError (#63) 2026-02-14 20:37:50 +01:00
d3vyce
19805ab376 feat: add dependency_overrides parameter to create_async_client (#61) 2026-02-13 18:11:11 +01:00
d3vyce
d4498e2063 feat: add cleanup parameter to create_db_session (#60) 2026-02-13 18:03:28 +01:00
f59c1a17e2 Version 0.8.1 2026-02-12 18:18:37 +01:00
dependabot[bot]
8982ba18e3 ⬆ Bump typer from 0.21.1 to 0.23.0 (#54)
Bumps [typer](https://github.com/fastapi/typer) from 0.21.1 to 0.23.0.
- [Release notes](https://github.com/fastapi/typer/releases)
- [Changelog](https://github.com/fastapi/typer/blob/master/docs/release-notes.md)
- [Commits](https://github.com/fastapi/typer/compare/0.21.1...0.23.0)

---
updated-dependencies:
- dependency-name: typer
  dependency-version: 0.23.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-12 18:13:04 +01:00
dependabot[bot]
71fe6f478f ⬆ Bump fastapi from 0.128.1 to 0.128.8 (#52)
Bumps [fastapi](https://github.com/fastapi/fastapi) from 0.128.1 to 0.128.8.
- [Release notes](https://github.com/fastapi/fastapi/releases)
- [Commits](https://github.com/fastapi/fastapi/compare/0.128.1...0.128.8)

---
updated-dependencies:
- dependency-name: fastapi
  dependency-version: 0.128.8
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-12 18:10:41 +01:00
dependabot[bot]
1cfbf14986 ⬆ Bump coverage from 7.13.3 to 7.13.4 (#53)
Bumps [coverage](https://github.com/coveragepy/coveragepy) from 7.13.3 to 7.13.4.
- [Release notes](https://github.com/coveragepy/coveragepy/releases)
- [Changelog](https://github.com/coveragepy/coveragepy/blob/main/CHANGES.rst)
- [Commits](https://github.com/coveragepy/coveragepy/compare/7.13.3...7.13.4)

---
updated-dependencies:
- dependency-name: coverage
  dependency-version: 7.13.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-12 18:10:27 +01:00
dependabot[bot]
e3ff535b7e ⬆ Bump ty from 0.0.14 to 0.0.16 (#55)
Bumps [ty](https://github.com/astral-sh/ty) from 0.0.14 to 0.0.16.
- [Release notes](https://github.com/astral-sh/ty/releases)
- [Changelog](https://github.com/astral-sh/ty/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ty/compare/0.0.14...0.0.16)

---
updated-dependencies:
- dependency-name: ty
  dependency-version: 0.0.16
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-12 18:10:02 +01:00
d3vyce
8825c772ce doc: add missing docstring + add missing feature to README (#57) 2026-02-12 18:09:39 +01:00
d3vyce
c8c263ca8f fix: don't use default DB if pytest-xdist is not present (#51) 2026-02-11 17:07:15 +01:00
2020fa2f92 Version 0.8.0 2026-02-10 15:53:13 -05:00
d3vyce
1ea316bef4 feat: add wait_for_row_change db helper (#49) 2026-02-10 21:46:59 +01:00
d3vyce
ced1a655f2 feat: add support for pytest-xdist (#47) 2026-02-10 21:29:17 +01:00
109 changed files with 19657 additions and 910 deletions

View File

@@ -20,7 +20,7 @@ jobs:
run: uv python install 3.14
- name: Install dependencies
run: uv sync
run: uv sync --group dev
- name: Build
run: uv build

View File

@@ -6,6 +6,9 @@ on:
pull_request:
branches: [main]
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
@@ -24,7 +27,7 @@ jobs:
run: uv python install 3.13
- name: Install dependencies
run: uv sync --extra dev
run: uv sync --group dev
- name: Run Ruff linter
run: uv run ruff check .
@@ -45,7 +48,7 @@ jobs:
run: uv python install 3.13
- name: Install dependencies
run: uv sync --extra dev
run: uv sync --group dev
- name: Run ty
run: uv run ty check
@@ -83,7 +86,7 @@ jobs:
run: uv python install ${{ matrix.python-version }}
- name: Install dependencies
run: uv sync --extra dev
run: uv sync --group dev
- name: Run tests with coverage
env:
@@ -93,7 +96,7 @@ jobs:
- name: Upload coverage to Codecov
if: matrix.python-version == '3.14'
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v6
with:
token: ${{ secrets.CODECOV_TOKEN }}
report_type: coverage
@@ -102,7 +105,7 @@ jobs:
- name: Upload test results to Codecov
if: matrix.python-version == '3.14'
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v6
with:
token: ${{ secrets.CODECOV_TOKEN }}
report_type: test_results

53
.github/workflows/docs.yml vendored Normal file
View File

@@ -0,0 +1,53 @@
name: Documentation
on:
release:
types: [published]
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Set up Python
run: uv python install 3.13
- run: uv sync --group dev
- name: Configure git
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Deploy documentation
run: |
VERSION=${GITHUB_REF_NAME#v}
MAJOR=$(echo "$VERSION" | cut -d. -f1)
DEPLOY_VERSION="v$(echo "$VERSION" | cut -d. -f1-2)"
# On new major: keep only the latest feature version of the previous major
PREV_MAJOR=$((MAJOR - 1))
OLD_FEATURE_VERSIONS=$(uv run mike list 2>/dev/null | grep -oE "^v${PREV_MAJOR}\.[0-9]+" || true)
if [ -n "$OLD_FEATURE_VERSIONS" ]; then
LATEST_PREV=$(echo "$OLD_FEATURE_VERSIONS" | sort -t. -k2 -n | tail -1)
echo "$OLD_FEATURE_VERSIONS" | while read -r OLD_V; do
if [ "$OLD_V" != "$LATEST_PREV" ]; then
echo "Deleting $OLD_V"
uv run mike delete "$OLD_V"
fi
done
fi
uv run mike deploy --update-aliases "$DEPLOY_VERSION" stable
uv run mike set-default stable
git push origin gh-pages

34
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,34 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-added-large-files
args: ["--maxkb=750"]
exclude: ^uv.lock$
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: local
hooks:
- id: local-ruff-check
name: ruff check
entry: uv run ruff check --force-exclude --fix --exit-non-zero-on-fix .
require_serial: true
language: unsupported
types: [python]
- id: local-ruff-format
name: ruff format
entry: uv run ruff format --force-exclude --exit-non-zero-on-format .
require_serial: true
language: unsupported
types: [python]
- id: local-ty
name: ty check
entry: uv run ty check
require_serial: true
language: unsupported
pass_filenames: false

View File

@@ -1,6 +1,6 @@
# FastAPI Toolsets
FastAPI Toolsets provides production-ready utilities for FastAPI applications built with async SQLAlchemy and PostgreSQL. It includes generic CRUD operations, a fixture system with dependency resolution, a Django-like CLI, standardized API responses, and structured exception handling with automatic OpenAPI documentation.
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
[![CI](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml/badge.svg)](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
[![codecov](https://codecov.io/gh/d3vyce/fastapi-toolsets/graph/badge.svg)](https://codecov.io/gh/d3vyce/fastapi-toolsets)
@@ -20,17 +20,45 @@ FastAPI Toolsets provides production-ready utilities for FastAPI applications bu
## Installation
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, model mixins, logging):
```bash
uv add fastapi-toolsets
```
Install only the extras you need:
```bash
uv add "fastapi-toolsets[cli]"
uv add "fastapi-toolsets[metrics]"
uv add "fastapi-toolsets[pytest]"
```
Or install everything:
```bash
uv add "fastapi-toolsets[all]"
```
## Features
- **CRUD**: Generic async CRUD operations with `CrudFactory`
- **Fixtures**: Fixture system with dependency management, context support and pytest integration
- **CLI**: Django-like command-line interface with fixture management and custom commands support
- **Standardized API Responses**: Consistent response format across your API
### Core
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`)
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
### Optional
- **CLI**: Django-like command-line interface with fixture management and custom commands support
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
## License

View File

@@ -0,0 +1 @@
# Authentication

View File

@@ -0,0 +1,185 @@
# Pagination & search
This example builds an articles listing endpoint that supports **offset pagination**, **cursor pagination**, **full-text search**, **faceted filtering**, and **sorting** — all from a single `CrudFactory` definition.
## Models
```python title="models.py"
--8<-- "docs_src/examples/pagination_search/models.py"
```
## Schemas
```python title="schemas.py"
--8<-- "docs_src/examples/pagination_search/schemas.py"
```
## Crud
Declare `searchable_fields`, `facet_fields`, and `order_fields` once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory). All endpoints built from this class share the same defaults and can override them per call.
```python title="crud.py"
--8<-- "docs_src/examples/pagination_search/crud.py"
```
## Session dependency
```python title="db.py"
--8<-- "docs_src/examples/pagination_search/db.py"
```
!!! info "Deploy a Postgres DB with docker"
```bash
docker run -d --name postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres -p 5432:5432 postgres:18-alpine
```
## App
```python title="app.py"
--8<-- "docs_src/examples/pagination_search/app.py"
```
## Routes
```python title="routes.py:1:16"
--8<-- "docs_src/examples/pagination_search/routes.py:1:16"
```
### Offset pagination
Best for admin panels or any UI that needs a total item count and numbered pages.
```python title="routes.py:19:37"
--8<-- "docs_src/examples/pagination_search/routes.py:19:37"
```
**Example request**
```
GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&order_by=title&order=asc
```
**Example response**
```json
{
"status": "SUCCESS",
"pagination_type": "offset",
"data": [
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
],
"pagination": {
"total_count": 42,
"pages": 5,
"page": 2,
"items_per_page": 10,
"has_more": true
},
"filter_attributes": {
"status": ["archived", "draft", "published"],
"name": ["backend", "frontend", "python"]
}
}
```
`filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client.
To skip the `COUNT(*)` query for better performance on large tables, pass `include_total=False`. `pagination.total_count` will be `null` in the response, while `has_more` remains accurate.
### Cursor pagination
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
```python title="routes.py:40:58"
--8<-- "docs_src/examples/pagination_search/routes.py:40:58"
```
**Example request**
```
GET /articles/cursor?items_per_page=10&status=published&order_by=created_at&order=desc
```
**Example response**
```json
{
"status": "SUCCESS",
"pagination_type": "cursor",
"data": [
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
],
"pagination": {
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
"prev_cursor": null,
"items_per_page": 10,
"has_more": true
},
"filter_attributes": {
"status": ["published"],
"name": ["backend", "python"]
}
}
```
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page.
### Unified endpoint (both strategies)
!!! info "Added in `v2.3.0`"
[`paginate()`](../module/crud.md#unified-paginate--both-strategies-on-one-endpoint) lets a single endpoint support both strategies via a `pagination_type` query parameter. The `pagination_type` field in the response acts as a discriminator for frontend tooling.
```python title="routes.py:61:79"
--8<-- "docs_src/examples/pagination_search/routes.py:61:79"
```
**Offset request** (default)
```
GET /articles/?pagination_type=offset&page=1&items_per_page=10
```
```json
{
"status": "SUCCESS",
"pagination_type": "offset",
"data": ["..."],
"pagination": { "total_count": 42, "pages": 5, "page": 1, "items_per_page": 10, "has_more": true }
}
```
**Cursor request**
```
GET /articles/?pagination_type=cursor&items_per_page=10
GET /articles/?pagination_type=cursor&items_per_page=10&cursor=eyJ2YWx1ZSI6...
```
```json
{
"status": "SUCCESS",
"pagination_type": "cursor",
"data": ["..."],
"pagination": { "next_cursor": "eyJ2YWx1ZSI6...", "prev_cursor": null, "items_per_page": 10, "has_more": true }
}
```
## Search behaviour
Both endpoints inherit the same `searchable_fields` declared on `ArticleCrud`:
Search is **case-insensitive** and uses a `LIKE %query%` pattern. Pass a [`SearchConfig`](../reference/crud.md#fastapi_toolsets.crud.search.SearchConfig) instead of a plain string to control case sensitivity or switch to `match_mode="all"` (AND across all fields instead of OR).
```python
from fastapi_toolsets.crud import SearchConfig
# Both title AND body must contain "fastapi"
result = await ArticleCrud.offset_paginate(
session,
search=SearchConfig(query="fastapi", case_sensitive=True, match_mode="all"),
search_fields=[Article.title, Article.body],
)
```

69
docs/index.md Normal file
View File

@@ -0,0 +1,69 @@
# FastAPI Toolsets
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
[![CI](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml/badge.svg)](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
[![codecov](https://codecov.io/gh/d3vyce/fastapi-toolsets/graph/badge.svg)](https://codecov.io/gh/d3vyce/fastapi-toolsets)
[![ty](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ty/main/assets/badge/v0.json)](https://github.com/astral-sh/ty)
[![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
---
**Documentation**: [https://fastapi-toolsets.d3vyce.fr](https://fastapi-toolsets.d3vyce.fr)
**Source Code**: [https://github.com/d3vyce/fastapi-toolsets](https://github.com/d3vyce/fastapi-toolsets)
---
## Installation
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, model mixins, logging):
```bash
uv add fastapi-toolsets
```
Install only the extras you need:
```bash
uv add "fastapi-toolsets[cli]"
uv add "fastapi-toolsets[metrics]"
uv add "fastapi-toolsets[pytest]"
```
Or install everything:
```bash
uv add "fastapi-toolsets[all]"
```
## Features
### Core
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`).
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations.
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
### Optional
- **CLI**: Django-like command-line interface with fixture management and custom commands support
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
## License
MIT License - see [LICENSE](LICENSE) for details.
## Contributing
Contributions are welcome! Please feel free to submit issues and pull requests.

137
docs/migration/v2.md Normal file
View File

@@ -0,0 +1,137 @@
# Migrating to v2.0
This page covers every breaking change introduced in **v2.0** and the steps required to update your code.
---
## CRUD
### `schema` is now required in `offset_paginate()` and `cursor_paginate()`
Calls that omit `schema` will now raise a `TypeError` at runtime.
Previously `schema` was optional; omitting it returned raw SQLAlchemy model instances inside the response. It is now a required keyword argument and the response always contains serialized schema instances.
=== "Before (`v1`)"
```python
# schema omitted — returned raw model instances
result = await UserCrud.offset_paginate(session=session, page=1)
result = await UserCrud.cursor_paginate(session=session, cursor=token)
```
=== "Now (`v2`)"
```python
result = await UserCrud.offset_paginate(session=session, page=1, schema=UserRead)
result = await UserCrud.cursor_paginate(session=session, cursor=token, schema=UserRead)
```
### `as_response` removed from `create()`, `get()`, and `update()`
Passing `as_response` to these methods will raise a `TypeError` at runtime.
The `as_response=True` shorthand is replaced by passing a `schema` directly. The return value is a `Response[schema]` when `schema` is provided, or the raw model instance when it is not.
=== "Before (`v1`)"
```python
user = await UserCrud.create(session=session, obj=data, as_response=True)
user = await UserCrud.get(session=session, filters=filters, as_response=True)
user = await UserCrud.update(session=session, obj=data, filters, as_response=True)
```
=== "Now (`v2`)"
```python
user = await UserCrud.create(session=session, obj=data, schema=UserRead)
user = await UserCrud.get(session=session, filters=filters, schema=UserRead)
user = await UserCrud.update(session=session, obj=data, filters, schema=UserRead)
```
### `delete()`: `as_response` renamed and return type changed
`as_response` is gone, and the plain (non-response) call no longer returns `True`.
Two changes were made to `delete()`:
1. The `as_response` parameter is renamed to `return_response`.
2. When called without `return_response=True`, the method now returns `None` on success instead of `True`.
=== "Before (`v1`)"
```python
ok = await UserCrud.delete(session=session, filters=filters)
if ok: # True on success
...
response = await UserCrud.delete(session=session, filters=filters, as_response=True)
```
=== "Now (`v2`)"
```python
await UserCrud.delete(session=session, filters=filters) # returns None
response = await UserCrud.delete(session=session, filters=filters, return_response=True)
```
### `paginate()` alias removed
Any call to `crud.paginate(...)` will raise `AttributeError` at runtime.
The `paginate` shorthand was an alias for `offset_paginate`. It has been removed; call `offset_paginate` directly.
=== "Before (`v1`)"
```python
result = await UserCrud.paginate(session=session, page=2, items_per_page=20, schema=UserRead)
```
=== "Now (`v2`)"
```python
result = await UserCrud.offset_paginate(session=session, page=2, items_per_page=20, schema=UserRead)
```
---
## Exceptions
### Missing `api_error` raises `TypeError` at class definition time
Unfinished or stub exception subclasses that previously compiled fine will now fail on import.
In `v1`, a subclass without `api_error` would only fail when the exception was raised. In `v2`, `__init_subclass__` validates this at class definition time.
=== "Before (`v1`)"
```python
class MyError(ApiException):
pass # fine until raised
```
=== "Now (`v2`)"
```python
class MyError(ApiException):
pass # TypeError: MyError must define an 'api_error' class attribute.
```
For shared base classes that are not meant to be raised directly, use `abstract=True`:
```python
class BillingError(ApiException, abstract=True):
"""Base for all billing-related errors — not raised directly."""
class PaymentRequiredError(BillingError):
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
```
---
## Schemas
### `Pagination` alias removed
`Pagination` was already deprecated in `v1` and is fully removed in `v2`, you now need to use [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) or [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination).

180
docs/migration/v3.md Normal file
View File

@@ -0,0 +1,180 @@
# Migrating to v3.0
This page covers every breaking change introduced in **v3.0** and the steps required to update your code.
---
## CRUD
### Facet keys now always use the full relationship chain
In `v2`, relationship facet fields used only the terminal column key (e.g. `"name"` for `Role.name`) and only prepended the relationship name when two facet fields shared the same column key. In `v3`, facet keys **always** include the full relationship chain joined by `__`, regardless of collisions.
=== "Before (`v2`)"
```
User.status -> status
(User.role, Role.name) -> name
(User.role, Role.permission, Permission.name) -> name
```
=== "Now (`v3`)"
```
User.status -> status
(User.role, Role.name) -> role__name
(User.role, Role.permission, Permission.name) -> role__permission__name
```
---
### `*_params` dependencies consolidated into per-paginate methods
The six individual dependency methods (`offset_params`, `cursor_params`, `paginate_params`, `filter_params`, `search_params`, `order_params`) have been **removed** and replaced by three consolidated methods that bundle pagination, search, filter, and order into a single `Depends()` call.
| Removed | Replacement |
|---|---|
| `offset_params()` + `filter_params()` + `search_params()` + `order_params()` | `offset_paginate_params()` |
| `cursor_params()` + `filter_params()` + `search_params()` + `order_params()` | `cursor_paginate_params()` |
| `paginate_params()` + `filter_params()` + `search_params()` + `order_params()` | `paginate_params()` |
Each new method accepts `search`, `filter`, and `order` boolean toggles (all `True` by default) to disable features you don't need.
=== "Before (`v2`)"
```python
from fastapi_toolsets.crud import OrderByClause
@router.get("/offset")
async def list_articles_offset(
session: SessionDep,
params: Annotated[dict, Depends(ArticleCrud.offset_params(default_page_size=20))],
filter_by: Annotated[dict, Depends(ArticleCrud.filter_params())],
order_by: Annotated[OrderByClause | None, Depends(ArticleCrud.order_params(default_field=Article.created_at))],
search: str | None = None,
) -> OffsetPaginatedResponse[ArticleRead]:
return await ArticleCrud.offset_paginate(
session=session,
**params,
search=search,
filter_by=filter_by or None,
order_by=order_by,
schema=ArticleRead,
)
```
=== "Now (`v3`)"
```python
@router.get("/offset")
async def list_articles_offset(
session: SessionDep,
params: Annotated[
dict,
Depends(
ArticleCrud.offset_paginate_params(
default_page_size=20,
default_order_field=Article.created_at,
)
),
],
) -> OffsetPaginatedResponse[ArticleRead]:
return await ArticleCrud.offset_paginate(session=session, **params, schema=ArticleRead)
```
The same pattern applies to `cursor_paginate_params()` and `paginate_params()`. To disable a feature, pass the toggle:
```python
# No search or ordering, only pagination + filtering
ArticleCrud.offset_paginate_params(search=False, order=False)
```
---
## Models
The lifecycle event system has been rewritten. Callbacks are now registered with a module-level [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator and dispatched by [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession), replacing the mixin-based approach from `v2`.
### `WatchedFieldsMixin` and `@watch` removed
Importing `WatchedFieldsMixin` or `watch` will raise `ImportError`.
Model method callbacks (`on_create`, `on_delete`, `on_update`) and the `@watch` decorator are replaced by:
1. **`__watched_fields__`** — a plain class attribute to restrict which field changes trigger `UPDATE` events (replaces `@watch`).
2. **`@listens_for`** — a module-level decorator to register callbacks for one or more [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) types (replaces `on_create` / `on_delete` / `on_update` methods).
=== "Before (`v2`)"
```python
from fastapi_toolsets.models import WatchedFieldsMixin, watch
@watch("status")
class Order(Base, UUIDMixin, WatchedFieldsMixin):
__tablename__ = "orders"
status: Mapped[str]
async def on_create(self):
await notify_new_order(self.id)
async def on_update(self, changes):
if "status" in changes:
await notify_status_change(self.id, changes["status"])
async def on_delete(self):
await notify_order_cancelled(self.id)
```
=== "Now (`v3`)"
```python
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
class Order(Base, UUIDMixin):
__tablename__ = "orders"
__watched_fields__ = ("status",)
status: Mapped[str]
@listens_for(Order, [ModelEvent.CREATE])
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
await notify_new_order(order.id)
@listens_for(Order, [ModelEvent.UPDATE])
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
if "status" in changes:
await notify_status_change(order.id, changes["status"])
@listens_for(Order, [ModelEvent.DELETE])
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
await notify_order_cancelled(order.id)
```
### `EventSession` now required
Without `EventSession`, lifecycle callbacks will silently stop firing.
Callbacks are now dispatched inside `EventSession.commit()` rather than via background tasks. Pass it as the session class when creating your session factory:
=== "Before (`v2`)"
```python
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
```
=== "Now (`v3`)"
```python
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from fastapi_toolsets.models import EventSession
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
```
!!! note
If you use `create_db_session` from `fastapi_toolsets.pytest`, the session already uses `EventSession` — no changes needed in tests.

93
docs/module/cli.md Normal file
View File

@@ -0,0 +1,93 @@
# CLI
Typer-based command-line interface for managing your FastAPI application, with built-in fixture commands integration.
## Installation
=== "uv"
``` bash
uv add "fastapi-toolsets[cli]"
```
=== "pip"
``` bash
pip install "fastapi-toolsets[cli]"
```
## Overview
The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It allow custom commands to be added in addition of the fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured.
## Configuration
Configure the CLI in your `pyproject.toml`:
```toml
[tool.fastapi-toolsets]
cli = "myapp.cli:cli" # Custom Typer app
fixtures = "myapp.fixtures:registry" # FixtureRegistry instance
db_context = "myapp.db:db_context" # Async context manager for sessions
```
All fields are optional. Without configuration, the `manager` command still works but no command are available.
## Usage
```bash
# Manager commands
manager --help
Usage: manager [OPTIONS] COMMAND [ARGS]...
FastAPI utilities CLI.
╭─ Options ────────────────────────────────────────────────────────────────────────╮
│ --install-completion Install completion for the current shell. │
│ --show-completion Show completion for the current shell, to copy it │
│ or customize the installation. │
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────╯
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
│ check-db │
│ fixtures Manage database fixtures. │
╰──────────────────────────────────────────────────────────────────────────────────╯
# Fixtures commands
manager fixtures --help
Usage: manager fixtures [OPTIONS] COMMAND [ARGS]...
Manage database fixtures.
╭─ Options ────────────────────────────────────────────────────────────────────────╮
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────╯
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
│ list List all registered fixtures. │
│ load Load fixtures into the database. │
╰──────────────────────────────────────────────────────────────────────────────────╯
```
## Custom CLI
You can extend the CLI by providing your own Typer app. The `manager` entry point will merge your app's commands with the built-in ones:
```python
# myapp/cli.py
import typer
cli = typer.Typer()
@cli.command()
def hello():
print("Hello from my app!")
```
```toml
[tool.fastapi-toolsets]
cli = "myapp.cli:cli"
```
---
[:material-api: API Reference](../reference/cli.md)

644
docs/module/crud.md Normal file
View File

@@ -0,0 +1,644 @@
# CRUD
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
!!! info
This module has been coded and tested to be compatible with PostgreSQL only.
## Overview
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), a base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
## Creating a CRUD class
### Factory style
```python
from fastapi_toolsets.crud import CrudFactory
from myapp.models import User
UserCrud = CrudFactory(model=User)
```
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model. This is the most concise option for straightforward CRUD with no custom logic.
### Subclass style
!!! info "Added in `v2.3.0`"
```python
from fastapi_toolsets.crud.factory import AsyncCrud
from myapp.models import User
class UserCrud(AsyncCrud[User]):
model = User
searchable_fields = [User.username, User.email]
default_load_options = [selectinload(User.role)]
```
Subclassing [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud) directly is the preferred style when you need to add custom methods or when the configuration is complex enough to benefit from a named class body.
### Adding custom methods
```python
class UserCrud(AsyncCrud[User]):
model = User
@classmethod
async def get_active(cls, session: AsyncSession) -> list[User]:
return await cls.get_multi(session, filters=[User.is_active == True])
```
### Sharing a custom base across multiple models
Define a generic base class with the shared methods, then subclass it for each model:
```python
from typing import Generic, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from fastapi_toolsets.crud.factory import AsyncCrud
T = TypeVar("T", bound=DeclarativeBase)
class AuditedCrud(AsyncCrud[T], Generic[T]):
"""Base CRUD with custom function"""
@classmethod
async def get_active(cls, session: AsyncSession):
return await cls.get_multi(session, filters=[cls.model.is_active == True])
class UserCrud(AuditedCrud[User]):
model = User
searchable_fields = [User.username, User.email]
```
You can also use the factory shorthand with the same base by passing `base_class`:
```python
UserCrud = CrudFactory(User, base_class=AuditedCrud)
```
## Basic operations
!!! info "`get_or_none` added in `v2.2`"
```python
# Create
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
# Get one (raises NotFoundError if not found)
user = await UserCrud.get(session=session, filters=[User.id == user_id])
# Get one or None (never raises)
user = await UserCrud.get_or_none(session=session, filters=[User.id == user_id])
# Get first or None
user = await UserCrud.first(session=session, filters=[User.email == email])
# Get multiple
users = await UserCrud.get_multi(session=session, filters=[User.is_active == True])
# Update
user = await UserCrud.update(session=session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id])
# Delete
await UserCrud.delete(session=session, filters=[User.id == user_id])
# Count / exists
count = await UserCrud.count(session=session, filters=[User.is_active == True])
exists = await UserCrud.exists(session=session, filters=[User.email == email])
```
## Fetching a single record
Three methods fetch a single record — choose based on how you want to handle the "not found" case and whether you need strict uniqueness:
| Method | Not found | Multiple results |
|---|---|---|
| `get` | raises `NotFoundError` | raises `MultipleResultsFound` |
| `get_or_none` | returns `None` | raises `MultipleResultsFound` |
| `first` | returns `None` | returns the first match silently |
Use `get` when the record must exist (e.g. a detail endpoint that should return 404):
```python
user = await UserCrud.get(session=session, filters=[User.id == user_id])
```
Use `get_or_none` when the record may not exist but you still want strict uniqueness enforcement:
```python
user = await UserCrud.get_or_none(session=session, filters=[User.email == email])
if user is None:
... # handle missing case without catching an exception
```
Use `first` when you only care about any one match and don't need uniqueness:
```python
user = await UserCrud.first(session=session, filters=[User.is_active == True])
```
## Pagination
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
Three pagination methods are available. All return a typed response whose `pagination_type` field tells clients which strategy was used.
| | `offset_paginate` | `cursor_paginate` | `paginate` |
|---|---|---|---|
| Return type | `OffsetPaginatedResponse` | `CursorPaginatedResponse` | either, based on `pagination_type` param |
| Total count | Yes | No | / |
| Jump to arbitrary page | Yes | No | / |
| Performance on deep pages | Degrades | Constant | / |
| Stable under concurrent inserts | No | Yes | / |
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll | single endpoint, both strategies |
### Offset pagination
```python
from typing import Annotated
from fastapi import Depends
@router.get("")
async def get_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
) -> OffsetPaginatedResponse[UserRead]:
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
```
The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns an [`OffsetPaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPaginatedResponse):
```json
{
"status": "SUCCESS",
"pagination_type": "offset",
"data": ["..."],
"pagination": {
"total_count": 100,
"pages": 5,
"page": 1,
"items_per_page": 20,
"has_more": true
}
}
```
#### Skipping the COUNT query
!!! info "Added in `v2.4.1`"
By default `offset_paginate` runs two queries: one for the page items and one `COUNT(*)` for `total_count`. On large tables the `COUNT` can be expensive. Pass `include_total=False` to `offset_paginate_params()` to skip it:
```python
@router.get("")
async def get_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.offset_paginate_params(include_total=False))],
) -> OffsetPaginatedResponse[UserRead]:
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
```
### Cursor pagination
```python
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.cursor_paginate_params())],
) -> CursorPaginatedResponse[UserRead]:
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
```
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`CursorPaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPaginatedResponse):
```json
{
"status": "SUCCESS",
"pagination_type": "cursor",
"data": ["..."],
"pagination": {
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
"prev_cursor": null,
"items_per_page": 20,
"has_more": true
}
}
```
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
#### Choosing a cursor column
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
- Auto-increment integer PKs
- UUID v7 PKs
- Timestamps
!!! warning
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
!!! note
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
The cursor value is URL-safe base64-encoded (no padding) when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
| SQLAlchemy type | Python type |
|---|---|
| `Integer`, `BigInteger`, `SmallInteger` | `int` |
| `Uuid` | `uuid.UUID` |
| `DateTime` | `datetime.datetime` |
| `Date` | `datetime.date` |
| `Float`, `Numeric` | `decimal.Decimal` |
```python
# Paginate by the primary key
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
# Paginate by a timestamp column instead
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
```
### Unified endpoint (both strategies)
!!! info "Added in `v2.3.0`"
[`paginate()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) dispatches to `offset_paginate` or `cursor_paginate` based on a `pagination_type` query parameter, letting you expose **one endpoint** that supports both strategies. The `pagination_type` field in the response tells clients which strategy was used, enabling frontend discriminated-union typing.
```python
from fastapi_toolsets.schemas import PaginatedResponse
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.paginate_params())],
) -> PaginatedResponse[UserRead]:
return await UserCrud.paginate(session, **params, schema=UserRead)
```
```
GET /users?pagination_type=offset&page=2&items_per_page=10
GET /users?pagination_type=cursor&cursor=eyJ2YWx1ZSI6...&items_per_page=10
```
## Search
Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate).
| | Full-text search | Faceted search |
|---|---|---|
| Input | Free-text string | Exact column values |
| Relationship support | Yes | Yes |
| Use case | Search bars | Filter dropdowns |
!!! info "You can use both search strategies in the same endpoint!"
### Full-text search
!!! info "Added in `v2.2.1`"
The model's primary key is always included in `searchable_fields` automatically, so searching by ID works out of the box without any configuration. When no `searchable_fields` are declared, only the primary key is searched.
Declare `searchable_fields` on the CRUD class. Relationship traversal is supported via tuples:
```python
PostCrud = CrudFactory(
model=Post,
searchable_fields=[
Post.title,
Post.content,
(Post.author, User.username), # search across relationship
],
)
```
You can override `searchable_fields` per call with `search_fields`:
```python
result = await UserCrud.offset_paginate(
session=session,
search_fields=[User.country],
)
```
Or via the dependency to narrow which fields are exposed as query parameters:
```python
params = UserCrud.offset_paginate_params(search_fields=[Post.title])
```
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
```python
@router.get("")
async def get_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
) -> OffsetPaginatedResponse[UserRead]:
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
```
```python
@router.get("")
async def get_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.cursor_paginate_params())],
) -> CursorPaginatedResponse[UserRead]:
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
```
The dependency adds two query parameters to the endpoint:
| Parameter | Type |
| --------------- | ------------- |
| `search` | `str \| null` |
| `search_column` | `str \| null` |
```
GET /posts?search=hello → search all configured columns
GET /posts?search=hello&search_column=title → search only Post.title
```
The available search column keys are returned in the `search_columns` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse). Use them to populate a column picker in the UI, or to validate `search_column` values on the client side:
```json
{
"status": "SUCCESS",
"data": ["..."],
"pagination": { "..." },
"search_columns": ["content", "author__username", "title"]
}
```
!!! info "Key format uses `__` as a separator for relationship chains."
A direct column `Post.title` produces `"title"`. A relationship tuple `(Post.author, User.username)` produces `"author__username"`. An unknown `search_column` value raises [`InvalidSearchColumnError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidSearchColumnError) (HTTP 422).
### Faceted search
!!! info "Added in `v1.2`"
Declare `facet_fields` on the CRUD class to return distinct column values alongside paginated results. This is useful for populating filter dropdowns or building faceted search UIs. Relationship traversal is supported via tuples, using the same syntax as `searchable_fields`:
```python
UserCrud = CrudFactory(
model=User,
facet_fields=[
User.status,
User.country,
(User.role, Role.name), # value from a related model
],
)
```
You can override `facet_fields` per call:
```python
result = await UserCrud.offset_paginate(
session=session,
facet_fields=[User.country],
)
```
Or via the dependency to narrow which fields are exposed as query parameters:
```python
params = UserCrud.offset_paginate_params(facet_fields=[User.country])
```
Facet filtering is built into the consolidated params dependencies. When `filter=True` (the default), each facet field is exposed as a query parameter and values are collected into `filter_by` automatically:
```python
from typing import Annotated
from fastapi import Depends
@router.get("", response_model_exclude_none=True)
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
) -> OffsetPaginatedResponse[UserRead]:
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
```
```python
@router.get("", response_model_exclude_none=True)
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.cursor_paginate_params())],
) -> CursorPaginatedResponse[UserRead]:
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
```
Both single-value and multi-value query parameters work:
```
GET /users?status=active → filter_by={"status": ["active"]}
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
GET /users?role__name=admin&role__name=editor → filter_by={"role__name": ["admin", "editor"]} (IN clause)
```
`filter_by` and `filters` can be combined — both are applied with AND logic.
The distinct values for each facet field are returned in the `filter_attributes` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse). Use them to populate filter dropdowns in the UI, or to validate `filter_by` keys on the client side:
```json
{
"status": "SUCCESS",
"data": ["..."],
"pagination": { "..." },
"filter_attributes": {
"status": ["active", "inactive"],
"country": ["DE", "FR", "US"],
"role__name": ["admin", "editor", "viewer"]
}
}
```
!!! info "Key format uses `__` as a separator for relationship chains."
A direct column `User.status` produces `"status"`. A relationship tuple `(User.role, Role.name)` produces `"role__name"`. A deeper chain `(User.role, Role.permission, Permission.name)` produces `"role__permission__name"`. An unknown `filter_by` key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) (HTTP 422).
## Sorting
!!! info "Added in `v1.3`"
Declare `order_fields` on the CRUD class. Relationship traversal is supported via tuples, using the same syntax as `searchable_fields` and `facet_fields`:
```python
UserCrud = CrudFactory(
model=User,
order_fields=[
User.name,
User.created_at,
(User.role, Role.name), # sort by a related model column
],
)
```
You can override `order_fields` per call:
```python
result = await UserCrud.offset_paginate(
session=session,
order_fields=[User.name],
)
```
Or via the dependency to narrow which fields are exposed as query parameters:
```python
params = UserCrud.offset_paginate_params(order_fields=[User.name])
```
Sorting is built into the consolidated params dependencies. When `order=True` (the default), `order_by` and `order` query parameters are exposed and resolved into an `OrderByClause` automatically:
```python
from typing import Annotated
from fastapi import Depends
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
) -> OffsetPaginatedResponse[UserRead]:
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
```
```python
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.cursor_paginate_params())],
) -> CursorPaginatedResponse[UserRead]:
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
```
The dependency adds two query parameters to the endpoint:
| Parameter | Type |
| ---------- | --------------- |
| `order_by` | `str \| null` |
| `order` | `asc` or `desc` |
```
GET /users?order_by=name&order=asc → ORDER BY users.name ASC
GET /users?order_by=role__name&order=desc → LEFT JOIN roles ON ... ORDER BY roles.name DESC
```
!!! info "Relationship tuples are joined automatically."
When a relation field is selected, the related table is LEFT OUTER JOINed automatically. An unknown `order_by` value raises [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) (HTTP 422).
The available sort keys are returned in the `order_columns` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse). Use them to populate a sort picker in the UI, or to validate `order_by` values on the client side:
```json
{
"status": "SUCCESS",
"data": ["..."],
"pagination": { "..." },
"order_columns": ["created_at", "name", "role__name"]
}
```
!!! info "Key format uses `__` as a separator for relationship chains."
A direct column `User.name` produces `"name"`. A relationship tuple `(User.role, Role.name)` produces `"role__name"`.
## Relationship loading
!!! info "Added in `v1.1`"
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
!!! warning
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
```python
from sqlalchemy.orm import selectinload
ArticleCrud = CrudFactory(
model=Article,
default_load_options=[
selectinload(Article.category),
selectinload(Article.tags),
],
)
```
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
```python
# Only loads category, tags are not loaded
article = await ArticleCrud.get(
session=session,
filters=[Article.id == article_id],
load_options=[selectinload(Article.category)],
)
# Loads nothing — useful for write-then-refresh flows or lightweight checks
articles = await ArticleCrud.get_multi(session=session, load_options=[])
```
## Many-to-many relationships
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
```python
PostCrud = CrudFactory(
model=Post,
m2m_fields={"tag_ids": Post.tags},
)
post = await PostCrud.create(session=session, obj=PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]))
```
## Upsert
Atomic `INSERT ... ON CONFLICT DO UPDATE` using PostgreSQL:
```python
await UserCrud.upsert(
session=session,
obj=UserCreateSchema(email="alice@example.com", username="alice"),
index_elements=[User.email],
set_={"username"},
)
```
## Response serialization
!!! info "Added in `v1.1`"
Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
```python
class UserRead(PydanticBase):
id: UUID
username: str
@router.get(
"/{uuid}",
responses=generate_error_responses(NotFoundError),
)
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
return await crud.UserCrud.get(
session=session,
filters=[User.id == uuid],
schema=UserRead,
)
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(crud.UserCrud.offset_paginate_params())],
) -> OffsetPaginatedResponse[UserRead]:
return await crud.UserCrud.offset_paginate(session=session, **params, schema=UserRead)
```
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
---
[:material-api: API Reference](../reference/crud.md)

123
docs/module/db.md Normal file
View File

@@ -0,0 +1,123 @@
# DB
SQLAlchemy async session management with transactions, table locking, and row-change polling.
!!! info
This module has been coded and tested to be compatible with PostgreSQL only.
## Overview
The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, table lock and polling for row changes.
## Session dependency
Use [`create_db_dependency`](../reference/db.md#fastapi_toolsets.db.create_db_dependency) to create a FastAPI dependency that yields a session and auto-commits on success:
```python
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from fastapi_toolsets.db import create_db_dependency
engine = create_async_engine(url="postgresql+asyncpg://...", future=True)
session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
get_db = create_db_dependency(session_maker=session_maker)
@router.get("/users")
async def list_users(session: AsyncSession = Depends(get_db)):
...
```
## Session context manager
Use [`create_db_context`](../reference/db.md#fastapi_toolsets.db.create_db_context) for sessions outside request handlers (e.g. background tasks, CLI commands):
```python
from fastapi_toolsets.db import create_db_context
db_context = create_db_context(session_maker=session_maker)
async def seed():
async with db_context() as session:
...
```
## Nested transactions
[`get_transaction`](../reference/db.md#fastapi_toolsets.db.get_transaction) handles savepoints automatically, allowing safe nesting:
```python
from fastapi_toolsets.db import get_transaction
async def create_user_with_role(session=session):
async with get_transaction(session=session):
...
async with get_transaction(session=session): # uses savepoint
...
```
## Table locking
[`lock_tables`](../reference/db.md#fastapi_toolsets.db.lock_tables) acquires PostgreSQL table-level locks before executing critical sections:
```python
from fastapi_toolsets.db import lock_tables
async with lock_tables(session=session, tables=[User], mode="EXCLUSIVE"):
# No other transaction can modify User until this block exits
...
```
Available lock modes are defined in [`LockMode`](../reference/db.md#fastapi_toolsets.db.LockMode): `ACCESS_SHARE`, `ROW_SHARE`, `ROW_EXCLUSIVE`, `SHARE_UPDATE_EXCLUSIVE`, `SHARE`, `SHARE_ROW_EXCLUSIVE`, `EXCLUSIVE`, `ACCESS_EXCLUSIVE`.
## Row-change polling
[`wait_for_row_change`](../reference/db.md#fastapi_toolsets.db.wait_for_row_change) polls a row until a specific column changes value, useful for waiting on async side effects:
```python
from fastapi_toolsets.db import wait_for_row_change
# Wait up to 30s for order.status to change
await wait_for_row_change(
session=session,
model=Order,
pk_value=order_id,
columns=[Order.status],
interval=1.0,
timeout=30.0,
)
```
## Creating a database
!!! info "Added in `v2.1`"
[`create_database`](../reference/db.md#fastapi_toolsets.db.create_database) creates a database at a given URL. It connects to *server_url* and issues a `CREATE DATABASE` statement:
```python
from fastapi_toolsets.db import create_database
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
await create_database(db_name="myapp_test", server_url=SERVER_URL)
```
For test isolation with automatic cleanup, use [`create_worker_database`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_worker_database) from the `pytest` module instead — it handles drop-before, create, and drop-after automatically.
## Cleaning up tables
!!! info "Added in `v2.1`"
[`cleanup_tables`](../reference/db.md#fastapi_toolsets.db.cleanup_tables) truncates all tables:
```python
from fastapi_toolsets.db import cleanup_tables
@pytest.fixture(autouse=True)
async def clean(db_session):
yield
await cleanup_tables(session=db_session, base=Base)
```
---
[:material-api: API Reference](../reference/db.md)

View File

@@ -0,0 +1,61 @@
# Dependencies
FastAPI dependency factories for automatic model resolution from path and body parameters.
## Overview
The `dependencies` module provides two factory functions that create FastAPI dependencies to fetch a model instance from the database automatically — either from a path parameter or from a request body field — and inject it directly into your route handler.
## `PathDependency`
[`PathDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.PathDependency) resolves a model from a URL path parameter and injects it into the route handler. Raises [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) automatically if the record does not exist.
```python
from fastapi_toolsets.dependencies import PathDependency
# Plain callable
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
# Annotated
SessionDep = Annotated[AsyncSession, Depends(get_db)]
UserDep = PathDependency(model=User, field=User.id, session_dep=SessionDep)
@router.get("/users/{user_id}")
async def get_user(user: User = UserDep):
return user
```
By default the parameter name is inferred from the field (`user_id` for `User.id`). You can override it:
```python
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db, param_name="id")
@router.get("/users/{id}")
async def get_user(user: User = UserDep):
return user
```
## `BodyDependency`
[`BodyDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.BodyDependency) resolves a model from a field in the request body. Useful when a body contains a foreign key and you want the full object injected:
```python
from fastapi_toolsets.dependencies import BodyDependency
# Plain callable
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
# Annotated
SessionDep = Annotated[AsyncSession, Depends(get_db)]
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=SessionDep, body_field="role_id")
@router.post("/users")
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
user = User(username=body.username, role=role)
...
```
---
[:material-api: API Reference](../reference/dependencies.md)

131
docs/module/exceptions.md Normal file
View File

@@ -0,0 +1,131 @@
# Exceptions
Structured API exceptions with consistent error responses and automatic OpenAPI documentation.
## Overview
The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse).
## Setup
Register the exception handlers on your FastAPI app at startup:
```python
from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers
app = FastAPI()
init_exceptions_handlers(app=app)
```
This registers handlers for:
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
- `HTTPException` — Starlette/FastAPI HTTP errors
- `RequestValidationError` — Pydantic request validation (422)
- `ResponseValidationError` — Pydantic response validation (422)
- `Exception` — unhandled errors (500)
It also patches `app.openapi()` to replace the default Pydantic 422 schema with a structured example matching the `ErrorResponse` format.
## Built-in exceptions
| Exception | Status | Default message |
|-----------|--------|-----------------|
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not Found |
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No Searchable Fields |
| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid Facet Filter |
| [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) | 422 | Invalid Order Field |
### Per-instance overrides
All built-in exceptions accept optional keyword arguments to customise the response for a specific raise site without changing the class defaults:
| Argument | Effect |
|----------|--------|
| `detail` | Overrides both `str(exc)` (log output) and the `message` field in the response body |
| `desc` | Overrides the `description` field |
| `data` | Overrides the `data` field |
```python
raise NotFoundError(detail="User 42 not found", desc="No user with that ID exists in the database.")
```
## Custom exceptions
Subclass [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) and define an `api_error` class variable:
```python
from fastapi_toolsets.exceptions import ApiException
from fastapi_toolsets.schemas import ApiError
class PaymentRequiredError(ApiException):
api_error = ApiError(
code=402,
msg="Payment Required",
desc="Your subscription has expired.",
err_code="BILLING-402",
)
```
!!! warning
Subclasses that do not define `api_error` raise a `TypeError` at **class creation time**, not at raise time.
### Custom `__init__`
Override `__init__` to compute `detail`, `desc`, or `data` dynamically, then delegate to `super().__init__()`:
```python
class OrderValidationError(ApiException):
api_error = ApiError(
code=422,
msg="Order Validation Failed",
desc="One or more order fields are invalid.",
err_code="ORDER-422",
)
def __init__(self, *field_errors: str) -> None:
super().__init__(
f"{len(field_errors)} validation error(s)",
desc=", ".join(field_errors),
data={"errors": [{"message": e} for e in field_errors]},
)
```
### Intermediate base classes
Use `abstract=True` when creating a shared base that is not meant to be raised directly:
```python
class BillingError(ApiException, abstract=True):
"""Base for all billing-related errors."""
class PaymentRequiredError(BillingError):
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
class SubscriptionExpiredError(BillingError):
api_error = ApiError(code=402, msg="Subscription Expired", desc="...", err_code="BILLING-402-EXP")
```
## OpenAPI response documentation
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
```python
from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError, ForbiddenError
@router.get(
"/users/{id}",
responses=generate_error_responses(NotFoundError, ForbiddenError),
)
async def get_user(...): ...
```
Multiple exceptions sharing the same HTTP status code are grouped under one entry, each appearing as a named example keyed by its `err_code`. This keeps the OpenAPI UI readable when several error variants map to the same status.
---
[:material-api: API Reference](../reference/exceptions.md)

198
docs/module/fixtures.md Normal file
View File

@@ -0,0 +1,198 @@
# Fixtures
Dependency-aware database seeding with context-based loading strategies.
## Overview
The `fixtures` module lets you define named fixtures with dependencies between them, then load them into the database in the correct order. Fixtures can be scoped to contexts (e.g. base data, testing data) so that only the relevant ones are loaded for each environment.
## Defining fixtures
```python
from fastapi_toolsets.fixtures import FixtureRegistry, Context
fixtures = FixtureRegistry()
@fixtures.register
def roles():
return [
Role(id=1, name="admin"),
Role(id=2, name="user"),
]
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
def test_users():
return [
User(id=1, username="alice", role_id=1),
User(id=2, username="bob", role_id=2),
]
```
Dependencies declared via `depends_on` are resolved topologically — `roles` will always be loaded before `test_users`.
## Loading fixtures
By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures_by_context):
```python
from fastapi_toolsets.fixtures import load_fixtures_by_context
async with db_context() as session:
await load_fixtures_by_context(session, fixtures, Context.TESTING)
```
Directly by name with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
```python
from fastapi_toolsets.fixtures import load_fixtures
async with db_context() as session:
await load_fixtures(session, fixtures, "roles", "test_users")
```
Both functions return a `dict[str, list[...]]` mapping each fixture name to the list of loaded instances.
## Contexts
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
| Context | Description |
|---------|-------------|
| `Context.BASE` | Core data required in all environments |
| `Context.TESTING` | Data only loaded during tests |
| `Context.DEVELOPMENT` | Data only loaded in development |
| `Context.PRODUCTION` | Data only loaded in production |
A fixture with no `contexts` defined takes `Context.BASE` by default.
### Custom contexts
Plain strings and any `Enum` subclass are accepted wherever a `Context` enum is expected.
```python
from enum import Enum
class AppContext(str, Enum):
STAGING = "staging"
DEMO = "demo"
@fixtures.register(contexts=[AppContext.STAGING])
def staging_data():
return [Config(key="feature_x", enabled=True)]
await load_fixtures_by_context(session, fixtures, AppContext.STAGING)
```
### Default context for a registry
Pass `contexts` to `FixtureRegistry` to set a default for all fixtures registered in it:
```python
testing_registry = FixtureRegistry(contexts=[Context.TESTING])
@testing_registry.register # implicitly contexts=[Context.TESTING]
def test_orders():
return [Order(id=1, total=99)]
```
### Same fixture name, multiple context variants
The same fixture name may be registered under different (non-overlapping) context sets. When multiple contexts are loaded together, all matching variants are merged:
```python
@fixtures.register(contexts=[Context.BASE])
def users():
return [User(id=1, username="admin")]
@fixtures.register(contexts=[Context.TESTING])
def users():
return [User(id=2, username="tester")]
# loads both admin and tester
await load_fixtures_by_context(session, fixtures, Context.BASE, Context.TESTING)
```
Registering two variants with overlapping context sets raises `ValueError`.
## Load strategies
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
| Strategy | Description |
|----------|-------------|
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
| `LoadStrategy.MERGE` | Insert or update on conflict (default) |
| `LoadStrategy.SKIP_EXISTING` | Skip rows that already exist |
```python
await load_fixtures_by_context(
session, fixtures, Context.BASE, strategy=LoadStrategy.SKIP_EXISTING
)
```
## Merging registries
Split fixture definitions across modules and merge them:
```python
from myapp.fixtures.dev import dev_fixtures
from myapp.fixtures.prod import prod_fixtures
fixtures = FixtureRegistry()
fixtures.include_registry(registry=dev_fixtures)
fixtures.include_registry(registry=prod_fixtures)
```
Fixtures with the same name are allowed as long as their context sets do not overlap. Conflicting contexts raise `ValueError`.
## Looking up fixture instances
[`get_obj_by_attr`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.get_obj_by_attr) retrieves a specific instance from a fixture function by attribute value — useful when building cross-fixture `depends_on` relationships:
```python
from fastapi_toolsets.fixtures import get_obj_by_attr
@fixtures.register(depends_on=["roles"])
def users():
admin_role = get_obj_by_attr(roles, "name", "admin")
return [User(id=1, username="alice", role_id=admin_role.id)]
```
Raises `StopIteration` if no matching instance is found.
## Pytest integration
Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}` by default:
```python
# conftest.py
import pytest
from fastapi_toolsets.pytest import create_db_session, register_fixtures
from app.fixtures import registry
from app.models import Base
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
@pytest.fixture
async def db_session():
async with create_db_session(database_url=DATABASE_URL, base=Base, cleanup=True) as session:
yield session
register_fixtures(registry=registry, namespace=globals())
```
```python
# test_users.py
async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Role]):
...
```
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
## CLI integration
Fixtures can be triggered from the CLI. See the [CLI module](cli.md) for setup instructions.
---
[:material-api: API Reference](../reference/fixtures.md)

39
docs/module/logger.md Normal file
View File

@@ -0,0 +1,39 @@
# Logger
Lightweight logging utilities with consistent formatting and uvicorn integration.
## Overview
The `logger` module provides two helpers: one to configure the root logger (and uvicorn loggers) at startup, and one to retrieve a named logger anywhere in your codebase.
## Setup
Call [`configure_logging`](../reference/logger.md#fastapi_toolsets.logger.configure_logging) once at application startup:
```python
from fastapi_toolsets.logger import configure_logging
configure_logging(level="INFO")
```
This sets up a stdout handler with a consistent format and also configures uvicorn's access and error loggers so all log output shares the same style.
## Getting a logger
```python
from fastapi_toolsets.logger import get_logger
logger = get_logger(name=__name__)
logger.info("User created")
```
When called without arguments, [`get_logger`](../reference/logger.md#fastapi_toolsets.logger.get_logger) auto-detects the caller's module name via frame inspection:
```python
# Equivalent to get_logger(name=__name__)
logger = get_logger()
```
---
[:material-api: API Reference](../reference/logger.md)

109
docs/module/metrics.md Normal file
View File

@@ -0,0 +1,109 @@
# Metrics
Prometheus metrics integration with a decorator-based registry and multi-process support.
## Installation
=== "uv"
``` bash
uv add "fastapi-toolsets[metrics]"
```
=== "pip"
``` bash
pip install "fastapi-toolsets[metrics]"
```
## Overview
The `metrics` module provides a [`MetricsRegistry`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry) to declare Prometheus metrics with decorators, and an [`init_metrics`](../reference/metrics.md#fastapi_toolsets.metrics.handler.init_metrics) function to mount a `/metrics` endpoint on your FastAPI app.
## Setup
```python
from fastapi import FastAPI
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
app = FastAPI()
metrics = MetricsRegistry()
init_metrics(app=app, registry=metrics)
```
This mounts the `/metrics` endpoint that Prometheus can scrape.
## Declaring metrics
### Providers
Providers are called once at startup by `init_metrics`. The return value (the Prometheus metric object) is stored in the registry and can be retrieved later with [`registry.get(name)`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry.get).
Use providers when you want **deferred initialization**: the Prometheus metric is not registered with the global `CollectorRegistry` until `init_metrics` runs, not at import time. This is particularly useful for testing — importing the module in a test suite without calling `init_metrics` leaves no metrics registered, avoiding cross-test pollution.
It is also useful when metrics are defined across multiple modules and merged with `include_registry`: any code that needs a metric can call `metrics.get()` on the shared registry instead of importing the metric directly from its origin module.
If neither of these applies to you, declaring metrics at module level (e.g. `HTTP_REQUESTS = Counter(...)`) is simpler and equally valid.
```python
from prometheus_client import Counter, Histogram
@metrics.register
def http_requests():
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
@metrics.register
def request_duration():
return Histogram("request_duration_seconds", "Request duration")
```
To use a provider's metric elsewhere (e.g. in a middleware), call `metrics.get()` inside the handler — **not** at module level, as providers are only initialized when `init_metrics` runs:
```python
async def metrics_middleware(request: Request, call_next):
response = await call_next(request)
metrics.get("http_requests").labels(
method=request.method, status=response.status_code
).inc()
return response
```
### Collectors
Collectors are called on every scrape. Use them for metrics that reflect current state (e.g. gauges).
!!! warning "Declare the metric at module level"
Do **not** instantiate the Prometheus metric inside the collector function. Doing so recreates it on every scrape, raising `ValueError: Duplicated timeseries in CollectorRegistry`. Declare it once at module level instead:
```python
from prometheus_client import Gauge
_queue_depth = Gauge("queue_depth", "Current queue depth")
@metrics.register(collect=True)
def collect_queue_depth():
_queue_depth.set(get_current_queue_depth())
```
## Merging registries
Split metrics definitions across modules and merge them:
```python
from myapp.metrics.http import http_metrics
from myapp.metrics.db import db_metrics
metrics = MetricsRegistry()
metrics.include_registry(registry=http_metrics)
metrics.include_registry(registry=db_metrics)
```
## Multi-process mode
Multi-process support is enabled automatically when the `PROMETHEUS_MULTIPROC_DIR` environment variable is set. No code changes are required.
!!! warning "Environment variable name"
The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`).
---
[:material-api: API Reference](../reference/metrics.md)

234
docs/module/models.md Normal file
View File

@@ -0,0 +1,234 @@
# Models
!!! info "Added in `v2.0`"
Reusable SQLAlchemy 2.0 mixins for common column patterns, designed to be composed freely on any `DeclarativeBase` model.
## Overview
The `models` module provides mixins that each add a single, well-defined column behaviour. They work with standard SQLAlchemy 2.0 declarative syntax and are fully compatible with `AsyncSession`.
```python
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
class Article(Base, UUIDMixin, TimestampMixin):
__tablename__ = "articles"
title: Mapped[str]
content: Mapped[str]
```
All timestamp columns are timezone-aware (`TIMESTAMPTZ`). All defaults are server-side (`clock_timestamp()`), so they are also applied when inserting rows via raw SQL outside the ORM.
## Mixins
### [`UUIDMixin`](../reference/models.md#fastapi_toolsets.models.UUIDMixin)
Adds a `id: UUID` primary key generated server-side by PostgreSQL using `gen_random_uuid()`. The value is retrieved via `RETURNING` after insert, so it is available on the Python object immediately after `flush()`.
!!! warning "Requires PostgreSQL 13+"
```python
from fastapi_toolsets.models import UUIDMixin
class User(Base, UUIDMixin):
__tablename__ = "users"
username: Mapped[str]
# id is None before flush
user = User(username="alice")
session.add(user)
await session.flush()
print(user.id) # UUID('...')
```
### [`UUIDv7Mixin`](../reference/models.md#fastapi_toolsets.models.UUIDv7Mixin)
!!! info "Added in `v2.3`"
Adds a `id: UUID` primary key generated server-side by PostgreSQL using `uuidv7()`. It's a time-ordered UUID format that encodes a millisecond-precision timestamp in the most significant bits, making it naturally sortable and index-friendly.
!!! warning "Requires PostgreSQL 18+"
```python
from fastapi_toolsets.models import UUIDv7Mixin
class Event(Base, UUIDv7Mixin):
__tablename__ = "events"
name: Mapped[str]
# id is None before flush
event = Event(name="user.signup")
session.add(event)
await session.flush()
print(event.id) # UUID('019...')
```
### [`CreatedAtMixin`](../reference/models.md#fastapi_toolsets.models.CreatedAtMixin)
Adds a `created_at: datetime` column set to `clock_timestamp()` on insert. The column has no `onupdate` hook — it is intentionally immutable after the row is created.
```python
from fastapi_toolsets.models import UUIDMixin, CreatedAtMixin
class Order(Base, UUIDMixin, CreatedAtMixin):
__tablename__ = "orders"
total: Mapped[float]
```
### [`UpdatedAtMixin`](../reference/models.md#fastapi_toolsets.models.UpdatedAtMixin)
Adds an `updated_at: datetime` column set to `clock_timestamp()` on insert and automatically updated to `clock_timestamp()` on every ORM-level update (via SQLAlchemy's `onupdate` hook).
```python
from fastapi_toolsets.models import UUIDMixin, UpdatedAtMixin
class Post(Base, UUIDMixin, UpdatedAtMixin):
__tablename__ = "posts"
title: Mapped[str]
post = Post(title="Hello")
await session.flush()
await session.refresh(post)
post.title = "Hello World"
await session.flush()
await session.refresh(post)
print(post.updated_at)
```
!!! note
`updated_at` is updated by SQLAlchemy at ORM flush time. If you update rows via raw SQL (e.g. `UPDATE posts SET ...`), the column will **not** be updated automatically — use a database trigger if you need that guarantee.
### [`TimestampMixin`](../reference/models.md#fastapi_toolsets.models.TimestampMixin)
Convenience mixin that combines [`CreatedAtMixin`](../reference/models.md#fastapi_toolsets.models.CreatedAtMixin) and [`UpdatedAtMixin`](../reference/models.md#fastapi_toolsets.models.UpdatedAtMixin). Equivalent to inheriting both.
```python
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
class Article(Base, UUIDMixin, TimestampMixin):
__tablename__ = "articles"
title: Mapped[str]
```
## Lifecycle events
The event system provides lifecycle callbacks that fire **after commit**. If the transaction rolls back, no callback fires.
### Setup
Event dispatch requires [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession). Pass it as the session class when creating your session factory:
```python
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from fastapi_toolsets.models import EventSession
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
```
!!! info "Callbacks fire on `session.commit()` only — not on savepoints."
Savepoints created by [`get_transaction`](db.md) or `begin_nested()` do **not**
trigger callbacks. All events accumulated across flushes are dispatched once
when the outermost `commit()` is called.
### Events
Three event types are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
| Event | Trigger |
|---|---|
| `ModelEvent.CREATE` | After `INSERT` commit |
| `ModelEvent.DELETE` | After `DELETE` commit |
| `ModelEvent.UPDATE` | After `UPDATE` commit on a watched field |
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
### Watched fields
Set `__watched_fields__` on the model to restrict which field changes trigger `UPDATE` events. It must be a `tuple[str, ...]` — any other type raises `TypeError`:
| Class attribute | `UPDATE` behaviour |
|---|---|
| `__watched_fields__ = ("status", "role")` | Only fires when `status` or `role` changes |
| *(not set)* | Fires when **any** mapped field changes |
`__watched_fields__` is inherited through the class hierarchy via normal Python MRO. A subclass can override it:
```python
class Order(Base, UUIDMixin):
__watched_fields__ = ("status",)
...
class UrgentOrder(Order):
# inherits __watched_fields__ = ("status",)
...
class PriorityOrder(Order):
__watched_fields__ = ("priority",)
# overrides parent — UPDATE fires only for priority changes
...
```
### Registering handlers
Register handlers with the [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator. Every callback receives three arguments: the model instance, the [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) that triggered it, and a `changes` dict (`None` for `CREATE` and `DELETE`):
```python
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
class Order(Base, UUIDMixin):
__tablename__ = "orders"
__watched_fields__ = ("status",)
status: Mapped[str]
@listens_for(Order, [ModelEvent.CREATE])
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
await notify_new_order(order.id)
@listens_for(Order, [ModelEvent.DELETE])
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
await notify_order_cancelled(order.id)
@listens_for(Order, [ModelEvent.UPDATE])
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
if "status" in changes:
await notify_status_change(order.id, changes["status"])
```
Multiple handlers can be registered for the same model and event. Handlers registered on a parent class also fire for subclass instances.
A single handler can listen for multiple events at once. When `event_types` is omitted, the handler fires for all events:
```python
@listens_for(Order, [ModelEvent.CREATE, ModelEvent.UPDATE])
async def on_order_changed(order: Order, event_type: ModelEvent, changes: dict | None):
await invalidate_cache(order.id)
@listens_for(Order) # all events
async def on_any_order_event(order: Order, event_type: ModelEvent, changes: dict | None):
await audit_log(order.id, event_type)
```
### Field changes format
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included. For `CREATE` and `DELETE` events, `changes` is `None`:
```python
# CREATE / DELETE → changes is None
# status changed → {"status": {"old": "pending", "new": "shipped"}}
# two fields changed → {"status": {...}, "assigned_to": {...}}
```
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
---
[:material-api: API Reference](../reference/models.md)

95
docs/module/pytest.md Normal file
View File

@@ -0,0 +1,95 @@
# Pytest
Testing helpers for FastAPI applications with async client, database sessions, and parallel worker support.
## Installation
=== "uv"
``` bash
uv add "fastapi-toolsets[pytest]"
```
=== "pip"
``` bash
pip install "fastapi-toolsets[pytest]"
```
## Overview
The `pytest` module provides utilities for setting up async test clients, managing test database sessions, and supporting parallel test execution with `pytest-xdist`.
## Creating an async client
Use [`create_async_client`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_async_client) to get an `httpx.AsyncClient` configured for your FastAPI app:
```python
from fastapi_toolsets.pytest import create_async_client
@pytest.fixture
async def http_client(db_session):
async def _override_get_db():
yield db_session
async with create_async_client(
app=app,
base_url="http://127.0.0.1/api/v1",
dependency_overrides={get_db: _override_get_db},
) as c:
yield c
```
## Database sessions in tests
Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test, combined with [`create_worker_database`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_worker_database) to set up a per-worker database:
```python
from fastapi_toolsets.pytest import create_worker_database, create_db_session
@pytest.fixture(scope="session")
async def worker_db_url():
async with create_worker_database(
database_url=str(settings.SQLALCHEMY_DATABASE_URI)
) as url:
yield url
@pytest.fixture
async def db_session(worker_db_url):
async with create_db_session(
database_url=worker_db_url, base=Base, cleanup=True
) as session:
yield session
```
!!! info
In this example, the database is reset between each test using the argument `cleanup=True`.
Use [`worker_database_url`](../reference/pytest.md#fastapi_toolsets.pytest.utils.worker_database_url) to derive the per-worker URL manually if needed:
```python
from fastapi_toolsets.pytest import worker_database_url
url = worker_database_url("postgresql+asyncpg://user:pass@localhost/test_db", default_test_db="test")
# e.g. "postgresql+asyncpg://user:pass@localhost/test_db_gw0" under xdist
```
## Parallel testing with pytest-xdist
The examples above are already compatible with parallel test execution with `pytest-xdist`.
## Cleaning up tables
If you want to manually clean up a database you can use [`cleanup_tables`](../reference/db.md#fastapi_toolsets.db.cleanup_tables), this will truncate all tables between tests for fast isolation:
```python
from fastapi_toolsets.db import cleanup_tables
@pytest.fixture(autouse=True)
async def clean(db_session):
yield
await cleanup_tables(session=db_session, base=Base)
```
---
[:material-api: API Reference](../reference/pytest.md)

140
docs/module/schemas.md Normal file
View File

@@ -0,0 +1,140 @@
# Schemas
Standardized Pydantic response models for consistent API responses across your FastAPI application.
## Overview
The `schemas` module provides generic response wrappers that enforce a uniform response structure. All models use `from_attributes=True` for ORM compatibility and `validate_assignment=True` for runtime type safety.
## Response models
### [`Response[T]`](../reference/schemas.md#fastapi_toolsets.schemas.Response)
The most common wrapper for a single resource response.
```python
from fastapi_toolsets.schemas import Response
@router.get("/users/{id}")
async def get_user(user: User = UserDep) -> Response[UserSchema]:
return Response(data=user, message="User retrieved")
```
### Paginated response models
Three classes wrap paginated list results. Pick the one that matches your endpoint's strategy:
| Class | `pagination` type | `pagination_type` field | Use when |
|---|---|---|---|
| [`OffsetPaginatedResponse[T]`](#offsetpaginatedresponset) | `OffsetPagination` | `"offset"` (fixed) | endpoint always uses offset |
| [`CursorPaginatedResponse[T]`](#cursorpaginatedresponset) | `CursorPagination` | `"cursor"` (fixed) | endpoint always uses cursor |
| [`PaginatedResponse[T]`](#paginatedresponset) | `OffsetPagination \| CursorPagination` | — | unified endpoint supporting both strategies |
#### [`OffsetPaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPaginatedResponse)
!!! info "Added in `v2.3.0`"
Use as the return type when the endpoint always uses [`offset_paginate`](crud.md#offset-pagination). The `pagination` field is guaranteed to be an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object; the response always includes a `pagination_type: "offset"` discriminator.
```python
from fastapi_toolsets.schemas import OffsetPaginatedResponse
@router.get("/users")
async def list_users(
page: int = 1,
items_per_page: int = 20,
) -> OffsetPaginatedResponse[UserSchema]:
return await UserCrud.offset_paginate(
session, page=page, items_per_page=items_per_page, schema=UserSchema
)
```
**Response shape:**
```json
{
"status": "SUCCESS",
"pagination_type": "offset",
"data": ["..."],
"pagination": {
"total_count": 100,
"page": 1,
"items_per_page": 20,
"has_more": true
}
}
```
#### [`CursorPaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPaginatedResponse)
!!! info "Added in `v2.3.0`"
Use as the return type when the endpoint always uses [`cursor_paginate`](crud.md#cursor-pagination). The `pagination` field is guaranteed to be a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object; the response always includes a `pagination_type: "cursor"` discriminator.
```python
from fastapi_toolsets.schemas import CursorPaginatedResponse
@router.get("/events")
async def list_events(
cursor: str | None = None,
items_per_page: int = 20,
) -> CursorPaginatedResponse[EventSchema]:
return await EventCrud.cursor_paginate(
session, cursor=cursor, items_per_page=items_per_page, schema=EventSchema
)
```
**Response shape:**
```json
{
"status": "SUCCESS",
"pagination_type": "cursor",
"data": ["..."],
"pagination": {
"next_cursor": "eyJpZCI6IDQyfQ==",
"prev_cursor": null,
"items_per_page": 20,
"has_more": true
}
}
```
#### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
Return type for endpoints that support **both** pagination strategies via a `pagination_type` query parameter (using [`paginate()`](crud.md#unified-paginate--both-strategies-on-one-endpoint)).
When used as a return annotation, `PaginatedResponse[T]` automatically expands to `Annotated[Union[CursorPaginatedResponse[T], OffsetPaginatedResponse[T]], Field(discriminator="pagination_type")]`, so FastAPI emits a proper `oneOf` + discriminator in the OpenAPI schema with no extra boilerplate:
```python
from fastapi_toolsets.crud import PaginationType
from fastapi_toolsets.schemas import PaginatedResponse
@router.get("/users")
async def list_users(
pagination_type: PaginationType = PaginationType.OFFSET,
page: int = 1,
cursor: str | None = None,
items_per_page: int = 20,
) -> PaginatedResponse[UserSchema]:
return await UserCrud.paginate(
session,
pagination_type=pagination_type,
page=page,
cursor=cursor,
items_per_page=items_per_page,
schema=UserSchema,
)
```
#### Pagination metadata models
The optional `filter_attributes` field is populated when `facet_fields` are configured on the CRUD class (see [Filter attributes](crud.md#filter-attributes-facets)). It is `None` by default and can be hidden from API responses with `response_model_exclude_none=True`.
### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
Returned automatically by the exceptions handler.
---
[:material-api: API Reference](../reference/schemas.md)

267
docs/module/security.md Normal file
View File

@@ -0,0 +1,267 @@
# Security
Composable authentication helpers for FastAPI that use `Security()` for OpenAPI documentation and accept user-provided validator functions with full type flexibility.
## Overview
The `security` module provides four auth source classes and a `MultiAuth` factory. Each class wraps a FastAPI security scheme for OpenAPI and accepts a validator function called as:
```python
await validator(credential, **kwargs)
```
where `kwargs` are the extra keyword arguments provided at instantiation (roles, permissions, enums, etc.). The validator returns the authenticated identity (e.g. a `User` model) which becomes the route dependency value.
```python
from fastapi import Security
from fastapi_toolsets.security import BearerTokenAuth
async def verify_token(token: str, *, role: str) -> User:
user = await db.get_by_token(token)
if not user or user.role != role:
raise UnauthorizedError()
return user
bearer_admin = BearerTokenAuth(verify_token, role="admin")
@app.get("/admin")
async def admin_route(user: User = Security(bearer_admin)):
return user
```
## Auth sources
### [`BearerTokenAuth`](../reference/security.md#fastapi_toolsets.security.BearerTokenAuth)
Reads the `Authorization: Bearer <token>` header. Wraps `HTTPBearer` for OpenAPI.
```python
from fastapi_toolsets.security import BearerTokenAuth
bearer = BearerTokenAuth(validator=verify_token)
@app.get("/me")
async def me(user: User = Security(bearer)):
return user
```
#### Token prefix
The optional `prefix` parameter restricts a `BearerTokenAuth` instance to tokens
that start with a given string. The prefix is **kept** in the value passed to the
validator — store and compare tokens with their prefix included.
This lets you deploy multiple `BearerTokenAuth` instances in the same application
and disambiguate them efficiently in `MultiAuth`:
```python
user_bearer = BearerTokenAuth(verify_user, prefix="user_") # matches "Bearer user_..."
org_bearer = BearerTokenAuth(verify_org, prefix="org_") # matches "Bearer org_..."
```
Use [`generate_token()`](#token-generation) to create correctly-prefixed tokens.
#### Token generation
`BearerTokenAuth.generate_token()` produces a secure random token ready to store
in your database and return to the client. If a prefix is configured it is
prepended automatically:
```python
bearer = BearerTokenAuth(verify_token, prefix="user_")
token = bearer.generate_token() # e.g. "user_Xk3mN..."
await db.store_token(user_id, token)
return {"access_token": token, "token_type": "bearer"}
```
The client sends `Authorization: Bearer user_Xk3mN...` and the validator receives
the full token (prefix included) to compare against the stored value.
### [`CookieAuth`](../reference/security.md#fastapi_toolsets.security.CookieAuth)
Reads a named cookie. Wraps `APIKeyCookie` for OpenAPI.
```python
from fastapi_toolsets.security import CookieAuth
cookie_auth = CookieAuth("session", validator=verify_session)
@app.get("/me")
async def me(user: User = Security(cookie_auth)):
return user
```
### [`OAuth2Auth`](../reference/security.md#fastapi_toolsets.security.OAuth2Auth)
Reads the `Authorization: Bearer <token>` header and registers the token endpoint
in OpenAPI via `OAuth2PasswordBearer`.
```python
from fastapi_toolsets.security import OAuth2Auth
oauth2_auth = OAuth2Auth(token_url="/token", validator=verify_token)
@app.get("/me")
async def me(user: User = Security(oauth2_auth)):
return user
```
### [`OpenIDAuth`](../reference/security.md#fastapi_toolsets.security.OpenIDAuth)
Reads the `Authorization: Bearer <token>` header and registers the OpenID Connect
discovery URL in OpenAPI via `OpenIdConnect`. Token validation is fully delegated
to your validator — use any OIDC / JWT library (`authlib`, `python-jose`, `PyJWT`).
```python
from fastapi_toolsets.security import OpenIDAuth
async def verify_google_token(token: str, *, audience: str) -> User:
payload = jwt.decode(token, google_public_keys, algorithms=["RS256"],
audience=audience)
return User(email=payload["email"], name=payload["name"])
google_auth = OpenIDAuth(
"https://accounts.google.com/.well-known/openid-configuration",
verify_google_token,
audience="my-client-id",
)
@app.get("/me")
async def me(user: User = Security(google_auth)):
return user
```
The discovery URL is used **only for OpenAPI documentation** — no requests are made
to it by this class. You are responsible for fetching and caching the provider's
public keys in your validator.
Multiple providers work naturally with `MultiAuth`:
```python
multi = MultiAuth(google_auth, github_auth)
@app.get("/data")
async def data(user: User = Security(multi)):
return user
```
## Typed validator kwargs
All auth classes forward extra instantiation keyword arguments to the validator.
Arguments can be any type — enums, strings, integers, etc. The validator returns
the authenticated identity, which FastAPI injects directly into the route handler.
```python
async def verify_token(token: str, *, role: Role, permission: str) -> User:
user = await decode_token(token)
if user.role != role or permission not in user.permissions:
raise UnauthorizedError()
return user
bearer = BearerTokenAuth(verify_token, role=Role.ADMIN, permission="billing:read")
```
Each auth instance is self-contained — create a separate instance per distinct
requirement instead of passing requirements through `Security(scopes=[...])`.
### Using `.require()` inline
If declaring a new top-level variable per role feels verbose, use `.require()` to
create a configured clone directly in the route decorator. The original instance
is not mutated:
```python
bearer = BearerTokenAuth(verify_token)
@app.get("/admin/stats")
async def admin_stats(user: User = Security(bearer.require(role=Role.ADMIN))):
return {"message": f"Hello admin {user.name}"}
@app.get("/profile")
async def profile(user: User = Security(bearer.require(role=Role.USER))):
return {"id": user.id, "name": user.name}
```
`.require()` kwargs are merged over existing ones — new values win on conflict.
The `prefix` (for `BearerTokenAuth`) and cookie name (for `CookieAuth`) are
always preserved.
`.require()` instances work transparently inside `MultiAuth`:
```python
multi = MultiAuth(
user_bearer.require(role=Role.USER),
org_bearer.require(role=Role.ADMIN),
)
```
## MultiAuth
[`MultiAuth`](../reference/security.md#fastapi_toolsets.security.MultiAuth) combines
multiple auth sources into a single callable. Sources are tried in order; the
first one that finds a credential wins.
```python
from fastapi_toolsets.security import MultiAuth
multi = MultiAuth(user_bearer, org_bearer, cookie_auth)
@app.get("/data")
async def data_route(user = Security(multi)):
return user
```
### Using `.require()` on MultiAuth
`MultiAuth` also supports `.require()`, which propagates the kwargs to every
source that implements it. Sources that do not (e.g. custom `AuthSource`
subclasses) are passed through unchanged:
```python
multi = MultiAuth(bearer, cookie)
@app.get("/admin")
async def admin(user: User = Security(multi.require(role=Role.ADMIN))):
return user
```
This is equivalent to calling `.require()` on each source individually:
```python
# These two are identical
multi.require(role=Role.ADMIN)
MultiAuth(
bearer.require(role=Role.ADMIN),
cookie.require(role=Role.ADMIN),
)
```
### Prefix-based dispatch
Because `extract()` is pure string matching (no I/O), prefix-based source
selection is essentially free. Only the matching source's validator (which may
involve DB or network I/O) is ever called:
```python
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
multi = MultiAuth(user_bearer, org_bearer)
# "Bearer user_alice" → only verify_user runs, receives "user_alice"
# "Bearer org_acme" → only verify_org runs, receives "org_acme"
```
Tokens are stored and compared **with their prefix** — use `generate_token()` on
each source to issue correctly-prefixed tokens:
```python
user_token = user_bearer.generate_token() # "user_..."
org_token = org_bearer.generate_token() # "org_..."
```
---
[:material-api: API Reference](../reference/security.md)

7
docs/overrides/main.html Normal file
View File

@@ -0,0 +1,7 @@
{% extends "base.html" %} {% block extrahead %}
<script
defer
src="https://analytics.d3vyce.fr/script.js"
data-website-id="338b8816-7b99-4c6a-82f3-15595be3fd47"
></script>
{{ super() }} {% endblock %}

27
docs/reference/cli.md Normal file
View File

@@ -0,0 +1,27 @@
# `cli`
Here's the reference for the CLI configuration helpers used to load settings from `pyproject.toml`.
You can import them directly from `fastapi_toolsets.cli.config`:
```python
from fastapi_toolsets.cli.config import (
import_from_string,
get_config_value,
get_fixtures_registry,
get_db_context,
get_custom_cli,
)
```
## ::: fastapi_toolsets.cli.config.import_from_string
## ::: fastapi_toolsets.cli.config.get_config_value
## ::: fastapi_toolsets.cli.config.get_fixtures_registry
## ::: fastapi_toolsets.cli.config.get_db_context
## ::: fastapi_toolsets.cli.config.get_custom_cli
## ::: fastapi_toolsets.cli.utils.async_command

20
docs/reference/crud.md Normal file
View File

@@ -0,0 +1,20 @@
# `crud`
Here's the reference for the CRUD classes, factory, and search utilities.
You can import the main symbols from `fastapi_toolsets.crud`:
```python
from fastapi_toolsets.crud import CrudFactory, AsyncCrud
from fastapi_toolsets.crud.search import SearchConfig, get_searchable_fields, build_search_filters
```
## ::: fastapi_toolsets.crud.factory.AsyncCrud
## ::: fastapi_toolsets.crud.factory.CrudFactory
## ::: fastapi_toolsets.crud.search.SearchConfig
## ::: fastapi_toolsets.crud.search.get_searchable_fields
## ::: fastapi_toolsets.crud.search.build_search_filters

34
docs/reference/db.md Normal file
View File

@@ -0,0 +1,34 @@
# `db`
Here's the reference for all database session utilities, transaction helpers, and locking functions.
You can import them directly from `fastapi_toolsets.db`:
```python
from fastapi_toolsets.db import (
LockMode,
cleanup_tables,
create_database,
create_db_dependency,
create_db_context,
get_transaction,
lock_tables,
wait_for_row_change,
)
```
## ::: fastapi_toolsets.db.LockMode
## ::: fastapi_toolsets.db.create_db_dependency
## ::: fastapi_toolsets.db.create_db_context
## ::: fastapi_toolsets.db.get_transaction
## ::: fastapi_toolsets.db.lock_tables
## ::: fastapi_toolsets.db.wait_for_row_change
## ::: fastapi_toolsets.db.create_database
## ::: fastapi_toolsets.db.cleanup_tables

View File

@@ -0,0 +1,13 @@
# `dependencies`
Here's the reference for the FastAPI dependency factory functions.
You can import them directly from `fastapi_toolsets.dependencies`:
```python
from fastapi_toolsets.dependencies import PathDependency, BodyDependency
```
## ::: fastapi_toolsets.dependencies.PathDependency
## ::: fastapi_toolsets.dependencies.BodyDependency

View File

@@ -0,0 +1,40 @@
# `exceptions`
Here's the reference for all exception classes and handler utilities.
You can import them directly from `fastapi_toolsets.exceptions`:
```python
from fastapi_toolsets.exceptions import (
ApiException,
UnauthorizedError,
ForbiddenError,
NotFoundError,
ConflictError,
NoSearchableFieldsError,
InvalidFacetFilterError,
InvalidOrderFieldError,
generate_error_responses,
init_exceptions_handlers,
)
```
## ::: fastapi_toolsets.exceptions.exceptions.ApiException
## ::: fastapi_toolsets.exceptions.exceptions.UnauthorizedError
## ::: fastapi_toolsets.exceptions.exceptions.ForbiddenError
## ::: fastapi_toolsets.exceptions.exceptions.NotFoundError
## ::: fastapi_toolsets.exceptions.exceptions.ConflictError
## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
## ::: fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers

View File

@@ -0,0 +1,31 @@
# `fixtures`
Here's the reference for the fixture registry, enums, and loading utilities.
You can import them directly from `fastapi_toolsets.fixtures`:
```python
from fastapi_toolsets.fixtures import (
Context,
LoadStrategy,
Fixture,
FixtureRegistry,
load_fixtures,
load_fixtures_by_context,
get_obj_by_attr,
)
```
## ::: fastapi_toolsets.fixtures.enum.Context
## ::: fastapi_toolsets.fixtures.enum.LoadStrategy
## ::: fastapi_toolsets.fixtures.registry.Fixture
## ::: fastapi_toolsets.fixtures.registry.FixtureRegistry
## ::: fastapi_toolsets.fixtures.utils.load_fixtures
## ::: fastapi_toolsets.fixtures.utils.load_fixtures_by_context
## ::: fastapi_toolsets.fixtures.utils.get_obj_by_attr

13
docs/reference/logger.md Normal file
View File

@@ -0,0 +1,13 @@
# `logger`
Here's the reference for the logging utilities.
You can import them directly from `fastapi_toolsets.logger`:
```python
from fastapi_toolsets.logger import configure_logging, get_logger
```
## ::: fastapi_toolsets.logger.configure_logging
## ::: fastapi_toolsets.logger.get_logger

15
docs/reference/metrics.md Normal file
View File

@@ -0,0 +1,15 @@
# `metrics`
Here's the reference for the Prometheus metrics registry and endpoint handler.
You can import them directly from `fastapi_toolsets.metrics`:
```python
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
```
## ::: fastapi_toolsets.metrics.registry.Metric
## ::: fastapi_toolsets.metrics.registry.MetricsRegistry
## ::: fastapi_toolsets.metrics.handler.init_metrics

34
docs/reference/models.md Normal file
View File

@@ -0,0 +1,34 @@
# `models`
Here's the reference for the SQLAlchemy model mixins provided by the `models` module.
You can import them directly from `fastapi_toolsets.models`:
```python
from fastapi_toolsets.models import (
EventSession,
ModelEvent,
UUIDMixin,
UUIDv7Mixin,
CreatedAtMixin,
UpdatedAtMixin,
TimestampMixin,
listens_for,
)
```
## ::: fastapi_toolsets.models.EventSession
## ::: fastapi_toolsets.models.ModelEvent
## ::: fastapi_toolsets.models.UUIDMixin
## ::: fastapi_toolsets.models.UUIDv7Mixin
## ::: fastapi_toolsets.models.CreatedAtMixin
## ::: fastapi_toolsets.models.UpdatedAtMixin
## ::: fastapi_toolsets.models.TimestampMixin
## ::: fastapi_toolsets.models.listens_for

26
docs/reference/pytest.md Normal file
View File

@@ -0,0 +1,26 @@
# `pytest`
Here's the reference for all testing utilities and pytest fixtures.
You can import them directly from `fastapi_toolsets.pytest`:
```python
from fastapi_toolsets.pytest import (
register_fixtures,
create_async_client,
create_db_session,
worker_database_url,
create_worker_database,
cleanup_tables,
)
```
## ::: fastapi_toolsets.pytest.plugin.register_fixtures
## ::: fastapi_toolsets.pytest.utils.create_async_client
## ::: fastapi_toolsets.pytest.utils.create_db_session
## ::: fastapi_toolsets.pytest.utils.worker_database_url
## ::: fastapi_toolsets.pytest.utils.create_worker_database

46
docs/reference/schemas.md Normal file
View File

@@ -0,0 +1,46 @@
# `schemas`
Here's the reference for all response models and types provided by the `schemas` module.
You can import them directly from `fastapi_toolsets.schemas`:
```python
from fastapi_toolsets.schemas import (
PydanticBase,
ResponseStatus,
ApiError,
BaseResponse,
Response,
ErrorResponse,
OffsetPagination,
CursorPagination,
PaginationType,
PaginatedResponse,
OffsetPaginatedResponse,
CursorPaginatedResponse,
)
```
## ::: fastapi_toolsets.schemas.PydanticBase
## ::: fastapi_toolsets.schemas.ResponseStatus
## ::: fastapi_toolsets.schemas.ApiError
## ::: fastapi_toolsets.schemas.BaseResponse
## ::: fastapi_toolsets.schemas.Response
## ::: fastapi_toolsets.schemas.ErrorResponse
## ::: fastapi_toolsets.schemas.OffsetPagination
## ::: fastapi_toolsets.schemas.CursorPagination
## ::: fastapi_toolsets.schemas.PaginationType
## ::: fastapi_toolsets.schemas.PaginatedResponse
## ::: fastapi_toolsets.schemas.OffsetPaginatedResponse
## ::: fastapi_toolsets.schemas.CursorPaginatedResponse

View File

@@ -0,0 +1,28 @@
# `security`
Here's the reference for the authentication helpers provided by the `security` module.
You can import them directly from `fastapi_toolsets.security`:
```python
from fastapi_toolsets.security import (
AuthSource,
BearerTokenAuth,
CookieAuth,
OAuth2Auth,
OpenIDAuth,
MultiAuth,
)
```
## ::: fastapi_toolsets.security.AuthSource
## ::: fastapi_toolsets.security.BearerTokenAuth
## ::: fastapi_toolsets.security.CookieAuth
## ::: fastapi_toolsets.security.OAuth2Auth
## ::: fastapi_toolsets.security.OpenIDAuth
## ::: fastapi_toolsets.security.MultiAuth

0
docs_src/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,9 @@
from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers
from .routes import router
app = FastAPI()
init_exceptions_handlers(app=app)
app.include_router(router=router)

View File

@@ -0,0 +1,9 @@
from fastapi_toolsets.crud import CrudFactory
from .models import OAuthAccount, OAuthProvider, Team, User, UserToken
TeamCrud = CrudFactory(model=Team)
UserCrud = CrudFactory(model=User)
UserTokenCrud = CrudFactory(model=UserToken)
OAuthProviderCrud = CrudFactory(model=OAuthProvider)
OAuthAccountCrud = CrudFactory(model=OAuthAccount)

View File

@@ -0,0 +1,15 @@
from fastapi import Depends
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from fastapi_toolsets.db import create_db_context, create_db_dependency
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
engine = create_async_engine(url=DATABASE_URL, future=True)
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
get_db = create_db_dependency(session_maker=async_session_maker)
get_db_context = create_db_context(session_maker=async_session_maker)
SessionDep = Depends(get_db)

View File

@@ -0,0 +1,105 @@
import enum
from datetime import datetime
from uuid import UUID
from sqlalchemy import (
Boolean,
DateTime,
Enum,
ForeignKey,
Integer,
String,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from fastapi_toolsets.models import TimestampMixin, UUIDMixin
class Base(DeclarativeBase, UUIDMixin):
type_annotation_map = {
str: String(),
int: Integer(),
UUID: PG_UUID(as_uuid=True),
datetime: DateTime(timezone=True),
}
class UserRole(enum.Enum):
admin = "admin"
moderator = "moderator"
user = "user"
class Team(Base, TimestampMixin):
__tablename__ = "teams"
name: Mapped[str] = mapped_column(String, unique=True, index=True)
users: Mapped[list["User"]] = relationship(back_populates="team")
class User(Base, TimestampMixin):
__tablename__ = "users"
username: Mapped[str] = mapped_column(String, unique=True, index=True)
email: Mapped[str | None] = mapped_column(
String, unique=True, index=True, nullable=True
)
hashed_password: Mapped[str | None] = mapped_column(String, nullable=True)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
role: Mapped[UserRole] = mapped_column(Enum(UserRole), default=UserRole.user)
team_id: Mapped[UUID | None] = mapped_column(ForeignKey("teams.id"), nullable=True)
team: Mapped["Team | None"] = relationship(back_populates="users")
oauth_accounts: Mapped[list["OAuthAccount"]] = relationship(back_populates="user")
tokens: Mapped[list["UserToken"]] = relationship(back_populates="user")
class UserToken(Base, TimestampMixin):
"""API tokens for a user (multiple allowed)."""
__tablename__ = "user_tokens"
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
# Store hashed token value
token_hash: Mapped[str] = mapped_column(String, unique=True, index=True)
name: Mapped[str | None] = mapped_column(String, nullable=True)
expires_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
user: Mapped["User"] = relationship(back_populates="tokens")
class OAuthProvider(Base, TimestampMixin):
"""Configurable OAuth2 / OpenID Connect provider."""
__tablename__ = "oauth_providers"
slug: Mapped[str] = mapped_column(String, unique=True, index=True)
name: Mapped[str] = mapped_column(String)
client_id: Mapped[str] = mapped_column(String)
client_secret: Mapped[str] = mapped_column(String)
discovery_url: Mapped[str] = mapped_column(String, nullable=False)
scopes: Mapped[str] = mapped_column(String, default="openid email profile")
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
accounts: Mapped[list["OAuthAccount"]] = relationship(back_populates="provider")
class OAuthAccount(Base, TimestampMixin):
"""OAuth2 / OpenID Connect account linked to a user."""
__tablename__ = "oauth_accounts"
__table_args__ = (
UniqueConstraint("provider_id", "subject", name="uq_oauth_provider_subject"),
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
provider_id: Mapped[UUID] = mapped_column(ForeignKey("oauth_providers.id"))
# OAuth `sub` / OpenID subject identifier
subject: Mapped[str] = mapped_column(String)
user: Mapped["User"] = relationship(back_populates="oauth_accounts")
provider: Mapped["OAuthProvider"] = relationship(back_populates="accounts")

View File

@@ -0,0 +1,122 @@
from typing import Annotated
from uuid import UUID
import bcrypt
from fastapi import APIRouter, Form, HTTPException, Response, Security
from fastapi_toolsets.dependencies import PathDependency
from .crud import UserCrud, UserTokenCrud
from .db import SessionDep
from .models import OAuthProvider, User, UserToken
from .schemas import (
ApiTokenCreateRequest,
ApiTokenResponse,
RegisterRequest,
UserCreate,
UserResponse,
)
from .security import auth, cookie_auth, create_api_token
ProviderDep = PathDependency(
model=OAuthProvider,
field=OAuthProvider.slug,
session_dep=SessionDep,
param_name="slug",
)
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain: str, hashed: str) -> bool:
return bcrypt.checkpw(plain.encode(), hashed.encode())
router = APIRouter(prefix="/auth")
@router.post("/register", response_model=UserResponse, status_code=201)
async def register(body: RegisterRequest, session: SessionDep):
existing = await UserCrud.first(
session=session, filters=[User.username == body.username]
)
if existing:
raise HTTPException(status_code=409, detail="Username already taken")
user = await UserCrud.create(
session=session,
obj=UserCreate(
username=body.username,
email=body.email,
hashed_password=hash_password(body.password),
),
)
return user
@router.post("/token", status_code=204)
async def login(
session: SessionDep,
response: Response,
username: Annotated[str, Form()],
password: Annotated[str, Form()],
):
user = await UserCrud.first(session=session, filters=[User.username == username])
if (
not user
or not user.hashed_password
or not verify_password(password, user.hashed_password)
):
raise HTTPException(status_code=401, detail="Invalid credentials")
if not user.is_active:
raise HTTPException(status_code=403, detail="Account disabled")
cookie_auth.set_cookie(response, str(user.id))
@router.post("/logout", status_code=204)
async def logout(response: Response):
cookie_auth.delete_cookie(response)
@router.get("/me", response_model=UserResponse)
async def me(user: User = Security(auth)):
return user
@router.post("/tokens", response_model=ApiTokenResponse, status_code=201)
async def create_token(
body: ApiTokenCreateRequest,
user: User = Security(auth),
):
raw, token_row = await create_api_token(
user.id, name=body.name, expires_at=body.expires_at
)
return ApiTokenResponse(
id=token_row.id,
name=token_row.name,
expires_at=token_row.expires_at,
created_at=token_row.created_at,
token=raw,
)
@router.delete("/tokens/{token_id}", status_code=204)
async def revoke_token(
session: SessionDep,
token_id: UUID,
user: User = Security(auth),
):
if not await UserTokenCrud.first(
session=session,
filters=[UserToken.id == token_id, UserToken.user_id == user.id],
):
raise HTTPException(status_code=404, detail="Token not found")
await UserTokenCrud.delete(
session=session,
filters=[UserToken.id == token_id, UserToken.user_id == user.id],
)

View File

@@ -0,0 +1,64 @@
from datetime import datetime
from uuid import UUID
from pydantic import EmailStr
from fastapi_toolsets.schemas import PydanticBase
class RegisterRequest(PydanticBase):
username: str
password: str
email: EmailStr | None = None
class UserResponse(PydanticBase):
id: UUID
username: str
email: str | None
role: str
is_active: bool
model_config = {"from_attributes": True}
class ApiTokenCreateRequest(PydanticBase):
name: str | None = None
expires_at: datetime | None = None
class ApiTokenResponse(PydanticBase):
id: UUID
name: str | None
expires_at: datetime | None
created_at: datetime
# Only populated on creation
token: str | None = None
model_config = {"from_attributes": True}
class OAuthProviderResponse(PydanticBase):
slug: str
name: str
model_config = {"from_attributes": True}
class UserCreate(PydanticBase):
username: str
email: str | None = None
hashed_password: str | None = None
class UserTokenCreate(PydanticBase):
user_id: UUID
token_hash: str
name: str | None = None
expires_at: datetime | None = None
class OAuthAccountCreate(PydanticBase):
user_id: UUID
provider_id: UUID
subject: str

View File

@@ -0,0 +1,100 @@
import hashlib
from datetime import datetime, timezone
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy.orm import selectinload
from fastapi_toolsets.exceptions import UnauthorizedError
from fastapi_toolsets.security import (
APIKeyHeaderAuth,
BearerTokenAuth,
CookieAuth,
MultiAuth,
)
from .crud import UserCrud, UserTokenCrud
from .db import get_db_context
from .models import User, UserRole, UserToken
from .schemas import UserTokenCreate
SESSION_COOKIE = "session"
SECRET_KEY = "123456789"
def _hash_token(token: str) -> str:
return hashlib.sha256(token.encode()).hexdigest()
async def _verify_token(token: str, role: UserRole | None = None) -> User:
async with get_db_context() as db:
user_token = await UserTokenCrud.first(
session=db,
filters=[UserToken.token_hash == _hash_token(token)],
load_options=[selectinload(UserToken.user)],
)
if user_token is None or not user_token.user.is_active:
raise UnauthorizedError()
if user_token.expires_at and user_token.expires_at < datetime.now(timezone.utc):
raise UnauthorizedError()
user = user_token.user
if role is not None and user.role != role:
raise HTTPException(status_code=403, detail="Insufficient permissions")
return user
async def _verify_cookie(user_id: str, role: UserRole | None = None) -> User:
async with get_db_context() as db:
user = await UserCrud.first(
session=db,
filters=[User.id == UUID(user_id)],
)
if not user or not user.is_active:
raise UnauthorizedError()
if role is not None and user.role != role:
raise HTTPException(status_code=403, detail="Insufficient permissions")
return user
bearer_auth = BearerTokenAuth(
validator=_verify_token,
prefix="ctf_",
)
header_auth = APIKeyHeaderAuth(
name="X-API-Key",
validator=_verify_token,
)
cookie_auth = CookieAuth(
name=SESSION_COOKIE,
validator=_verify_cookie,
secret_key=SECRET_KEY,
)
auth = MultiAuth(bearer_auth, header_auth, cookie_auth)
async def create_api_token(
user_id: UUID,
*,
name: str | None = None,
expires_at: datetime | None = None,
) -> tuple[str, UserToken]:
raw = bearer_auth.generate_token()
async with get_db_context() as db:
token_row = await UserTokenCrud.create(
session=db,
obj=UserTokenCreate(
user_id=user_id,
token_hash=_hash_token(raw),
name=name,
expires_at=expires_at,
),
)
return raw, token_row

View File

@@ -0,0 +1,9 @@
from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers
from .routes import router
app = FastAPI()
init_exceptions_handlers(app=app)
app.include_router(router=router)

View File

@@ -0,0 +1,21 @@
from fastapi_toolsets.crud import CrudFactory
from .models import Article, Category
ArticleCrud = CrudFactory(
model=Article,
cursor_column=Article.created_at,
searchable_fields=[ # default fields for full-text search
Article.title,
Article.body,
(Article.category, Category.name),
],
facet_fields=[ # fields exposed as filter dropdowns
Article.status,
(Article.category, Category.name),
],
order_fields=[ # fields exposed for client-driven ordering
Article.title,
Article.created_at,
],
)

View File

@@ -0,0 +1,17 @@
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from fastapi_toolsets.db import create_db_context, create_db_dependency
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
engine = create_async_engine(url=DATABASE_URL, future=True)
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
get_db = create_db_dependency(session_maker=async_session_maker)
get_db_context = create_db_context(session_maker=async_session_maker)
SessionDep = Annotated[AsyncSession, Depends(get_db)]

View File

@@ -0,0 +1,34 @@
import uuid
from sqlalchemy import Boolean, ForeignKey, String, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from fastapi_toolsets.models import CreatedAtMixin
class Base(DeclarativeBase):
pass
class Category(Base):
__tablename__ = "categories"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(64), unique=True)
articles: Mapped[list["Article"]] = relationship(back_populates="category")
class Article(Base, CreatedAtMixin):
__tablename__ = "articles"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(256))
body: Mapped[str] = mapped_column(Text)
status: Mapped[str] = mapped_column(String(32))
published: Mapped[bool] = mapped_column(Boolean, default=False)
category_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("categories.id"), nullable=True
)
category: Mapped["Category | None"] = relationship(back_populates="articles")

View File

@@ -0,0 +1,79 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi_toolsets.schemas import (
CursorPaginatedResponse,
OffsetPaginatedResponse,
PaginatedResponse,
)
from .crud import ArticleCrud
from .db import SessionDep
from .models import Article
from .schemas import ArticleRead
router = APIRouter(prefix="/articles")
@router.get("/offset")
async def list_articles_offset(
session: SessionDep,
params: Annotated[
dict,
Depends(
ArticleCrud.offset_paginate_params(
default_page_size=20,
max_page_size=100,
default_order_field=Article.created_at,
)
),
],
) -> OffsetPaginatedResponse[ArticleRead]:
return await ArticleCrud.offset_paginate(
session=session,
**params,
schema=ArticleRead,
)
@router.get("/cursor")
async def list_articles_cursor(
session: SessionDep,
params: Annotated[
dict,
Depends(
ArticleCrud.cursor_paginate_params(
default_page_size=20,
max_page_size=100,
default_order_field=Article.created_at,
)
),
],
) -> CursorPaginatedResponse[ArticleRead]:
return await ArticleCrud.cursor_paginate(
session=session,
**params,
schema=ArticleRead,
)
@router.get("/")
async def list_articles(
session: SessionDep,
params: Annotated[
dict,
Depends(
ArticleCrud.paginate_params(
default_page_size=20,
max_page_size=100,
default_order_field=Article.created_at,
)
),
],
) -> PaginatedResponse[ArticleRead]:
return await ArticleCrud.paginate(
session,
**params,
schema=ArticleRead,
)

View File

@@ -0,0 +1,13 @@
import datetime
import uuid
from fastapi_toolsets.schemas import PydanticBase
class ArticleRead(PydanticBase):
id: uuid.UUID
created_at: datetime.datetime
title: str
status: str
published: bool
category_id: uuid.UUID | None

View File

@@ -1,7 +1,7 @@
[project]
name = "fastapi-toolsets"
version = "0.7.1"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
version = "3.0.3"
description = "Production-ready utilities for FastAPI applications"
readme = "README.md"
license = "MIT"
license-files = ["LICENSE"]
@@ -11,7 +11,7 @@ authors = [
]
keywords = ["fastapi", "sqlalchemy", "postgresql"]
classifiers = [
"Development Status :: 4 - Beta",
"Development Status :: 5 - Production/Stable",
"Framework :: AsyncIO",
"Framework :: FastAPI",
"Framework :: Pydantic",
@@ -31,12 +31,10 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
"fastapi>=0.100.0",
"sqlalchemy[asyncio]>=2.0",
"asyncpg>=0.29.0",
"fastapi>=0.100.0",
"pydantic>=2.0",
"typer>=0.9.0",
"httpx>=0.25.0",
"sqlalchemy[asyncio]>=2.0",
]
[project.urls]
@@ -46,23 +44,53 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets"
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
[project.optional-dependencies]
test = [
"pytest>=8.0.0",
"pytest-anyio>=0.0.0",
"coverage>=7.0.0",
"pytest-cov>=4.0.0",
cli = [
"typer>=0.9.0",
]
dev = [
"fastapi-toolsets[test]",
"ruff>=0.1.0",
"ty>=0.0.1a0",
metrics = [
"prometheus_client>=0.20.0",
]
pytest = [
"httpx>=0.25.0",
"pytest-xdist>=3.0.0",
"pytest>=8.0.0",
]
all = [
"fastapi-toolsets[cli,metrics,pytest]",
]
[project.scripts]
manager = "fastapi_toolsets.cli.app:cli"
[dependency-groups]
dev = [
{include-group = "tests"},
{include-group = "docs"},
{include-group = "docs-src"},
"fastapi-toolsets[all]",
"prek>=0.3.8",
"ruff>=0.1.0",
"ty>=0.0.1a0",
]
tests = [
"coverage>=7.0.0",
"httpx>=0.25.0",
"pytest-anyio>=0.0.0",
"pytest-cov>=4.0.0",
"pytest-xdist>=3.0.0",
"pytest>=8.0.0",
]
docs = [
"mike",
"mkdocstrings-python>=2.0.2",
"zensical>=0.0.30",
]
docs-src = [
"bcrypt>=4.0.0",
]
[build-system]
requires = ["uv_build>=0.10,<0.11.0"]
requires = ["uv_build>=0.10,<0.12.0"]
build-backend = "uv_build"
[tool.pytest.ini_options]
@@ -81,3 +109,6 @@ exclude_lines = [
"if TYPE_CHECKING:",
"raise NotImplementedError",
]
[tool.uv.sources]
mike = { git = "https://github.com/squidfunk/mike.git", tag = "2.2.0+zensical-0.1.0" }

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success")
"""
__version__ = "0.7.1"
__version__ = "3.0.3"

View File

@@ -0,0 +1,9 @@
"""Optional dependency helpers."""
def require_extra(package: str, extra: str) -> None:
"""Raise *ImportError* with an actionable install instruction."""
raise ImportError(
f"'{package}' is required to use this module. "
f"Install it with: pip install fastapi-toolsets[{extra}]"
)

View File

@@ -1,6 +1,11 @@
"""Main CLI application."""
import typer
try:
import typer
except ImportError:
from .._imports import require_extra
require_extra(package="typer", extra="cli")
from ..logger import configure_logging
from .config import get_custom_cli

View File

@@ -72,7 +72,7 @@ async def load(
registry = get_fixtures_registry()
db_context = get_db_context()
context_list = [c.value for c in contexts] if contexts else [Context.BASE]
context_list = list(contexts) if contexts else [Context.BASE]
ordered = registry.resolve_context_dependencies(*context_list)

View File

@@ -1,12 +1,18 @@
"""CLI configuration and dynamic imports."""
from __future__ import annotations
import importlib
import sys
from typing import TYPE_CHECKING, Any, Literal, overload
import typer
from .pyproject import find_pyproject, load_pyproject
if TYPE_CHECKING:
from ..fixtures import FixtureRegistry
def _ensure_project_in_path():
"""Add project root to sys.path if not installed in editable mode."""
@@ -17,11 +23,17 @@ def _ensure_project_in_path():
sys.path.insert(0, project_root)
def import_from_string(import_path: str):
"""Import an object from a string path like 'module.submodule:attribute'.
def import_from_string(import_path: str) -> Any:
"""Import an object from a dotted string path.
Args:
import_path: Import path in ``"module.submodule:attribute"`` format
Returns:
The imported attribute
Raises:
typer.BadParameter: If the import path is invalid or import fails.
typer.BadParameter: If the import path is invalid or import fails
"""
if ":" not in import_path:
raise typer.BadParameter(
@@ -45,7 +57,13 @@ def import_from_string(import_path: str):
return getattr(module, attr_name)
def get_config_value(key: str, required: bool = False):
@overload
def get_config_value(key: str, required: Literal[True]) -> Any: ... # pragma: no cover
@overload
def get_config_value(
key: str, required: bool = False
) -> Any | None: ... # pragma: no cover
def get_config_value(key: str, required: bool = False) -> Any | None:
"""Get a configuration value from pyproject.toml.
Args:
@@ -70,7 +88,7 @@ def get_config_value(key: str, required: bool = False):
return value
def get_fixtures_registry():
def get_fixtures_registry() -> FixtureRegistry:
"""Import and return the fixtures registry from config."""
from ..fixtures import FixtureRegistry
@@ -85,7 +103,7 @@ def get_fixtures_registry():
return registry
def get_db_context():
def get_db_context() -> Any:
"""Import and return the db_context function from config."""
import_path = get_config_value("db_context", required=True)
return import_from_string(import_path)

View File

@@ -10,6 +10,12 @@ def find_pyproject(start_path: Path | None = None) -> Path | None:
"""Find pyproject.toml by walking up the directory tree.
Similar to how pytest, black, and ruff discover their config files.
Args:
start_path: Directory to start searching from. Defaults to cwd.
Returns:
Path to pyproject.toml, or None if not found.
"""
path = (start_path or Path.cwd()).resolve()

View File

@@ -13,11 +13,13 @@ def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
"""Decorator to run an async function as a sync CLI command.
Example:
```python
@fixture_cli.command("load")
@async_command
async def load(ctx: typer.Context) -> None:
async with get_db_context() as session:
await load_fixtures(session, registry)
```
"""
@functools.wraps(func)

View File

@@ -1,15 +1,37 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory
from .search import (
SearchConfig,
get_searchable_fields,
from ..exceptions import (
InvalidFacetFilterError,
InvalidSearchColumnError,
NoSearchableFieldsError,
UnsupportedFacetTypeError,
)
from ..schemas import PaginationType
from ..types import (
FacetFieldType,
JoinType,
M2MFieldType,
OrderByClause,
OrderFieldType,
SearchFieldType,
)
from .factory import AsyncCrud, CrudFactory
from .search import SearchConfig, get_searchable_fields
__all__ = [
"AsyncCrud",
"CrudFactory",
"FacetFieldType",
"get_searchable_fields",
"InvalidFacetFilterError",
"InvalidSearchColumnError",
"JoinType",
"M2MFieldType",
"NoSearchableFieldsError",
"OrderByClause",
"OrderFieldType",
"PaginationType",
"SearchConfig",
"SearchFieldType",
"UnsupportedFacetTypeError",
]

File diff suppressed because it is too large Load Diff

View File

@@ -1,20 +1,38 @@
"""Search utilities for AsyncCrud."""
import asyncio
import functools
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_
from sqlalchemy import String, and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.types import (
ARRAY,
Boolean,
Date,
DateTime,
Enum,
Integer,
Numeric,
Time,
Uuid,
)
from ..exceptions import NoSearchableFieldsError
from ..exceptions import (
InvalidFacetFilterError,
InvalidSearchColumnError,
NoSearchableFieldsError,
UnsupportedFacetTypeError,
)
from ..types import FacetFieldType, SearchFieldType
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
@dataclass
class SearchConfig:
@@ -33,6 +51,7 @@ class SearchConfig:
match_mode: Literal["any", "all"] = "any"
@functools.lru_cache(maxsize=128)
def get_searchable_fields(
model: type[DeclarativeBase],
*,
@@ -78,6 +97,7 @@ def build_search_filters(
search: str | SearchConfig,
search_fields: Sequence[SearchFieldType] | None = None,
default_fields: Sequence[SearchFieldType] | None = None,
search_column: str | None = None,
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
"""Build SQLAlchemy filter conditions for search.
@@ -86,22 +106,24 @@ def build_search_filters(
search: Search string or SearchConfig
search_fields: Fields specified per-call (takes priority)
default_fields: Default fields (from ClassVar)
search_column: Optional key to narrow search to a single field.
Must match one of the resolved search field keys.
Returns:
Tuple of (filter_conditions, joins_needed)
Raises:
NoSearchableFieldsError: If no searchable field has been configured
"""
# Normalize input
if isinstance(search, str):
config = SearchConfig(query=search, fields=search_fields)
else:
config = search
if search_fields is not None:
config = SearchConfig(
query=config.query,
fields=search_fields,
case_sensitive=config.case_sensitive,
match_mode=config.match_mode,
)
config = (
replace(search, fields=search_fields)
if search_fields is not None
else search
)
if not config.query or not config.query.strip():
return [], []
@@ -112,6 +134,14 @@ def build_search_filters(
if not fields:
raise NoSearchableFieldsError(model)
# Narrow to a single column when search_column is specified
if search_column is not None:
keys = search_field_keys(fields)
index = {k: f for k, f in zip(keys, fields)}
if search_column not in index:
raise InvalidSearchColumnError(search_column, sorted(index))
fields = [index[search_column]]
query = config.query.strip()
filters: list[ColumnElement[bool]] = []
joins: list[InstrumentedAttribute[Any]] = []
@@ -136,7 +166,7 @@ def build_search_filters(
else:
filters.append(column_as_string.ilike(f"%{query}%"))
if not filters:
if not filters: # pragma: no cover
return [], []
# Combine based on match_mode
@@ -144,3 +174,211 @@ def build_search_filters(
return [or_(*filters)], joins
else:
return filters, joins
def search_field_keys(fields: Sequence[SearchFieldType]) -> list[str]:
"""Return a human-readable key for each search field."""
return facet_keys(fields)
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
"""Return a key for each facet field.
Args:
facet_fields: Sequence of facet fields — either direct columns or
relationship tuples ``(rel, ..., column)``.
Returns:
A list of string keys, one per facet field, in the same order.
"""
keys: list[str] = []
for field in facet_fields:
if isinstance(field, tuple):
keys.append("__".join(el.key for el in field))
else:
keys.append(field.key)
return keys
async def build_facets(
session: "AsyncSession",
model: type[DeclarativeBase],
facet_fields: Sequence[FacetFieldType],
*,
base_filters: "list[ColumnElement[bool]] | None" = None,
base_joins: list[InstrumentedAttribute[Any]] | None = None,
) -> dict[str, list[Any]]:
"""Return distinct values for each facet field, respecting current filters.
Args:
session: DB async session
model: SQLAlchemy model class
facet_fields: Columns or relationship tuples to facet on
base_filters: Filter conditions already applied to the main query (search + caller filters)
base_joins: Relationship joins already applied to the main query
Returns:
Dict mapping column key to sorted list of distinct non-None values
"""
existing_join_keys: set[str] = {str(j) for j in (base_joins or [])}
keys = facet_keys(facet_fields)
async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]:
if isinstance(field, tuple):
# Relationship chain: (User.role, Role.name) — last element is the column
rels = field[:-1]
column = field[-1]
else:
rels = ()
column = field
col_type = column.property.columns[0].type
is_array = isinstance(col_type, ARRAY)
if is_array:
unnested = func.unnest(column).label(column.key)
q = select(unnested).select_from(model).distinct()
else:
q = select(column).select_from(model).distinct()
# Apply base joins (deduplicated) — needed here independently
seen_joins: set[str] = set()
for rel in base_joins or []:
rel_key = str(rel)
if rel_key not in seen_joins:
seen_joins.add(rel_key)
q = q.outerjoin(rel)
# Add any extra joins required by this facet field that aren't already applied
for rel in rels:
rel_key = str(rel)
if rel_key not in existing_join_keys and rel_key not in seen_joins:
seen_joins.add(rel_key)
q = q.outerjoin(rel)
if base_filters:
q = q.where(and_(*base_filters))
if is_array:
q = q.order_by(unnested)
else:
q = q.order_by(column)
result = await session.execute(q)
col_type = column.property.columns[0].type
enum_class = getattr(col_type, "enum_class", None)
values = [
row[0].name
if (enum_class is not None and isinstance(row[0], enum_class))
else row[0]
for row in result.all()
if row[0] is not None
]
return key, values
pairs = await asyncio.gather(
*[_query_facet(f, k) for f, k in zip(facet_fields, keys)]
)
return dict(pairs)
_EQUALITY_TYPES = (String, Integer, Numeric, Date, DateTime, Time, Enum, Uuid)
"""Column types that support equality / IN filtering in build_filter_by."""
def _coerce_bool(value: Any) -> bool:
"""Coerce a string value to a Python bool for Boolean column filtering."""
if isinstance(value, bool):
return value
if isinstance(value, str):
if value.lower() == "true":
return True
if value.lower() == "false":
return False
raise ValueError(f"Cannot coerce {value!r} to bool")
def build_filter_by(
filter_by: dict[str, Any],
facet_fields: Sequence[FacetFieldType],
) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]:
"""Translate a {column_key: value} dict into SQLAlchemy filter conditions.
Args:
filter_by: Mapping of column key to scalar value or list of values
facet_fields: Declared facet fields to validate keys against
Returns:
Tuple of (filter_conditions, joins_needed)
Raises:
InvalidFacetFilterError: If a key in filter_by is not a declared facet field
"""
index: dict[
str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]]
] = {}
for key, field in zip(facet_keys(facet_fields), facet_fields):
if isinstance(field, tuple):
rels = list(field[:-1])
column = field[-1]
else:
rels = []
column = field
index[key] = (column, rels)
valid_keys = set(index)
filters: list[ColumnElement[bool]] = []
joins: list[InstrumentedAttribute[Any]] = []
added_join_keys: set[str] = set()
for key, value in filter_by.items():
if key not in index:
raise InvalidFacetFilterError(key, valid_keys)
column, rels = index[key]
for rel in rels:
rel_key = str(rel)
if rel_key not in added_join_keys:
joins.append(rel)
added_join_keys.add(rel_key)
col_type = column.property.columns[0].type
if isinstance(col_type, Boolean):
coerce = _coerce_bool
if isinstance(value, list):
filters.append(column.in_([coerce(v) for v in value]))
else:
filters.append(column == coerce(value))
elif isinstance(col_type, ARRAY):
if isinstance(value, list):
filters.append(column.overlap(value))
else:
filters.append(column.any(value))
elif isinstance(col_type, Enum):
enum_class = col_type.enum_class
if enum_class is not None:
def _coerce_enum(v: Any) -> Any:
if isinstance(v, enum_class):
return v
return enum_class[v] # lookup by name: "PENDING", "RED"
if isinstance(value, list):
filters.append(column.in_([_coerce_enum(v) for v in value]))
else:
filters.append(column == _coerce_enum(value))
else: # pragma: no cover
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column == value)
elif isinstance(col_type, _EQUALITY_TYPES):
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column == value)
else:
raise UnsupportedFacetTypeError(key, type(col_type).__name__)
return filters, joins

View File

@@ -1,25 +1,35 @@
"""Database utilities: sessions, transactions, and locks."""
import asyncio
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from enum import Enum
from typing import Any, TypeVar
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from .exceptions import NotFoundError
__all__ = [
"LockMode",
"cleanup_tables",
"create_database",
"create_db_context",
"create_db_dependency",
"lock_tables",
"get_transaction",
"lock_tables",
"wait_for_row_change",
]
_SessionT = TypeVar("_SessionT", bound=AsyncSession)
def create_db_dependency(
session_maker: async_sessionmaker[AsyncSession],
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
session_maker: async_sessionmaker[_SessionT],
) -> Callable[[], AsyncGenerator[_SessionT, None]]:
"""Create a FastAPI dependency for database sessions.
Creates a dependency function that yields a session and auto-commits
@@ -32,6 +42,7 @@ def create_db_dependency(
An async generator function usable with FastAPI's Depends()
Example:
```python
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from fastapi_toolsets.db import create_db_dependency
@@ -43,10 +54,12 @@ def create_db_dependency(
@app.get("/users")
async def list_users(session: AsyncSession = Depends(get_db)):
...
```
"""
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async def get_db() -> AsyncGenerator[_SessionT, None]:
async with session_maker() as session:
await session.connection()
yield session
if session.in_transaction():
await session.commit()
@@ -55,8 +68,8 @@ def create_db_dependency(
def create_db_context(
session_maker: async_sessionmaker[AsyncSession],
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
session_maker: async_sessionmaker[_SessionT],
) -> Callable[[], AbstractAsyncContextManager[_SessionT]]:
"""Create a context manager for database sessions.
Creates a context manager for use outside of FastAPI request handlers,
@@ -69,6 +82,7 @@ def create_db_context(
An async context manager function
Example:
```python
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from fastapi_toolsets.db import create_db_context
@@ -80,6 +94,7 @@ def create_db_context(
async with get_db_context() as session:
user = await UserCrud.get(session, [User.id == 1])
...
```
"""
get_db = create_db_dependency(session_maker)
return asynccontextmanager(get_db)
@@ -101,9 +116,11 @@ async def get_transaction(
The session within the transaction context
Example:
```python
async with get_transaction(session):
session.add(model)
# Auto-commits on exit, rolls back on exception
```
"""
if session.in_transaction():
async with session.begin_nested():
@@ -155,6 +172,7 @@ async def lock_tables(
SQLAlchemyError: If lock cannot be acquired within timeout
Example:
```python
from fastapi_toolsets.db import lock_tables, LockMode
async with lock_tables(session, [User, Account]):
@@ -166,6 +184,7 @@ async def lock_tables(
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
# Exclusive lock - no other transactions can access
await process_order(session, order_id)
```
"""
table_names = ",".join(table.__tablename__ for table in tables)
@@ -173,3 +192,150 @@ async def lock_tables(
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
yield session
async def create_database(
db_name: str,
*,
server_url: str,
) -> None:
"""Create a database.
Connects to *server_url* using ``AUTOCOMMIT`` isolation and issues a
``CREATE DATABASE`` statement for *db_name*.
Args:
db_name: Name of the database to create.
server_url: URL used for server-level DDL (must point to an existing
database on the same server).
Example:
```python
from fastapi_toolsets.db import create_database
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
await create_database("myapp_test", server_url=SERVER_URL)
```
"""
engine = create_async_engine(server_url, isolation_level="AUTOCOMMIT")
try:
async with engine.connect() as conn:
await conn.execute(text(f"CREATE DATABASE {db_name}"))
finally:
await engine.dispose()
async def cleanup_tables(
session: AsyncSession,
base: type[DeclarativeBase],
) -> None:
"""Truncate all tables for fast between-test cleanup.
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
across every table in *base*'s metadata, which is significantly faster
than dropping and re-creating tables between tests.
This is a no-op when the metadata contains no tables.
Args:
session: An active async database session.
base: SQLAlchemy DeclarativeBase class containing model metadata.
Example:
```python
@pytest.fixture
async def db_session(worker_db_url):
async with create_db_session(worker_db_url, Base) as session:
yield session
await cleanup_tables(session, Base)
```
"""
tables = base.metadata.sorted_tables
if not tables:
return
table_names = ", ".join(f'"{t.name}"' for t in tables)
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
await session.commit()
_M = TypeVar("_M", bound=DeclarativeBase)
async def wait_for_row_change(
session: AsyncSession,
model: type[_M],
pk_value: Any,
*,
columns: list[str] | None = None,
interval: float = 0.5,
timeout: float | None = None,
) -> _M:
"""Poll a database row until a change is detected.
Queries the row every ``interval`` seconds and returns the model instance
once a change is detected in any column (or only the specified ``columns``).
Args:
session: AsyncSession instance
model: SQLAlchemy model class
pk_value: Primary key value of the row to watch
columns: Optional list of column names to watch. If None, all columns
are watched.
interval: Polling interval in seconds (default: 0.5)
timeout: Maximum time to wait in seconds. None means wait forever.
Returns:
The refreshed model instance with updated values
Raises:
NotFoundError: If the row does not exist or is deleted during polling
TimeoutError: If timeout expires before a change is detected
Example:
```python
from fastapi_toolsets.db import wait_for_row_change
# Wait for any column to change
updated = await wait_for_row_change(session, User, user_id)
# Watch specific columns with a timeout
updated = await wait_for_row_change(
session, User, user_id,
columns=["status", "email"],
interval=1.0,
timeout=30.0,
)
```
"""
instance = await session.get(model, pk_value)
if instance is None:
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} not found")
if columns is not None:
watch_cols = columns
else:
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
initial = {col: getattr(instance, col) for col in watch_cols}
elapsed = 0.0
while True:
await asyncio.sleep(interval)
elapsed += interval
if timeout is not None and elapsed >= timeout:
raise TimeoutError(
f"No change detected on {model.__name__} "
f"with pk={pk_value!r} within {timeout}s"
)
session.expunge(instance)
instance = await session.get(model, pk_value)
if instance is None:
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} was deleted")
current = {col: getattr(instance, col) for col in watch_cols}
if current != initial:
return instance

View File

@@ -1,17 +1,27 @@
"""Dependency factories for FastAPI routes."""
import inspect
from collections.abc import AsyncGenerator, Callable
from typing import Any, TypeVar, cast
import typing
from collections.abc import Callable
from typing import Any, cast
from fastapi import Depends
from fastapi.params import Depends as DependsClass
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from ..crud import CrudFactory
from .crud import CrudFactory
from .types import ModelType, SessionDependency
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
__all__ = ["BodyDependency", "PathDependency"]
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
if typing.get_origin(session_dep) is typing.Annotated:
for arg in typing.get_args(session_dep)[1:]:
if isinstance(arg, DependsClass):
return arg.dependency
return session_dep
def PathDependency(
@@ -36,13 +46,16 @@ def PathDependency(
NotFoundError: If no matching record is found
Example:
```python
UserDep = PathDependency(User, User.id, session_dep=get_db)
@router.get("/user/{id}")
async def get(
user: User = UserDep,
): ...
```
"""
session_callable = _unwrap_session_dep(session_dep)
crud = CrudFactory(model)
name = (
param_name
@@ -52,7 +65,7 @@ def PathDependency(
python_type = field.type.python_type
async def dependency(
session: AsyncSession = Depends(session_dep), **kwargs: Any
session: AsyncSession = Depends(session_callable), **kwargs: Any
) -> ModelType:
value = kwargs[name]
return await crud.get(session, filters=[field == value])
@@ -69,7 +82,7 @@ def PathDependency(
"session",
inspect.Parameter.KEYWORD_ONLY,
annotation=AsyncSession,
default=Depends(session_dep),
default=Depends(session_callable),
),
]
),
@@ -100,6 +113,7 @@ def BodyDependency(
NotFoundError: If no matching record is found
Example:
```python
UserDep = BodyDependency(
User, User.ctfd_id, session_dep=get_db, body_field="user_id"
)
@@ -108,12 +122,14 @@ def BodyDependency(
async def assign(
user: User = UserDep,
): ...
```
"""
session_callable = _unwrap_session_dep(session_dep)
crud = CrudFactory(model)
python_type = field.type.python_type
async def dependency(
session: AsyncSession = Depends(session_dep), **kwargs: Any
session: AsyncSession = Depends(session_callable), **kwargs: Any
) -> ModelType:
value = kwargs[body_field]
return await crud.get(session, filters=[field == value])
@@ -130,7 +146,7 @@ def BodyDependency(
"session",
inspect.Parameter.KEYWORD_ONLY,
annotation=AsyncSession,
default=Depends(session_dep),
default=Depends(session_callable),
),
]
),

View File

@@ -1,5 +0,0 @@
"""FastAPI dependency factories for database objects."""
from .factory import BodyDependency, PathDependency
__all__ = ["BodyDependency", "PathDependency"]

View File

@@ -1,11 +1,17 @@
"""Standardized API exceptions and error response handlers."""
from .exceptions import (
ApiError,
ApiException,
ConflictError,
ForbiddenError,
InvalidFacetFilterError,
InvalidOrderFieldError,
InvalidSearchColumnError,
NoSearchableFieldsError,
NotFoundError,
UnauthorizedError,
UnsupportedFacetTypeError,
generate_error_responses,
)
from .handler import init_exceptions_handlers
@@ -17,7 +23,11 @@ __all__ = [
"ForbiddenError",
"generate_error_responses",
"init_exceptions_handlers",
"InvalidFacetFilterError",
"InvalidOrderFieldError",
"InvalidSearchColumnError",
"NoSearchableFieldsError",
"NotFoundError",
"UnauthorizedError",
"UnsupportedFacetTypeError",
]

View File

@@ -6,30 +6,46 @@ from ..schemas import ApiError, ErrorResponse, ResponseStatus
class ApiException(Exception):
"""Base exception for API errors with structured response.
Subclass this to create custom API exceptions with consistent error format.
The exception handler will use api_error to generate the response.
Example:
class CustomError(ApiException):
api_error = ApiError(
code=400,
msg="Bad Request",
desc="The request was invalid.",
err_code="CUSTOM-400",
)
"""
"""Base exception for API errors with structured response."""
api_error: ClassVar[ApiError]
def __init__(self, detail: str | None = None):
def __init_subclass__(cls, abstract: bool = False, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if not abstract and not hasattr(cls, "api_error"):
raise TypeError(
f"{cls.__name__} must define an 'api_error' class attribute. "
"Pass abstract=True when creating intermediate base classes."
)
def __init__(
self,
detail: str | None = None,
*,
desc: str | None = None,
data: Any = None,
) -> None:
"""Initialize the exception.
Args:
detail: Optional override for the error message
detail: Optional human-readable message
desc: Optional per-instance override for the ``description`` field
in the HTTP response body.
data: Optional per-instance override for the ``data`` field in the
HTTP response body.
"""
super().__init__(detail or self.api_error.msg)
updates: dict[str, Any] = {}
if detail is not None:
updates["msg"] = detail
if desc is not None:
updates["desc"] = desc
if data is not None:
updates["data"] = data
if updates:
object.__setattr__(
self, "api_error", self.__class__.api_error.model_copy(update=updates)
)
super().__init__(self.api_error.msg)
class UnauthorizedError(ApiException):
@@ -76,49 +92,6 @@ class ConflictError(ApiException):
)
class InsufficientRolesError(ForbiddenError):
"""User does not have the required roles."""
api_error = ApiError(
code=403,
msg="Insufficient Roles",
desc="You do not have the required roles to access this resource.",
err_code="RBAC-403",
)
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
self.required_roles = required_roles
self.user_roles = user_roles
desc = f"Required roles: {', '.join(required_roles)}"
if user_roles is not None:
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
super().__init__(desc)
class UserNotFoundError(NotFoundError):
"""User was not found."""
api_error = ApiError(
code=404,
msg="User Not Found",
desc="The requested user was not found.",
err_code="USER-404",
)
class RoleNotFoundError(NotFoundError):
"""Role was not found."""
api_error = ApiError(
code=404,
msg="Role Not Found",
desc="The requested role was not found.",
err_code="ROLE-404",
)
class NoSearchableFieldsError(ApiException):
"""Raised when search is requested but no searchable fields are available."""
@@ -130,12 +103,124 @@ class NoSearchableFieldsError(ApiException):
)
def __init__(self, model: type) -> None:
"""Initialize the exception.
Args:
model: The model class that has no searchable fields configured.
"""
self.model = model
detail = (
f"No searchable fields found for model '{model.__name__}'. "
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
super().__init__(
desc=(
f"No searchable fields found for model '{model.__name__}'. "
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
)
)
class InvalidFacetFilterError(ApiException):
"""Raised when filter_by contains a key not declared in facet_fields."""
api_error = ApiError(
code=400,
msg="Invalid Facet Filter",
desc="One or more filter_by keys are not declared as facet fields.",
err_code="FACET-400",
)
def __init__(self, key: str, valid_keys: set[str]) -> None:
"""Initialize the exception.
Args:
key: The unknown filter key provided by the caller.
valid_keys: Set of valid keys derived from the declared facet_fields.
"""
self.key = key
self.valid_keys = valid_keys
super().__init__(
desc=(
f"'{key}' is not a declared facet field. "
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
)
)
class UnsupportedFacetTypeError(ApiException):
"""Raised when a facet field has a column type not supported by filter_by."""
api_error = ApiError(
code=400,
msg="Unsupported Facet Type",
desc="The column type is not supported for facet filtering.",
err_code="FACET-TYPE-400",
)
def __init__(self, key: str, col_type: str) -> None:
"""Initialize the exception.
Args:
key: The facet field key.
col_type: The unsupported column type name.
"""
self.key = key
self.col_type = col_type
super().__init__(
desc=(
f"Facet field '{key}' has unsupported column type '{col_type}'. "
f"Supported types: String, Integer, Numeric, Boolean, "
f"Date, DateTime, Time, Enum, Uuid, ARRAY."
)
)
class InvalidSearchColumnError(ApiException):
"""Raised when search_column is not one of the configured searchable fields."""
api_error = ApiError(
code=400,
msg="Invalid Search Column",
desc="The requested search column is not a configured searchable field.",
err_code="SEARCH-COL-400",
)
def __init__(self, column: str, valid_columns: list[str]) -> None:
"""Initialize the exception.
Args:
column: The unknown search column provided by the caller.
valid_columns: List of valid search column keys.
"""
self.column = column
self.valid_columns = valid_columns
super().__init__(
desc=(
f"'{column}' is not a searchable column. "
f"Valid columns: {valid_columns}."
)
)
class InvalidOrderFieldError(ApiException):
"""Raised when order_by contains a field not in the allowed order fields."""
api_error = ApiError(
code=422,
msg="Invalid Order Field",
desc="The requested order field is not allowed for this resource.",
err_code="SORT-422",
)
def __init__(self, field: str, valid_fields: list[str]) -> None:
"""Initialize the exception.
Args:
field: The unknown order field provided by the caller.
valid_fields: List of valid field names.
"""
self.field = field
self.valid_fields = valid_fields
super().__init__(
desc=f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
)
super().__init__(detail)
def generate_error_responses(
@@ -143,42 +228,39 @@ def generate_error_responses(
) -> dict[int | str, dict[str, Any]]:
"""Generate OpenAPI response documentation for exceptions.
Use this to document possible error responses for an endpoint.
Args:
*errors: Exception classes that inherit from ApiException
*errors: Exception classes that inherit from ApiException.
Returns:
Dict suitable for FastAPI's responses parameter
Example:
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
@app.get(
"/admin",
responses=generate_error_responses(UnauthorizedError, ForbiddenError)
)
async def admin_endpoint():
...
Dict suitable for FastAPI's ``responses`` parameter.
"""
responses: dict[int | str, dict[str, Any]] = {}
for error in errors:
api_error = error.api_error
code = api_error.code
responses[api_error.code] = {
"model": ErrorResponse,
"description": api_error.msg,
"content": {
"application/json": {
"example": {
"data": None,
"status": ResponseStatus.FAIL.value,
"message": api_error.msg,
"description": api_error.desc,
"error_code": api_error.err_code,
if code not in responses:
responses[code] = {
"model": ErrorResponse,
"description": api_error.msg,
"content": {
"application/json": {
"examples": {},
}
}
},
}
responses[code]["content"]["application/json"]["examples"][
api_error.err_code
] = {
"summary": api_error.msg,
"value": {
"data": api_error.data,
"status": ResponseStatus.FAIL.value,
"message": api_error.msg,
"description": api_error.desc,
"error_code": api_error.err_code,
},
}

View File

@@ -1,50 +1,69 @@
"""Exception handlers for FastAPI applications."""
from collections.abc import Callable
from typing import Any
from fastapi import FastAPI, Request, Response, status
from fastapi.exceptions import RequestValidationError, ResponseValidationError
from fastapi.openapi.utils import get_openapi
from fastapi.exceptions import (
HTTPException,
RequestValidationError,
ResponseValidationError,
)
from fastapi.responses import JSONResponse
from ..schemas import ResponseStatus
from ..schemas import ErrorResponse, ResponseStatus
from .exceptions import ApiException
_VALIDATION_LOCATION_PARAMS: frozenset[str] = frozenset(
{"body", "query", "path", "header", "cookie"}
)
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
Args:
app: FastAPI application instance.
Returns:
The same FastAPI instance (for chaining).
"""
_register_exception_handlers(app)
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
_original_openapi = app.openapi
app.openapi = lambda: _patched_openapi(app, _original_openapi) # type: ignore[method-assign] # ty:ignore[invalid-assignment]
return app
def _register_exception_handlers(app: FastAPI) -> None:
"""Register all exception handlers on a FastAPI application.
Args:
app: FastAPI application instance
Example:
from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers
app = FastAPI()
init_exceptions_handlers(app)
"""
"""Register all exception handlers on a FastAPI application."""
@app.exception_handler(ApiException)
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
"""Handle custom API exceptions with structured response."""
api_error = exc.api_error
error_response = ErrorResponse(
data=api_error.data,
message=api_error.msg,
description=api_error.desc,
error_code=api_error.err_code,
)
return JSONResponse(
status_code=api_error.code,
content={
"data": None,
"status": ResponseStatus.FAIL.value,
"message": api_error.msg,
"description": api_error.desc,
"error_code": api_error.err_code,
},
content=error_response.model_dump(),
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
"""Handle Starlette/FastAPI HTTPException with a consistent error format."""
detail = exc.detail if isinstance(exc.detail, str) else "HTTP Error"
error_response = ErrorResponse(
message=detail,
error_code=f"HTTP-{exc.status_code}",
)
return JSONResponse(
status_code=exc.status_code,
content=error_response.model_dump(),
headers=getattr(exc, "headers", None),
)
@app.exception_handler(RequestValidationError)
@@ -64,15 +83,14 @@ def _register_exception_handlers(app: FastAPI) -> None:
@app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
"""Handle all unhandled exceptions with a generic 500 response."""
error_response = ErrorResponse(
message="Internal Server Error",
description="An unexpected error occurred. Please try again later.",
error_code="SERVER-500",
)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"data": None,
"status": ResponseStatus.FAIL.value,
"message": "Internal Server Error",
"description": "An unexpected error occurred. Please try again later.",
"error_code": "SERVER-500",
},
content=error_response.model_dump(),
)
@@ -84,11 +102,10 @@ def _format_validation_error(
formatted_errors = []
for error in errors:
field_path = ".".join(
str(loc)
for loc in error["loc"]
if loc not in ("body", "query", "path", "header", "cookie")
)
locs = error["loc"]
if locs and locs[0] in _VALIDATION_LOCATION_PARAMS:
locs = locs[1:]
field_path = ".".join(str(loc) for loc in locs)
formatted_errors.append(
{
"field": field_path or "root",
@@ -97,46 +114,35 @@ def _format_validation_error(
}
)
error_response = ErrorResponse(
data={"errors": formatted_errors},
message="Validation Error",
description=f"{len(formatted_errors)} validation error(s) detected",
error_code="VAL-422",
)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"data": {"errors": formatted_errors},
"status": ResponseStatus.FAIL.value,
"message": "Validation Error",
"description": f"{len(formatted_errors)} validation error(s) detected",
"error_code": "VAL-422",
},
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
content=error_response.model_dump(),
)
def _custom_openapi(app: FastAPI) -> dict[str, Any]:
"""Generate custom OpenAPI schema with standardized error format.
Replaces default 422 validation error responses with the custom format.
def _patched_openapi(
app: FastAPI, original_openapi: Callable[[], dict[str, Any]]
) -> dict[str, Any]:
"""Generate the OpenAPI schema and replace default 422 responses.
Args:
app: FastAPI application instance
app: FastAPI application instance.
original_openapi: The previous ``app.openapi`` callable to delegate to.
Returns:
OpenAPI schema dict
Example:
from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers
app = FastAPI()
init_exceptions_handlers(app) # Automatically sets custom OpenAPI
Patched OpenAPI schema dict.
"""
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
openapi_version=app.openapi_version,
description=app.description,
routes=app.routes,
)
openapi_schema = original_openapi()
for path_data in openapi_schema.get("paths", {}).values():
for operation in path_data.values():
@@ -146,20 +152,25 @@ def _custom_openapi(app: FastAPI) -> dict[str, Any]:
"description": "Validation Error",
"content": {
"application/json": {
"example": {
"data": {
"errors": [
{
"field": "field_name",
"message": "value is not valid",
"type": "value_error",
}
]
},
"status": ResponseStatus.FAIL.value,
"message": "Validation Error",
"description": "1 validation error(s) detected",
"error_code": "VAL-422",
"examples": {
"VAL-422": {
"summary": "Validation Error",
"value": {
"data": {
"errors": [
{
"field": "field_name",
"message": "value is not valid",
"type": "value_error",
}
]
},
"status": ResponseStatus.FAIL.value,
"message": "Validation Error",
"description": "1 validation error(s) detected",
"error_code": "VAL-422",
},
}
}
}
},

View File

@@ -1,3 +1,5 @@
"""Fixture system for seeding databases with dependency resolution."""
from .enum import LoadStrategy
from .registry import Context, FixtureRegistry
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context

View File

@@ -1,3 +1,5 @@
"""Enums for fixture loading strategies and contexts."""
from enum import Enum

View File

@@ -2,6 +2,7 @@
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, cast
from sqlalchemy.orm import DeclarativeBase
@@ -12,6 +13,13 @@ from .enum import Context
logger = get_logger()
def _normalize_contexts(
contexts: list[str | Enum] | tuple[str | Enum, ...],
) -> list[str]:
"""Convert a sequence of any Enum subclass and/or plain strings to a list of strings."""
return [c.value if isinstance(c, Enum) else c for c in contexts]
@dataclass
class Fixture:
"""A fixture definition with metadata."""
@@ -26,6 +34,7 @@ class FixtureRegistry:
"""Registry for managing fixtures with dependencies.
Example:
```python
from fastapi_toolsets.fixtures import FixtureRegistry, Context
fixtures = FixtureRegistry()
@@ -48,26 +57,52 @@ class FixtureRegistry:
return [
Post(id=1, title="Test", user_id=1),
]
```
Fixtures with the same name may be registered for **different** contexts.
When multiple contexts are loaded together, their instances are merged:
```python
@fixtures.register(contexts=[Context.BASE])
def users():
return [User(id=1, username="admin")]
@fixtures.register(contexts=[Context.TESTING])
def users():
return [User(id=2, username="tester")]
# load_fixtures_by_context(..., Context.BASE, Context.TESTING)
# → loads both User(admin) and User(tester) under the "users" name
```
"""
def __init__(
self,
contexts: list[str | Context] | None = None,
contexts: list[str | Enum] | None = None,
) -> None:
self._fixtures: dict[str, Fixture] = {}
self._fixtures: dict[str, list[Fixture]] = {}
self._default_contexts: list[str] | None = (
[c.value if isinstance(c, Context) else c for c in contexts]
if contexts
else None
_normalize_contexts(contexts) if contexts else None
)
def _validate_no_context_overlap(self, name: str, new_contexts: list[str]) -> None:
"""Raise ``ValueError`` if any existing variant for *name* overlaps."""
existing_variants = self._fixtures.get(name, [])
new_set = set(new_contexts)
for variant in existing_variants:
if set(variant.contexts) & new_set:
raise ValueError(
f"Fixture '{name}' already exists in the current registry "
f"with overlapping contexts. Use distinct context sets for "
f"each variant of the same fixture name."
)
def register(
self,
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
*,
name: str | None = None,
depends_on: list[str] | None = None,
contexts: list[str | Context] | None = None,
contexts: list[str | Enum] | None = None,
) -> Callable[..., Any]:
"""Register a fixture function.
@@ -77,9 +112,11 @@ class FixtureRegistry:
func: Fixture function returning list of model instances
name: Fixture name (defaults to function name)
depends_on: List of fixture names this depends on
contexts: List of contexts this fixture belongs to
contexts: List of contexts this fixture belongs to. Both
:class:`Context` enum values and plain strings are accepted.
Example:
```python
@fixtures.register
def roles():
return [Role(id=1, name="admin")]
@@ -94,19 +131,20 @@ class FixtureRegistry:
) -> Callable[[], Sequence[DeclarativeBase]]:
fixture_name = name or cast(Any, fn).__name__
if contexts is not None:
fixture_contexts = [
c.value if isinstance(c, Context) else c for c in contexts
]
fixture_contexts = _normalize_contexts(contexts)
elif self._default_contexts is not None:
fixture_contexts = self._default_contexts
else:
fixture_contexts = [Context.BASE.value]
self._fixtures[fixture_name] = Fixture(
name=fixture_name,
func=fn,
depends_on=depends_on or [],
contexts=fixture_contexts,
self._validate_no_context_overlap(fixture_name, fixture_contexts)
self._fixtures.setdefault(fixture_name, []).append(
Fixture(
name=fixture_name,
func=fn,
depends_on=depends_on or [],
contexts=fixture_contexts,
)
)
return fn
@@ -117,13 +155,17 @@ class FixtureRegistry:
def include_registry(self, registry: "FixtureRegistry") -> None:
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
Fixtures with the same name are allowed as long as their context sets
do not overlap. Conflicting contexts raise :class:`ValueError`.
Args:
registry: The `FixtureRegistry` to include
Raises:
ValueError: If a fixture name already exists in the current registry
ValueError: If a fixture name already exists with overlapping contexts
Example:
```python
registry = FixtureRegistry()
dev_registry = FixtureRegistry()
@@ -132,32 +174,75 @@ class FixtureRegistry:
return [...]
registry.include_registry(registry=dev_registry)
```
"""
for name, fixture in registry._fixtures.items():
if name in self._fixtures:
raise ValueError(
f"Fixture '{name}' already exists in the current registry"
)
self._fixtures[name] = fixture
for name, variants in registry._fixtures.items():
for fixture in variants:
self._validate_no_context_overlap(name, fixture.contexts)
self._fixtures.setdefault(name, []).append(fixture)
def get(self, name: str) -> Fixture:
"""Get a fixture by name."""
"""Get a fixture by name.
Raises:
KeyError: If no fixture with *name* is registered.
ValueError: If the fixture has multiple context variants — use
:meth:`get_variants` in that case.
"""
if name not in self._fixtures:
raise KeyError(f"Fixture '{name}' not found")
return self._fixtures[name]
variants = self._fixtures[name]
if len(variants) > 1:
raise ValueError(
f"Fixture '{name}' has {len(variants)} context variants. "
f"Use get_variants('{name}') to retrieve them."
)
return variants[0]
def get_variants(self, name: str, *contexts: str | Enum) -> list[Fixture]:
"""Return all registered variants for *name*, optionally filtered by context.
Args:
name: Fixture name.
*contexts: If given, only return variants whose context set
intersects with these values. Both :class:`Context` enum
values and plain strings are accepted.
Returns:
List of matching :class:`Fixture` objects (may be empty when a
context filter is applied and nothing matches).
Raises:
KeyError: If no fixture with *name* is registered.
"""
if name not in self._fixtures:
raise KeyError(f"Fixture '{name}' not found")
variants = self._fixtures[name]
if not contexts:
return list(variants)
context_values = set(_normalize_contexts(contexts))
return [v for v in variants if set(v.contexts) & context_values]
def get_all(self) -> list[Fixture]:
"""Get all registered fixtures."""
return list(self._fixtures.values())
"""Get all registered fixtures (all variants of all names)."""
return [f for variants in self._fixtures.values() for f in variants]
def get_by_context(self, *contexts: str | Context) -> list[Fixture]:
def get_by_context(self, *contexts: str | Enum) -> list[Fixture]:
"""Get fixtures for specific contexts."""
context_values = {c.value if isinstance(c, Context) else c for c in contexts}
return [f for f in self._fixtures.values() if set(f.contexts) & context_values]
context_values = set(_normalize_contexts(contexts))
return [
f
for variants in self._fixtures.values()
for f in variants
if set(f.contexts) & context_values
]
def resolve_dependencies(self, *names: str) -> list[str]:
"""Resolve fixture dependencies in topological order.
When a fixture name has multiple context variants, the union of all
variants' ``depends_on`` lists is used.
Args:
*names: Fixture names to resolve
@@ -179,9 +264,20 @@ class FixtureRegistry:
raise ValueError(f"Circular dependency detected: {name}")
visiting.add(name)
fixture = self.get(name)
variants = self._fixtures.get(name)
if variants is None:
raise KeyError(f"Fixture '{name}' not found")
for dep in fixture.depends_on:
# Union of depends_on across all variants, preserving first-seen order.
seen_deps: set[str] = set()
all_deps: list[str] = []
for variant in variants:
for dep in variant.depends_on:
if dep not in seen_deps:
all_deps.append(dep)
seen_deps.add(dep)
for dep in all_deps:
visit(dep)
visiting.remove(name)
@@ -193,7 +289,7 @@ class FixtureRegistry:
return resolved
def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]:
def resolve_context_dependencies(self, *contexts: str | Enum) -> list[str]:
"""Resolve all fixtures for contexts with dependencies.
Args:
@@ -203,7 +299,9 @@ class FixtureRegistry:
List of fixture names in load order
"""
context_fixtures = self.get_by_context(*contexts)
names = [f.name for f in context_fixtures]
# Deduplicate names while preserving first-seen order (a name can
# appear multiple times if it has variants in different contexts).
names = list(dict.fromkeys(f.name for f in context_fixtures))
all_deps: set[str] = set()
for name in names:

View File

@@ -1,22 +1,233 @@
from collections.abc import Callable, Sequence
from typing import Any, TypeVar
"""Fixture loading utilities for database seeding."""
from collections.abc import Callable, Sequence
from enum import Enum
from typing import Any
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction
from ..logger import get_logger
from ..types import ModelType
from .enum import LoadStrategy
from .registry import Context, FixtureRegistry
from .registry import FixtureRegistry, _normalize_contexts
logger = get_logger()
T = TypeVar("T", bound=DeclarativeBase)
def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
"""Extract column values from a model instance, skipping unset server-default columns."""
state = sa_inspect(instance)
state_dict = state.dict
result: dict[str, Any] = {}
for prop in state.mapper.column_attrs:
if prop.key not in state_dict:
continue
val = state_dict[prop.key]
if val is None:
col = prop.columns[0]
if (
col.server_default is not None
or (col.default is not None and col.default.is_callable)
or col.autoincrement is True
):
continue
result[prop.key] = val
return result
def _group_by_type(
instances: list[DeclarativeBase],
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
"""Group instances by their concrete model class, preserving insertion order."""
groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {}
for instance in instances:
groups.setdefault(type(instance), []).append(instance)
return list(groups.items())
def _group_by_column_set(
dicts: list[dict[str, Any]],
instances: list[DeclarativeBase],
) -> list[tuple[list[dict[str, Any]], list[DeclarativeBase]]]:
"""Group (dict, instance) pairs by their dict key sets."""
groups: dict[
frozenset[str], tuple[list[dict[str, Any]], list[DeclarativeBase]]
] = {}
for d, inst in zip(dicts, instances):
key = frozenset(d)
if key not in groups:
groups[key] = ([], [])
groups[key][0].append(d)
groups[key][1].append(inst)
return list(groups.values())
async def _batch_insert(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> None:
"""INSERT all instances — raises on conflict (no duplicate handling)."""
dicts = [_instance_to_dict(i) for i in instances]
for group_dicts, _ in _group_by_column_set(dicts, instances):
await session.execute(pg_insert(model_cls).values(group_dicts))
async def _batch_merge(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> None:
"""UPSERT: insert new rows, update existing ones with the provided values."""
mapper = model_cls.__mapper__
pk_names = [col.name for col in mapper.primary_key]
pk_names_set = set(pk_names)
non_pk_cols = [
prop.key
for prop in mapper.column_attrs
if not any(col.name in pk_names_set for col in prop.columns)
]
dicts = [_instance_to_dict(i) for i in instances]
for group_dicts, _ in _group_by_column_set(dicts, instances):
stmt = pg_insert(model_cls).values(group_dicts)
inserted_keys = set(group_dicts[0])
update_cols = [col for col in non_pk_cols if col in inserted_keys]
if update_cols:
stmt = stmt.on_conflict_do_update(
index_elements=pk_names,
set_={col: stmt.excluded[col] for col in update_cols},
)
else:
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
await session.execute(stmt)
async def _batch_skip_existing(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> list[DeclarativeBase]:
"""INSERT only rows that do not already exist; return the inserted ones."""
mapper = model_cls.__mapper__
pk_names = [col.name for col in mapper.primary_key]
no_pk: list[DeclarativeBase] = []
with_pk_pairs: list[tuple[DeclarativeBase, Any]] = []
for inst in instances:
pk = _get_primary_key(inst)
if pk is None:
no_pk.append(inst)
else:
with_pk_pairs.append((inst, pk))
loaded: list[DeclarativeBase] = list(no_pk)
if no_pk:
no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
await session.execute(pg_insert(model_cls).values(group_dicts))
if with_pk_pairs:
with_pk = [i for i, _ in with_pk_pairs]
with_pk_dicts = [_instance_to_dict(i) for i in with_pk]
for group_dicts, group_insts in _group_by_column_set(with_pk_dicts, with_pk):
stmt = (
pg_insert(model_cls)
.values(group_dicts)
.on_conflict_do_nothing(index_elements=pk_names)
)
result = await session.execute(stmt.returning(*mapper.primary_key))
inserted_pks = {
row[0] if len(pk_names) == 1 else tuple(row) for row in result
}
loaded.extend(
inst
for inst, pk in zip(
group_insts, [_get_primary_key(i) for i in group_insts]
)
if pk in inserted_pks
)
return loaded
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
contexts: tuple[str, ...] | None = None,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order using batch Core INSERT statements."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
variants = (
registry.get_variants(name, *contexts)
if contexts is not None
else registry.get_variants(name)
)
if contexts is not None and not variants:
variants = registry.get_variants(name)
if not variants:
results[name] = []
continue
instances = [inst for v in variants for inst in v.func()]
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for model_cls, group in _group_by_type(instances):
match strategy:
case LoadStrategy.INSERT:
await _batch_insert(session, model_cls, group)
loaded.extend(group)
case LoadStrategy.MERGE:
await _batch_merge(session, model_cls, group)
loaded.extend(group)
case LoadStrategy.SKIP_EXISTING:
inserted = await _batch_skip_existing(session, model_cls, group)
loaded.extend(inserted)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None
def get_obj_by_attr(
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any
) -> T:
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
) -> ModelType:
"""Get a SQLAlchemy model instance by matching an attribute value.
Args:
@@ -47,6 +258,8 @@ async def load_fixtures(
) -> dict[str, list[DeclarativeBase]]:
"""Load specific fixtures by name with dependencies.
All context variants of each requested fixture are loaded and merged.
Args:
session: Database session
registry: Fixture registry
@@ -55,11 +268,6 @@ async def load_fixtures(
Returns:
Dict mapping fixture names to loaded instances
Example:
# Loads 'roles' first (dependency), then 'users'
result = await load_fixtures(session, fixtures, "users")
print(result["users"]) # [User(...), ...]
"""
ordered = registry.resolve_dependencies(*names)
return await _load_ordered(session, registry, ordered, strategy)
@@ -68,7 +276,7 @@ async def load_fixtures(
async def load_fixtures_by_context(
session: AsyncSession,
registry: FixtureRegistry,
*contexts: str | Context,
*contexts: str | Enum,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load all fixtures for specific contexts.
@@ -76,79 +284,15 @@ async def load_fixtures_by_context(
Args:
session: Database session
registry: Fixture registry
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
*contexts: Contexts to load (e.g., ``Context.BASE``, ``Context.TESTING``,
or plain strings for custom contexts)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
Example:
# Load base + testing fixtures
await load_fixtures_by_context(
session, fixtures,
Context.BASE, Context.TESTING
)
"""
context_strings = tuple(_normalize_contexts(contexts))
ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy)
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None
return await _load_ordered(
session, registry, ordered, strategy, contexts=context_strings
)

View File

@@ -35,6 +35,14 @@ def configure_logging(
Returns:
The configured Logger instance.
Example:
```python
from fastapi_toolsets.logger import configure_logging
logger = configure_logging("DEBUG")
logger.info("Application started")
```
"""
formatter = logging.Formatter(fmt)
@@ -58,7 +66,7 @@ def configure_logging(
_SENTINEL = object()
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment]
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment] # ty:ignore[invalid-parameter-default]
"""Return a logger with the given *name*.
A thin convenience wrapper around :func:`logging.getLogger` that keeps
@@ -75,6 +83,15 @@ def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[
Returns:
A Logger instance.
Example:
```python
from fastapi_toolsets.logger import get_logger
logger = get_logger() # uses caller's __name__
logger = get_logger("myapp") # explicit name
logger = get_logger(None) # root logger
```
"""
if name is _SENTINEL:
name = sys._getframe(1).f_globals.get("__name__")

View File

@@ -0,0 +1,21 @@
"""Prometheus metrics integration for FastAPI applications."""
from typing import Any
from .registry import Metric, MetricsRegistry
try:
from .handler import init_metrics
except ImportError:
def init_metrics(*_args: Any, **_kwargs: Any) -> None:
from .._imports import require_extra
require_extra(package="prometheus_client", extra="metrics")
__all__ = [
"Metric",
"MetricsRegistry",
"init_metrics",
]

View File

@@ -0,0 +1,81 @@
"""Prometheus metrics endpoint for FastAPI applications."""
import inspect
import os
from fastapi import FastAPI
from fastapi.responses import Response
from prometheus_client import (
CONTENT_TYPE_LATEST,
CollectorRegistry,
generate_latest,
multiprocess,
)
from ..logger import get_logger
from .registry import MetricsRegistry
logger = get_logger()
def _is_multiprocess() -> bool:
"""Check if prometheus multi-process mode is enabled."""
return "PROMETHEUS_MULTIPROC_DIR" in os.environ
def init_metrics(
app: FastAPI,
registry: MetricsRegistry,
*,
path: str = "/metrics",
) -> FastAPI:
"""Register a Prometheus ``/metrics`` endpoint on a FastAPI app.
Args:
app: FastAPI application instance.
registry: A :class:`MetricsRegistry` containing providers and collectors.
path: URL path for the metrics endpoint (default ``/metrics``).
Returns:
The same FastAPI instance (for chaining).
Example:
```python
from fastapi import FastAPI
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
metrics = MetricsRegistry()
app = FastAPI()
init_metrics(app, registry=metrics)
```
"""
for provider in registry.get_providers():
logger.debug("Initialising metric provider '%s'", provider.name)
registry._instances[provider.name] = provider.func()
# Partition collectors and cache env check at startup — both are stable for the app lifetime.
async_collectors = [
c for c in registry.get_collectors() if inspect.iscoroutinefunction(c.func)
]
sync_collectors = [
c for c in registry.get_collectors() if not inspect.iscoroutinefunction(c.func)
]
multiprocess_mode = _is_multiprocess()
@app.get(path, include_in_schema=False)
async def metrics_endpoint() -> Response:
for collector in sync_collectors:
collector.func()
for collector in async_collectors:
await collector.func()
if multiprocess_mode:
prom_registry = CollectorRegistry()
multiprocess.MultiProcessCollector(prom_registry)
output = generate_latest(prom_registry)
else:
output = generate_latest()
return Response(content=output, media_type=CONTENT_TYPE_LATEST)
return app

View File

@@ -0,0 +1,104 @@
"""Metrics registry with decorator-based registration."""
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, cast
from ..logger import get_logger
logger = get_logger()
@dataclass
class Metric:
"""A metric definition with metadata."""
name: str
func: Callable[..., Any]
collect: bool = field(default=False)
class MetricsRegistry:
"""Registry for managing Prometheus metric providers and collectors."""
def __init__(self) -> None:
self._metrics: dict[str, Metric] = {}
self._instances: dict[str, Any] = {}
def register(
self,
func: Callable[..., Any] | None = None,
*,
name: str | None = None,
collect: bool = False,
) -> Callable[..., Any]:
"""Register a metric provider or collector function.
Can be used as a decorator with or without arguments.
Args:
func: The metric function to register.
name: Metric name (defaults to function name).
collect: If ``True``, the function is called on every scrape.
If ``False`` (default), called once at init time.
"""
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
metric_name = name or cast(Any, fn).__name__
self._metrics[metric_name] = Metric(
name=metric_name,
func=fn,
collect=collect,
)
return fn
if func is not None:
return decorator(func)
return decorator
def get(self, name: str) -> Any:
"""Return the metric instance created by a provider.
Args:
name: The metric name (defaults to the provider function name).
Raises:
KeyError: If the metric name is unknown or ``init_metrics`` has not
been called yet.
"""
if name not in self._instances:
if name in self._metrics:
raise KeyError(
f"Metric '{name}' exists but has not been initialized yet. "
"Ensure init_metrics() has been called before accessing metric instances."
)
raise KeyError(f"Unknown metric '{name}'.")
return self._instances[name]
def include_registry(self, registry: "MetricsRegistry") -> None:
"""Include another :class:`MetricsRegistry` into this one.
Args:
registry: The registry to merge in.
Raises:
ValueError: If a metric name already exists in the current registry.
"""
for metric_name, definition in registry._metrics.items():
if metric_name in self._metrics:
raise ValueError(
f"Metric '{metric_name}' already exists in the current registry"
)
self._metrics[metric_name] = definition
def get_all(self) -> list[Metric]:
"""Get all registered metric definitions."""
return list(self._metrics.values())
def get_providers(self) -> list[Metric]:
"""Get metric providers (called once at init)."""
return [m for m in self._metrics.values() if not m.collect]
def get_collectors(self) -> list[Metric]:
"""Get collectors (called on each scrape)."""
return [m for m in self._metrics.values() if m.collect]

View File

@@ -0,0 +1,21 @@
"""SQLAlchemy model mixins for common column patterns."""
from .columns import (
CreatedAtMixin,
TimestampMixin,
UUIDMixin,
UUIDv7Mixin,
UpdatedAtMixin,
)
from .watched import EventSession, ModelEvent, listens_for
__all__ = [
"EventSession",
"ModelEvent",
"UUIDMixin",
"UUIDv7Mixin",
"CreatedAtMixin",
"UpdatedAtMixin",
"TimestampMixin",
"listens_for",
]

View File

@@ -0,0 +1,50 @@
"""SQLAlchemy column mixins for common column patterns."""
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Uuid, text
from sqlalchemy.orm import Mapped, mapped_column
class UUIDMixin:
"""Mixin that adds a UUID primary key auto-generated by the database."""
id: Mapped[uuid.UUID] = mapped_column(
Uuid,
primary_key=True,
server_default=text("gen_random_uuid()"),
)
class UUIDv7Mixin:
"""Mixin that adds a UUIDv7 primary key auto-generated by the database."""
id: Mapped[uuid.UUID] = mapped_column(
Uuid,
primary_key=True,
server_default=text("uuidv7()"),
)
class CreatedAtMixin:
"""Mixin that adds a ``created_at`` timestamp column."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=text("clock_timestamp()"),
)
class UpdatedAtMixin:
"""Mixin that adds an ``updated_at`` timestamp column."""
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=text("clock_timestamp()"),
onupdate=text("clock_timestamp()"),
)
class TimestampMixin(CreatedAtMixin, UpdatedAtMixin):
"""Mixin that combines ``created_at`` and ``updated_at`` timestamp columns."""

View File

@@ -0,0 +1,290 @@
"""Field-change monitoring via SQLAlchemy session events."""
import inspect
from collections.abc import Callable
from enum import Enum
from typing import Any
from sqlalchemy import event
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
from ..logger import get_logger
_logger = get_logger()
class ModelEvent(str, Enum):
"""Event types dispatched by :class:`EventSession`."""
CREATE = "create"
DELETE = "delete"
UPDATE = "update"
_CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception"
_SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates"
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
_WATCHED_MODELS: set[type] = set()
_WATCHED_CACHE: dict[type, bool] = {}
_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
def _invalidate_caches() -> None:
"""Clear lookup caches after handler registration."""
_WATCHED_CACHE.clear()
_HANDLER_CACHE.clear()
def listens_for(
model_class: type,
event_types: list[ModelEvent] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Register a callback for one or more model lifecycle events.
Args:
model_class: The SQLAlchemy model class to listen on.
event_types: List of :class:`ModelEvent` values to listen for.
Defaults to all event types.
"""
evs = event_types if event_types is not None else list(ModelEvent)
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
for ev in evs:
_EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn)
_WATCHED_MODELS.add(model_class)
_invalidate_caches()
return fn
return decorator
def _is_watched(obj: Any) -> bool:
"""Return True if *obj*'s type (or any ancestor) has registered handlers."""
cls = type(obj)
try:
return _WATCHED_CACHE[cls]
except KeyError:
result = any(klass in _WATCHED_MODELS for klass in cls.__mro__)
_WATCHED_CACHE[cls] = result
return result
def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]:
"""Return registered handlers for *cls* and *ev*, walking the MRO."""
key = (cls, ev)
try:
return _HANDLER_CACHE[key]
except KeyError:
handlers: list[Callable[..., Any]] = []
for klass in cls.__mro__:
handlers.extend(_EVENT_HANDLERS.get((klass, ev), []))
_HANDLER_CACHE[key] = handlers
return handlers
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState
state_dict = state.dict
snapshot: dict[str, Any] = {}
for prop in state.mapper.column_attrs:
if prop.key in state_dict:
snapshot[prop.key] = state_dict[prop.key]
elif ( # pragma: no cover
not state.expired
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
and all(
col.nullable
and col.server_default is None
and col.server_onupdate is None
for col in prop.columns
)
):
snapshot[prop.key] = None
return snapshot
def _get_watched_fields(cls: type) -> tuple[str, ...] | None:
"""Return the watched fields for *cls*."""
fields = getattr(cls, "__watched_fields__", None)
if fields is not None and (
not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields)
):
raise TypeError(
f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], "
f"got {type(fields).__name__}"
)
return fields
def _upsert_changes(
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
obj: Any,
changes: dict[str, dict[str, Any]],
) -> None:
"""Insert or merge *changes* into *pending* for *obj*."""
key = id(obj)
if key in pending:
existing = pending[key][1]
for field, change in changes.items():
if field in existing:
existing[field]["new"] = change["new"]
else:
existing[field] = change
else:
pending[key] = (obj, changes)
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
def _after_flush(session: Any, flush_context: Any) -> None:
# New objects: capture reference. Attributes will be refreshed after commit.
for obj in session.new:
if _is_watched(obj):
session.info.setdefault(_SESSION_CREATES, []).append(obj)
# Deleted objects: snapshot now while attributes are still loaded.
for obj in session.deleted:
if _is_watched(obj):
snapshot = _snapshot_column_attrs(obj)
session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot))
# Dirty objects: read old/new from SQLAlchemy attribute history.
for obj in session.dirty:
if not _is_watched(obj):
continue
watched = _get_watched_fields(type(obj))
changes: dict[str, dict[str, Any]] = {}
inst_attrs = sa_inspect(obj).attrs
attrs = (
((field, inst_attrs[field]) for field in watched)
if watched is not None
else ((s.key, s) for s in inst_attrs)
)
for field, attr_state in attrs:
history = attr_state.history
if history.has_changes() and history.deleted:
changes[field] = {
"old": history.deleted[0],
"new": history.added[0] if history.added else None,
}
if changes:
_upsert_changes(
session.info.setdefault(_SESSION_UPDATES, {}),
obj,
changes,
)
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
def _after_rollback(session: Any) -> None:
if session.in_transaction():
return
session.info.pop(_SESSION_CREATES, None)
session.info.pop(_SESSION_DELETES, None)
session.info.pop(_SESSION_UPDATES, None)
async def _invoke_callback(
fn: Callable[..., Any],
obj: Any,
event_type: ModelEvent,
changes: dict[str, dict[str, Any]] | None,
) -> None:
"""Call *fn* and await the result if it is awaitable."""
result = fn(obj, event_type, changes)
if inspect.isawaitable(result):
await result
class EventSession(AsyncSession):
"""AsyncSession subclass that dispatches lifecycle callbacks after commit."""
async def commit(self) -> None: # noqa: C901
await super().commit()
creates: list[Any] = self.info.pop(_SESSION_CREATES, [])
deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop(
_SESSION_UPDATES, {}
)
if not creates and not deletes and not field_changes:
return
# Suppress transient objects (created + deleted in same transaction).
if creates and deletes:
created_ids = {id(o) for o in creates}
deleted_ids = {id(o) for o, _ in deletes}
transient_ids = created_ids & deleted_ids
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids]
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
}
# Suppress updates for deleted objects (row is gone, refresh would fail).
if deletes and field_changes:
deleted_ids = {id(o) for o, _ in deletes}
field_changes = {
k: v for k, v in field_changes.items() if k not in deleted_ids
}
# Suppress updates for newly created objects (CREATE-only semantics).
if creates and field_changes:
create_ids = {id(o) for o in creates}
field_changes = {
k: v for k, v in field_changes.items() if k not in create_ids
}
# Dispatch CREATE callbacks.
for obj in creates:
try:
state = sa_inspect(obj, raiseerr=False)
if (
state is None or state.detached or state.transient
): # pragma: no cover
continue
await self.refresh(obj)
for handler in _get_handlers(type(obj), ModelEvent.CREATE):
await _invoke_callback(handler, obj, ModelEvent.CREATE, None)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
# Dispatch DELETE callbacks (restore snapshot; row is gone).
for obj, snapshot in deletes:
try:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
for handler in _get_handlers(type(obj), ModelEvent.DELETE):
await _invoke_callback(handler, obj, ModelEvent.DELETE, None)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
# Dispatch UPDATE callbacks.
for obj, changes in field_changes.values():
try:
state = sa_inspect(obj, raiseerr=False)
if (
state is None or state.detached or state.transient
): # pragma: no cover
continue
await self.refresh(obj)
for handler in _get_handlers(type(obj), ModelEvent.UPDATE):
await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
async def rollback(self) -> None:
await super().rollback()
self.info.pop(_SESSION_CREATES, None)
self.info.pop(_SESSION_DELETES, None)
self.info.pop(_SESSION_UPDATES, None)

View File

@@ -1,8 +1,30 @@
from .plugin import register_fixtures
from .utils import create_async_client, create_db_session
"""Pytest helpers for FastAPI testing: sessions, clients, and fixtures."""
try:
from .plugin import register_fixtures
except ImportError:
from .._imports import require_extra
require_extra(package="pytest", extra="pytest")
try:
from .utils import (
cleanup_tables,
create_async_client,
create_db_session,
create_worker_database,
worker_database_url,
)
except ImportError:
from .._imports import require_extra
require_extra(package="httpx", extra="pytest")
__all__ = [
"cleanup_tables",
"create_async_client",
"create_db_session",
"create_worker_database",
"register_fixtures",
"worker_database_url",
]

View File

@@ -1,55 +1,4 @@
"""Pytest plugin for using FixtureRegistry fixtures in tests.
This module provides utilities to automatically generate pytest fixtures
from your FixtureRegistry, with proper dependency resolution.
Example:
# conftest.py
import pytest
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app.fixtures import fixtures # Your FixtureRegistry
from app.models import Base
from fastapi_toolsets.pytest_plugin import register_fixtures
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
@pytest.fixture
async def engine():
engine = create_async_engine(DATABASE_URL)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(engine):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
session = session_factory()
try:
yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
# Automatically generate pytest fixtures from registry
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
register_fixtures(fixtures, globals())
Usage in tests:
# test_users.py
async def test_user_count(db_session, fixture_users):
# fixture_users automatically loads fixture_roles first (if dependency)
# and returns the list of User models
assert len(fixture_users) > 0
async def test_user_role(db_session, fixture_users):
user = fixture_users[0]
assert user.role_id is not None
"""
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
from collections.abc import Callable, Sequence
from typing import Any
@@ -86,6 +35,7 @@ def register_fixtures(
List of created fixture names
Example:
```python
# conftest.py
from app.fixtures import fixtures
from fastapi_toolsets.pytest_plugin import register_fixtures
@@ -96,6 +46,7 @@ def register_fixtures(
# - fixture_roles
# - fixture_users (depends on fixture_roles if users depends on roles)
# - fixture_posts (depends on fixture_users if posts depends on users)
```
"""
created_fixtures: list[str] = []

View File

@@ -1,26 +1,140 @@
"""Pytest helper utilities for FastAPI testing."""
from collections.abc import AsyncGenerator
import os
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager
from typing import Any
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from ..db import create_db_context
from ..db import cleanup_tables, create_database
from ..models.watched import EventSession
def _get_xdist_worker(default_test_db: str) -> str:
"""Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
When xdist is not installed or not active, the variable is absent and
*default_test_db* is returned instead.
Args:
default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
is not set.
"""
return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
def worker_database_url(database_url: str, default_test_db: str) -> str:
"""Derive a per-worker database URL for pytest-xdist parallel runs.
Appends ``_{worker_name}`` to the database name so each xdist worker
operates on its own database. When not running under xdist,
``_{default_test_db}`` is appended instead.
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
variable (set automatically by xdist in each worker process).
Args:
database_url: Original database connection URL.
default_test_db: Suffix appended to the database name when
``PYTEST_XDIST_WORKER`` is not set.
Returns:
A database URL with a worker- or default-specific database name.
"""
worker = _get_xdist_worker(default_test_db=default_test_db)
url = make_url(database_url)
url = url.set(database=f"{url.database}_{worker}")
return url.render_as_string(hide_password=False)
@asynccontextmanager
async def create_worker_database(
database_url: str,
default_test_db: str = "test_db",
) -> AsyncGenerator[str, None]:
"""Create and drop a per-worker database for pytest-xdist isolation.
Derives a worker-specific database URL using :func:`worker_database_url`,
then delegates to :func:`~fastapi_toolsets.db.create_database` to create
and drop it. Intended for use as a **session-scoped** fixture.
When running under xdist the database name is suffixed with the worker
name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
Args:
database_url: Original database connection URL (used as the server
connection and as the base for the worker database name).
default_test_db: Suffix appended to the database name when
``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
Yields:
The worker-specific database URL.
Example:
```python
from fastapi_toolsets.pytest import create_worker_database, create_db_session
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
@pytest.fixture(scope="session")
async def worker_db_url():
async with create_worker_database(DATABASE_URL) as url:
yield url
@pytest.fixture
async def db_session(worker_db_url):
async with create_db_session(
worker_db_url, Base, cleanup=True
) as session:
yield session
```
"""
worker_url = worker_database_url(
database_url=database_url, default_test_db=default_test_db
)
worker_db_name = make_url(worker_url).database
assert worker_db_name is not None
engine = create_async_engine(database_url, isolation_level="AUTOCOMMIT")
try:
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
await create_database(db_name=worker_db_name, server_url=database_url)
yield worker_url
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
finally:
await engine.dispose()
@asynccontextmanager
async def create_async_client(
app: Any,
base_url: str = "http://test",
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
) -> AsyncGenerator[AsyncClient, None]:
"""Create an async httpx client for testing FastAPI applications.
Args:
app: FastAPI application instance.
base_url: Base URL for requests. Defaults to "http://test".
dependency_overrides: Optional mapping of original dependencies to
their test replacements. Applied via ``app.dependency_overrides``
before yielding and cleaned up after.
Yields:
An AsyncClient configured for the app.
@@ -41,10 +155,39 @@ async def create_async_client(
response = await client.get("/health")
assert response.status_code == 200
```
Example with dependency overrides:
```python
from fastapi_toolsets.pytest import create_async_client, create_db_session
from app.db import get_db
@pytest.fixture
async def db_session():
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
yield session
@pytest.fixture
async def client(db_session):
async def override():
yield db_session
async with create_async_client(
app, dependency_overrides={get_db: override}
) as c:
yield c
```
"""
if dependency_overrides:
app.dependency_overrides.update(dependency_overrides)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url=base_url) as client:
yield client
try:
async with AsyncClient(transport=transport, base_url=base_url) as client:
yield client
finally:
if dependency_overrides:
for key in dependency_overrides:
app.dependency_overrides.pop(key, None)
@asynccontextmanager
@@ -55,6 +198,7 @@ async def create_db_session(
echo: bool = False,
expire_on_commit: bool = False,
drop_tables: bool = True,
cleanup: bool = False,
) -> AsyncGenerator[AsyncSession, None]:
"""Create a database session for testing.
@@ -67,6 +211,8 @@ async def create_db_session(
echo: Enable SQLAlchemy query logging. Defaults to False.
expire_on_commit: Expire objects after commit. Defaults to False.
drop_tables: Drop tables after test. Defaults to True.
cleanup: Truncate all tables after test using
:func:`cleanup_tables`. Defaults to False.
Yields:
An AsyncSession ready for database operations.
@@ -80,7 +226,9 @@ async def create_db_session(
@pytest.fixture
async def db_session():
async with create_db_session(DATABASE_URL, Base) as session:
async with create_db_session(
DATABASE_URL, Base, cleanup=True
) as session:
yield session
async def test_create_user(db_session: AsyncSession):
@@ -96,13 +244,15 @@ async def create_db_session(
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)
# Create session using existing db context utility
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
get_session = create_db_context(session_maker)
async with get_session() as session:
session_maker = async_sessionmaker(
engine, expire_on_commit=expire_on_commit, class_=EventSession
)
async with session_maker() as session:
yield session
if cleanup:
await cleanup_tables(session=session, base=base)
if drop_tables:
async with engine.begin() as conn:
await conn.run_sync(base.metadata.drop_all)

View File

@@ -1,22 +1,27 @@
"""Base Pydantic schemas for API responses."""
import math
from enum import Enum
from typing import ClassVar, Generic, TypeVar
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field, computed_field
from .types import DataT
__all__ = [
"ApiError",
"CursorPagination",
"CursorPaginatedResponse",
"ErrorResponse",
"Pagination",
"OffsetPagination",
"OffsetPaginatedResponse",
"PaginatedResponse",
"PaginationType",
"PydanticBase",
"Response",
"ResponseStatus",
]
DataT = TypeVar("DataT")
class PydanticBase(BaseModel):
"""Base class for all Pydantic models with common configuration."""
@@ -50,6 +55,7 @@ class ApiError(PydanticBase):
msg: str
desc: str
err_code: str
data: Any | None = None
class BaseResponse(PydanticBase):
@@ -70,7 +76,9 @@ class Response(BaseResponse, Generic[DataT]):
"""Generic API response with data payload.
Example:
```python
Response[UserRead](data=user, message="User retrieved")
```
"""
data: DataT | None = None
@@ -84,34 +92,115 @@ class ErrorResponse(BaseResponse):
status: ResponseStatus = ResponseStatus.FAIL
description: str | None = None
data: None = None
data: Any | None = None
class Pagination(PydanticBase):
"""Pagination metadata for list responses.
class OffsetPagination(PydanticBase):
"""Pagination metadata for offset-based list responses.
Attributes:
total_count: Total number of items across all pages
total_count: Total number of items across all pages.
``None`` when ``include_total=False``.
items_per_page: Number of items per page
page: Current page number (1-indexed)
has_more: Whether there are more pages
pages: Total number of pages
"""
total_count: int
total_count: int | None
items_per_page: int
page: int
has_more: bool
@computed_field
@property
def pages(self) -> int | None:
"""Total number of pages, or ``None`` when ``total_count`` is unknown."""
if self.total_count is None:
return None
if self.items_per_page == 0:
return 0
return math.ceil(self.total_count / self.items_per_page)
class CursorPagination(PydanticBase):
"""Pagination metadata for cursor-based list responses.
Attributes:
next_cursor: Encoded cursor for the next page, or None on the last page.
prev_cursor: Encoded cursor for the previous page, or None on the first page.
items_per_page: Number of items requested per page.
has_more: Whether there is at least one more page after this one.
"""
next_cursor: str | None
prev_cursor: str | None = None
items_per_page: int
has_more: bool
class PaginationType(str, Enum):
"""Pagination strategy selector for :meth:`.AsyncCrud.paginate`."""
OFFSET = "offset"
CURSOR = "cursor"
class PaginatedResponse(BaseResponse, Generic[DataT]):
"""Paginated API response for list endpoints.
Example:
PaginatedResponse[UserRead](
data=users,
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
)
Base class and return type for endpoints that support both pagination
strategies. Use :class:`OffsetPaginatedResponse` or
:class:`CursorPaginatedResponse` when the strategy is fixed.
When used as ``PaginatedResponse[T]`` in a return annotation, subscripting
returns ``Annotated[Union[CursorPaginatedResponse[T], OffsetPaginatedResponse[T]], Field(discriminator="pagination_type")]``
so FastAPI emits a proper ``oneOf`` + discriminator in the OpenAPI schema.
"""
data: list[DataT]
pagination: Pagination
pagination: OffsetPagination | CursorPagination
pagination_type: PaginationType | None = None
filter_attributes: dict[str, list[Any]] | None = None
search_columns: list[str] | None = None
order_columns: list[str] | None = None
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
def __class_getitem__( # ty:ignore[invalid-method-override]
cls, item: type[Any] | tuple[type[Any], ...]
) -> type[Any]:
if cls is PaginatedResponse and not isinstance(item, TypeVar):
cached = cls._discriminated_union_cache.get(item)
if cached is None:
cached = Annotated[
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # ty:ignore[invalid-type-form]
Field(discriminator="pagination_type"),
]
cls._discriminated_union_cache[item] = cached
return cached # ty:ignore[invalid-return-type]
return super().__class_getitem__(item)
class OffsetPaginatedResponse(PaginatedResponse[DataT]):
"""Paginated response with typed offset-based pagination metadata.
The ``pagination_type`` field is always ``"offset"`` and acts as a
discriminator, allowing frontend clients to narrow the union type returned
by a unified ``paginate()`` endpoint.
"""
pagination: OffsetPagination
pagination_type: Literal[PaginationType.OFFSET] = PaginationType.OFFSET
class CursorPaginatedResponse(PaginatedResponse[DataT]):
"""Paginated response with typed cursor-based pagination metadata.
The ``pagination_type`` field is always ``"cursor"`` and acts as a
discriminator, allowing frontend clients to narrow the union type returned
by a unified ``paginate()`` endpoint.
"""
pagination: CursorPagination
pagination_type: Literal[PaginationType.CURSOR] = PaginationType.CURSOR

View File

@@ -0,0 +1,24 @@
"""Authentication helpers for FastAPI using Security()."""
from .abc import AuthSource
from .oauth import (
oauth_build_authorization_redirect,
oauth_decode_state,
oauth_encode_state,
oauth_fetch_userinfo,
oauth_resolve_provider_urls,
)
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
__all__ = [
"APIKeyHeaderAuth",
"AuthSource",
"BearerTokenAuth",
"CookieAuth",
"MultiAuth",
"oauth_build_authorization_redirect",
"oauth_decode_state",
"oauth_encode_state",
"oauth_fetch_userinfo",
"oauth_resolve_provider_urls",
]

View File

@@ -0,0 +1,53 @@
"""Abstract base class for authentication sources."""
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable
from fastapi import Request
from fastapi.security import SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
def _ensure_async(fn: Callable[..., Any]) -> Callable[..., Any]:
"""Wrap *fn* so it can always be awaited, caching the coroutine check at init time."""
if inspect.iscoroutinefunction(fn):
return fn
async def wrapper(*args: Any, **kwargs: Any) -> Any:
return fn(*args, **kwargs)
return wrapper
class AuthSource(ABC):
"""Abstract base class for authentication sources."""
def __init__(self) -> None:
"""Set up the default FastAPI dependency signature."""
source = self
async def _call(
request: Request,
security_scopes: SecurityScopes, # noqa: ARG001
) -> Any:
credential = await source.extract(request)
if credential is None:
raise UnauthorizedError()
return await source.authenticate(credential)
self._call_fn: Callable[..., Any] = _call
self.__signature__ = inspect.signature(_call)
@abstractmethod
async def extract(self, request: Request) -> str | None:
"""Extract the raw credential from the request without validating."""
@abstractmethod
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the authenticated identity."""
async def __call__(self, **kwargs: Any) -> Any:
"""FastAPI dependency dispatch."""
return await self._call_fn(**kwargs)

View File

@@ -0,0 +1,140 @@
"""OAuth 2.0 / OIDC helper utilities."""
import base64
from typing import Any
from urllib.parse import urlencode
import httpx
from fastapi.responses import RedirectResponse
_discovery_cache: dict[str, dict] = {}
async def oauth_resolve_provider_urls(
discovery_url: str,
) -> tuple[str, str, str | None]:
"""Fetch the OIDC discovery document and return endpoint URLs.
Args:
discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
Returns:
A ``(authorization_url, token_url, userinfo_url)`` tuple.
*userinfo_url* is ``None`` when the provider does not advertise one.
"""
if discovery_url not in _discovery_cache:
async with httpx.AsyncClient() as client:
resp = await client.get(discovery_url)
resp.raise_for_status()
_discovery_cache[discovery_url] = resp.json()
cfg = _discovery_cache[discovery_url]
return (
cfg["authorization_endpoint"],
cfg["token_endpoint"],
cfg.get("userinfo_endpoint"),
)
async def oauth_fetch_userinfo(
*,
token_url: str,
userinfo_url: str,
code: str,
client_id: str,
client_secret: str,
redirect_uri: str,
) -> dict[str, Any]:
"""Exchange an authorization code for tokens and return the userinfo payload.
Performs the two-step OAuth 2.0 / OIDC token exchange:
1. POSTs the authorization *code* to *token_url* to obtain an access token.
2. GETs *userinfo_url* using that access token as a Bearer credential.
Args:
token_url: Provider's token endpoint.
userinfo_url: Provider's userinfo endpoint.
code: Authorization code received from the provider's callback.
client_id: OAuth application client ID.
client_secret: OAuth application client secret.
redirect_uri: Redirect URI that was used in the authorization request.
Returns:
The JSON payload returned by the userinfo endpoint as a plain ``dict``.
"""
async with httpx.AsyncClient() as client:
token_resp = await client.post(
token_url,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": client_id,
"client_secret": client_secret,
"redirect_uri": redirect_uri,
},
headers={"Accept": "application/json"},
)
token_resp.raise_for_status()
access_token = token_resp.json()["access_token"]
userinfo_resp = await client.get(
userinfo_url,
headers={"Authorization": f"Bearer {access_token}"},
)
userinfo_resp.raise_for_status()
return userinfo_resp.json()
def oauth_build_authorization_redirect(
authorization_url: str,
*,
client_id: str,
scopes: str,
redirect_uri: str,
destination: str,
) -> RedirectResponse:
"""Return an OAuth 2.0 authorization ``RedirectResponse``.
Args:
authorization_url: Provider's authorization endpoint.
client_id: OAuth application client ID.
scopes: Space-separated list of requested scopes.
redirect_uri: URI the provider should redirect back to after authorization.
destination: URL the user should be sent to after the full OAuth flow
completes (encoded as ``state``).
Returns:
A :class:`~fastapi.responses.RedirectResponse` to the provider's
authorization page.
"""
params = urlencode(
{
"client_id": client_id,
"response_type": "code",
"scope": scopes,
"redirect_uri": redirect_uri,
"state": oauth_encode_state(destination),
}
)
return RedirectResponse(f"{authorization_url}?{params}")
def oauth_encode_state(url: str) -> str:
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
return base64.urlsafe_b64encode(url.encode()).decode()
def oauth_decode_state(state: str | None, *, fallback: str) -> str:
"""Decode a base64url OAuth ``state`` parameter.
Handles missing padding (some providers strip ``=``).
Returns *fallback* if *state* is absent, the literal string ``"null"``,
or cannot be decoded.
"""
if not state or state == "null":
return fallback
try:
padded = state + "=" * (4 - len(state) % 4)
return base64.urlsafe_b64decode(padded).decode()
except Exception:
return fallback

View File

@@ -0,0 +1,8 @@
"""Built-in authentication source implementations."""
from .header import APIKeyHeaderAuth
from .bearer import BearerTokenAuth
from .cookie import CookieAuth
from .multi import MultiAuth
__all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]

View File

@@ -0,0 +1,120 @@
"""Bearer token authentication source."""
import inspect
import secrets
from typing import Annotated, Any, Callable
from fastapi import Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _ensure_async
class BearerTokenAuth(AuthSource):
"""Bearer token authentication source.
Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
The validator is called as ``await validator(credential, **kwargs)``
where ``kwargs`` are the extra keyword arguments provided at instantiation.
Args:
validator: Sync or async callable that receives the credential and any
extra keyword arguments, and returns the authenticated identity
(e.g. a ``User`` model). Should raise
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
whose value starts with this prefix are matched. The prefix is
**kept** in the value passed to the validator — store and compare
tokens with their prefix included. Use :meth:`generate_token` to
create correctly-prefixed tokens. This enables multiple
``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
for user tokens, ``"org_"`` for org tokens).
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
validator: Callable[..., Any],
*,
prefix: str | None = None,
**kwargs: Any,
) -> None:
self._validator = _ensure_async(validator)
self._prefix = prefix
self._kwargs = kwargs
self._scheme = HTTPBearer(auto_error=False)
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
credentials: Annotated[
HTTPAuthorizationCredentials | None, Depends(self._scheme)
] = None,
) -> Any:
if credentials is None:
raise UnauthorizedError()
return await self._validate(credentials.credentials)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
async def _validate(self, token: str) -> Any:
"""Check prefix and call the validator."""
if self._prefix is not None and not token.startswith(self._prefix):
raise UnauthorizedError()
return await self._validator(token, **self._kwargs)
async def extract(self, request: Any) -> str | None:
"""Extract the raw credential from the request without validating.
Returns ``None`` if no ``Authorization: Bearer`` header is present,
the token is empty, or the token does not match the configured prefix.
The prefix is included in the returned value.
"""
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
return None
token = auth[7:]
if not token:
return None
if self._prefix is not None and not token.startswith(self._prefix):
return None
return token
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the identity.
Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
the extra keyword arguments provided at instantiation.
"""
return await self._validate(credential)
def require(self, **kwargs: Any) -> "BearerTokenAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return BearerTokenAuth(
self._validator,
prefix=self._prefix,
**{**self._kwargs, **kwargs},
)
def generate_token(self, nbytes: int = 32) -> str:
"""Generate a secure random token for this auth source.
Returns a URL-safe random token. If a prefix is configured it is
prepended — the returned value is what you store in your database
and return to the client as-is.
Args:
nbytes: Number of random bytes before base64 encoding. The
resulting string is ``ceil(nbytes * 4 / 3)`` characters
(43 chars for the default 32 bytes). Defaults to 32.
Returns:
A ready-to-use token string (e.g. ``"user_Xk3..."``).
"""
token = secrets.token_urlsafe(nbytes)
if self._prefix is not None:
return f"{self._prefix}{token}"
return token

View File

@@ -0,0 +1,139 @@
"""Cookie-based authentication source."""
import base64
import hashlib
import hmac
import inspect
import json
import time
from typing import Annotated, Any, Callable
from fastapi import Depends, Request, Response
from fastapi.security import APIKeyCookie, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _ensure_async
class CookieAuth(AuthSource):
"""Cookie-based authentication source.
Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
proof sessions without any database entry.
Args:
name: Cookie name.
validator: Sync or async callable that receives the cookie value
(plain, after signature verification when ``secret_key`` is set)
and any extra keyword arguments, and returns the authenticated
identity.
secret_key: When provided, the cookie is HMAC-SHA256 signed.
:meth:`set_cookie` embeds an expiry and signs the payload;
:meth:`extract` verifies the signature and expiry before handing
the plain value to the validator. When ``None`` (default), the raw
cookie value is passed to the validator as-is.
ttl: Cookie lifetime in seconds (default 24 h). Only used when
``secret_key`` is set.
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
name: str,
validator: Callable[..., Any],
*,
secret_key: str | None = None,
ttl: int = 86400,
**kwargs: Any,
) -> None:
self._name = name
self._validator = _ensure_async(validator)
self._secret_key = secret_key
self._ttl = ttl
self._kwargs = kwargs
self._scheme = APIKeyCookie(name=name, auto_error=False)
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
value: Annotated[str | None, Depends(self._scheme)] = None,
) -> Any:
if value is None:
raise UnauthorizedError()
plain = self._verify(value)
return await self._validator(plain, **self._kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
def _hmac(self, data: str) -> str:
if self._secret_key is None:
raise RuntimeError("_hmac called without secret_key configured")
return hmac.new(
self._secret_key.encode(), data.encode(), hashlib.sha256
).hexdigest()
def _sign(self, value: str) -> str:
data = base64.urlsafe_b64encode(
json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
).decode()
return f"{data}.{self._hmac(data)}"
def _verify(self, cookie_value: str) -> str:
"""Return the plain value, verifying HMAC + expiry when signed."""
if not self._secret_key:
return cookie_value
try:
data, sig = cookie_value.rsplit(".", 1)
except ValueError:
raise UnauthorizedError()
if not hmac.compare_digest(self._hmac(data), sig):
raise UnauthorizedError()
try:
payload = json.loads(base64.urlsafe_b64decode(data))
value: str = payload["v"]
exp: int = payload["exp"]
except Exception:
raise UnauthorizedError()
if exp < int(time.time()):
raise UnauthorizedError()
return value
async def extract(self, request: Request) -> str | None:
return request.cookies.get(self._name)
async def authenticate(self, credential: str) -> Any:
plain = self._verify(credential)
return await self._validator(plain, **self._kwargs)
def require(self, **kwargs: Any) -> "CookieAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return CookieAuth(
self._name,
self._validator,
secret_key=self._secret_key,
ttl=self._ttl,
**{**self._kwargs, **kwargs},
)
def set_cookie(self, response: Response, value: str) -> None:
"""Attach the cookie to *response*, signing it when ``secret_key`` is set."""
cookie_value = self._sign(value) if self._secret_key else value
response.set_cookie(
self._name,
cookie_value,
httponly=True,
samesite="lax",
max_age=self._ttl,
)
def delete_cookie(self, response: Response) -> None:
"""Clear the session cookie (logout)."""
response.delete_cookie(self._name, httponly=True, samesite="lax")

View File

@@ -0,0 +1,67 @@
"""API key header authentication source."""
import inspect
from typing import Annotated, Any, Callable
from fastapi import Depends, Request
from fastapi.security import APIKeyHeader, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _ensure_async
class APIKeyHeaderAuth(AuthSource):
"""API key header authentication source.
Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
The validator is called as ``await validator(api_key, **kwargs)``
where ``kwargs`` are the extra keyword arguments provided at instantiation.
Args:
name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
validator: Sync or async callable that receives the API key and any
extra keyword arguments, and returns the authenticated identity.
Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
on failure.
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
name: str,
validator: Callable[..., Any],
**kwargs: Any,
) -> None:
self._name = name
self._validator = _ensure_async(validator)
self._kwargs = kwargs
self._scheme = APIKeyHeader(name=name, auto_error=False)
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
api_key: Annotated[str | None, Depends(self._scheme)] = None,
) -> Any:
if api_key is None:
raise UnauthorizedError()
return await self._validator(api_key, **self._kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
async def extract(self, request: Request) -> str | None:
"""Extract the API key from the configured header."""
return request.headers.get(self._name) or None
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the identity."""
return await self._validator(credential, **self._kwargs)
def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return APIKeyHeaderAuth(
self._name,
self._validator,
**{**self._kwargs, **kwargs},
)

View File

@@ -0,0 +1,119 @@
"""MultiAuth: combine multiple authentication sources into a single callable."""
import inspect
from typing import Any, cast
from fastapi import Request
from fastapi.security import SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource
class MultiAuth:
"""Combine multiple authentication sources into a single callable.
Sources are tried in order; the first one whose
:meth:`~AuthSource.extract` returns a non-``None`` credential wins.
Its :meth:`~AuthSource.authenticate` is called and the result returned.
If a credential is found but the validator raises, the exception propagates
immediately — the remaining sources are **not** tried. This prevents
silent fallthrough on invalid credentials.
If no source provides a credential,
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` is raised.
The :meth:`~AuthSource.extract` method of each source performs only
string matching (no I/O), so prefix-based dispatch is essentially free.
Any :class:`~AuthSource` subclass — including user-defined ones — can be
passed as a source.
Args:
*sources: Auth source instances to try in order.
Example::
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
cookie = CookieAuth("session", verify_session)
multi = MultiAuth(user_bearer, org_bearer, cookie)
@app.get("/data")
async def data_route(user = Security(multi)):
return user
# Apply a shared requirement to all sources at once
@app.get("/admin")
async def admin_route(user = Security(multi.require(role=Role.ADMIN))):
return user
"""
def __init__(self, *sources: AuthSource) -> None:
self._sources = sources
async def _call(
request: Request,
security_scopes: SecurityScopes, # noqa: ARG001
**kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
) -> Any:
for source in self._sources:
credential = await source.extract(request)
if credential is not None:
return await source.authenticate(credential)
raise UnauthorizedError()
self._call_fn = _call
# Build a merged signature that includes the security-scheme Depends()
# parameters from every source so FastAPI registers them in OpenAPI docs.
seen: set[str] = {"request", "security_scopes"}
merged: list[inspect.Parameter] = [
inspect.Parameter(
"request",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=Request,
),
inspect.Parameter(
"security_scopes",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=SecurityScopes,
),
]
for i, source in enumerate(sources):
for name, param in inspect.signature(source).parameters.items():
if name in seen:
continue
merged.append(param.replace(name=f"_s{i}_{name}"))
seen.add(name)
self.__signature__ = inspect.Signature(merged, return_annotation=Any)
async def __call__(self, **kwargs: Any) -> Any:
return await self._call_fn(**kwargs)
def require(self, **kwargs: Any) -> "MultiAuth":
"""Return a new :class:`MultiAuth` with kwargs forwarded to each source.
Calls ``.require(**kwargs)`` on every source that supports it. Sources
that do not implement ``.require()`` (e.g. custom :class:`~AuthSource`
subclasses) are passed through unchanged.
New kwargs are merged over each source's existing kwargs — new values
win on conflict::
multi = MultiAuth(bearer, cookie)
@app.get("/admin")
async def admin(user = Security(multi.require(role=Role.ADMIN))):
return user
"""
new_sources = tuple(
cast(Any, source).require(**kwargs)
if hasattr(source, "require")
else source
for source in self._sources
)
return MultiAuth(*new_sources)

View File

@@ -0,0 +1,28 @@
"""Shared type aliases for the fastapi-toolsets package."""
from collections.abc import AsyncGenerator, Callable, Mapping
from typing import Any, TypeVar
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnElement
# Generic TypeVars
DataT = TypeVar("DataT")
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
# CRUD type aliases
JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
# Search / facet / order type aliases
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
OrderFieldType = SearchFieldType
# Dependency type aliases
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any

View File

@@ -2,27 +2,39 @@
import os
import uuid
from enum import Enum
import pytest
from pydantic import BaseModel
from sqlalchemy import ForeignKey, String, Uuid
import datetime
import decimal
from sqlalchemy import (
Column,
Date,
DateTime,
Enum as SAEnum,
ForeignKey,
Integer,
JSON,
Numeric,
String,
Table,
Uuid,
)
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.schemas import PydanticBase
# PostgreSQL connection URL from environment or default for local development
DATABASE_URL = os.getenv("DATABASE_URL") or os.getenv(
"TEST_DATABASE_URL",
"postgresql+asyncpg://postgres:postgres@localhost:5432/fastapi_toolsets_test",
DATABASE_URL = os.getenv(
key="DATABASE_URL",
default="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres",
)
# =============================================================================
# Test Models
# =============================================================================
class Base(DeclarativeBase):
"""Base class for test models."""
@@ -49,6 +61,7 @@ class User(Base):
username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True)
is_active: Mapped[bool] = mapped_column(default=True)
notes: Mapped[str | None]
role_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("roles.id"), nullable=True
)
@@ -56,6 +69,64 @@ class User(Base):
role: Mapped[Role | None] = relationship(back_populates="users")
class Tag(Base):
"""Test tag model."""
__tablename__ = "tags"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50), unique=True)
post_tags = Table(
"post_tags",
Base.metadata,
Column(
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
),
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
)
class IntRole(Base):
"""Test role model with auto-increment integer PK."""
__tablename__ = "int_roles"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50), unique=True)
class Permission(Base):
"""Test model with composite primary key."""
__tablename__ = "permissions"
subject: Mapped[str] = mapped_column(String(50), primary_key=True)
action: Mapped[str] = mapped_column(String(50), primary_key=True)
class Event(Base):
"""Test model with DateTime and Date cursor columns."""
__tablename__ = "events"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(100))
occurred_at: Mapped[datetime.datetime] = mapped_column(DateTime)
scheduled_date: Mapped[datetime.date] = mapped_column(Date)
class Product(Base):
"""Test model with Numeric cursor column."""
__tablename__ = "products"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(100))
price: Mapped[decimal.Decimal] = mapped_column(Numeric(10, 2))
class Post(Base):
"""Test post model."""
@@ -67,10 +138,58 @@ class Post(Base):
is_published: Mapped[bool] = mapped_column(default=False)
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
# =============================================================================
# Test Schemas
# =============================================================================
class OrderStatus(int, Enum):
"""Integer-backed enum for order status."""
PENDING = 1
PROCESSING = 2
SHIPPED = 3
CANCELLED = 4
class Color(str, Enum):
"""String-backed enum for color."""
RED = "red"
GREEN = "green"
BLUE = "blue"
class Order(Base):
"""Test model with an IntEnum column (Enum(int, Enum)) and a raw Integer column."""
__tablename__ = "orders"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(100))
status: Mapped[OrderStatus] = mapped_column(SAEnum(OrderStatus))
priority: Mapped[int] = mapped_column(Integer)
color: Mapped[Color] = mapped_column(SAEnum(Color))
class Transfer(Base):
"""Test model with two FKs to the same table (users)."""
__tablename__ = "transfers"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
amount: Mapped[str] = mapped_column(String(50))
sender_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
receiver_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
class Article(Base):
"""Test article model with ARRAY and JSON columns."""
__tablename__ = "articles"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(200))
labels: Mapped[list[str]] = mapped_column(ARRAY(String))
metadata_: Mapped[dict | None] = mapped_column("metadata", JSON, nullable=True)
class RoleCreate(BaseModel):
@@ -80,6 +199,13 @@ class RoleCreate(BaseModel):
name: str
class RoleRead(PydanticBase):
"""Schema for reading a role."""
id: uuid.UUID
name: str
class RoleUpdate(BaseModel):
"""Schema for updating a role."""
@@ -96,6 +222,14 @@ class UserCreate(BaseModel):
role_id: uuid.UUID | None = None
class UserRead(PydanticBase):
"""Schema for reading a user (subset of fields — no email)."""
id: uuid.UUID
username: str
is_active: bool = True
class UserUpdate(BaseModel):
"""Schema for updating a user."""
@@ -105,6 +239,13 @@ class UserUpdate(BaseModel):
role_id: uuid.UUID | None = None
class TagCreate(BaseModel):
"""Schema for creating a tag."""
id: uuid.UUID | None = None
name: str
class PostCreate(BaseModel):
"""Schema for creating a post."""
@@ -123,18 +264,136 @@ class PostUpdate(BaseModel):
is_published: bool | None = None
# =============================================================================
# CRUD Classes
# =============================================================================
class PostM2MCreate(BaseModel):
"""Schema for creating a post with M2M tag IDs."""
id: uuid.UUID | None = None
title: str
content: str = ""
is_published: bool = False
author_id: uuid.UUID
tag_ids: list[uuid.UUID] = []
class PostM2MUpdate(BaseModel):
"""Schema for updating a post with M2M tag IDs."""
title: str | None = None
content: str | None = None
is_published: bool | None = None
tag_ids: list[uuid.UUID] | None = None
class IntRoleRead(PydanticBase):
"""Schema for reading an IntRole."""
id: int
name: str
class IntRoleCreate(BaseModel):
"""Schema for creating an IntRole."""
name: str
class EventRead(PydanticBase):
"""Schema for reading an Event."""
id: uuid.UUID
name: str
class EventCreate(BaseModel):
"""Schema for creating an Event."""
name: str
occurred_at: datetime.datetime
scheduled_date: datetime.date
class ProductRead(PydanticBase):
"""Schema for reading a Product."""
id: uuid.UUID
name: str
class ProductCreate(BaseModel):
"""Schema for creating a Product."""
name: str
price: decimal.Decimal
class ArticleCreate(BaseModel):
"""Schema for creating an article."""
id: uuid.UUID | None = None
title: str
labels: list[str] = []
class ArticleRead(PydanticBase):
"""Schema for reading an article."""
id: uuid.UUID
title: str
labels: list[str]
class OrderCreate(BaseModel):
"""Schema for creating an order."""
id: uuid.UUID | None = None
name: str
status: OrderStatus
priority: int = 0
color: Color = Color.RED
class OrderRead(PydanticBase):
"""Schema for reading an order."""
id: uuid.UUID
name: str
status: OrderStatus
priority: int
color: Color
class TransferCreate(BaseModel):
"""Schema for creating a transfer."""
id: uuid.UUID | None = None
amount: str
sender_id: uuid.UUID
receiver_id: uuid.UUID
class TransferRead(PydanticBase):
"""Schema for reading a transfer."""
id: uuid.UUID
amount: str
OrderCrud = CrudFactory(Order)
TransferCrud = CrudFactory(Transfer)
ArticleCrud = CrudFactory(Article)
RoleCrud = CrudFactory(Role)
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
UserCrud = CrudFactory(User)
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
PostCrud = CrudFactory(Post)
# =============================================================================
# Fixtures
# =============================================================================
TagCrud = CrudFactory(Tag)
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
EventCrud = CrudFactory(Event)
EventDateTimeCursorCrud = CrudFactory(Event, cursor_column=Event.occurred_at)
EventDateCursorCrud = CrudFactory(Event, cursor_column=Event.scheduled_date)
ProductCrud = CrudFactory(Product)
ProductNumericCursorCrud = CrudFactory(Product, cursor_column=Product.price)
@pytest.fixture
@@ -173,30 +432,3 @@ async def db_session(engine):
# Drop tables after test
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
def sample_role_data() -> RoleCreate:
"""Sample role creation data."""
return RoleCreate(name="admin")
@pytest.fixture
def sample_user_data() -> UserCreate:
"""Sample user creation data."""
return UserCreate(
username="testuser",
email="test@example.com",
is_active=True,
)
@pytest.fixture
def sample_post_data() -> PostCreate:
"""Sample post creation data."""
return PostCreate(
title="Test Post",
content="Test content",
is_published=True,
author_id=uuid.uuid4(),
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,28 @@
"""Tests for fastapi_toolsets.db module."""
import asyncio
import uuid
import pytest
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from fastapi_toolsets.db import (
LockMode,
cleanup_tables,
create_database,
create_db_context,
create_db_dependency,
get_transaction,
lock_tables,
wait_for_row_change,
)
from fastapi_toolsets.exceptions import NotFoundError
from fastapi_toolsets.pytest import create_db_session
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
class TestCreateDbDependency:
@@ -57,6 +68,55 @@ class TestCreateDbDependency:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.mark.anyio
async def test_in_transaction_on_yield(self):
"""Session is already in a transaction when the endpoint body starts."""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
get_db = create_db_dependency(session_factory)
async for session in get_db():
assert session.in_transaction()
break
await engine.dispose()
@pytest.mark.anyio
async def test_update_after_lock_tables_is_persisted(self):
"""Changes made after lock_tables exits (before endpoint returns) are committed.
Regression: without the auto-begin fix, lock_tables would start and commit a
real outer transaction, leaving the session idle. Any modifications after that
point were silently dropped.
"""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
try:
get_db = create_db_dependency(session_factory)
async for session in get_db():
async with lock_tables(session, [Role]):
role = Role(name="lock_then_update")
session.add(role)
await session.flush()
# lock_tables has exited — outer transaction must still be open
assert session.in_transaction()
role.name = "updated_after_lock"
async with session_factory() as verify:
result = await RoleCrud.first(
verify, [Role.name == "updated_after_lock"]
)
assert result is not None
finally:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
class TestCreateDbContext:
"""Tests for create_db_context."""
@@ -241,3 +301,182 @@ class TestLockTables:
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
assert result is None
class TestWaitForRowChange:
"""Tests for wait_for_row_change polling function."""
@pytest.mark.anyio
async def test_detects_update(self, db_session: AsyncSession, engine):
"""Returns updated instance when a column value changes."""
role = Role(name="watch_role")
db_session.add(role)
await db_session.commit()
async def update_later():
await asyncio.sleep(0.15)
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as other:
r = await other.get(Role, role.id)
assert r is not None
r.name = "updated_role"
await other.commit()
update_task = asyncio.create_task(update_later())
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
await update_task
assert result.name == "updated_role"
@pytest.mark.anyio
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
"""Only triggers on changes to specified columns."""
user = User(username="testuser", email="test@example.com")
db_session.add(user)
await db_session.commit()
async def update_later():
factory = async_sessionmaker(engine, expire_on_commit=False)
# First: change email (not watched) — should not trigger
await asyncio.sleep(0.15)
async with factory() as other:
u = await other.get(User, user.id)
assert u is not None
u.email = "new@example.com"
await other.commit()
# Second: change username (watched) — should trigger
await asyncio.sleep(0.15)
async with factory() as other:
u = await other.get(User, user.id)
assert u is not None
u.username = "newuser"
await other.commit()
update_task = asyncio.create_task(update_later())
result = await wait_for_row_change(
db_session, User, user.id, columns=["username"], interval=0.05
)
await update_task
assert result.username == "newuser"
assert result.email == "new@example.com"
@pytest.mark.anyio
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
"""Raises NotFoundError when the row does not exist."""
fake_id = uuid.uuid4()
with pytest.raises(NotFoundError, match="not found"):
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
@pytest.mark.anyio
async def test_timeout_raises(self, db_session: AsyncSession):
"""Raises TimeoutError when no change is detected within timeout."""
role = Role(name="timeout_role")
db_session.add(role)
await db_session.commit()
with pytest.raises(TimeoutError):
await wait_for_row_change(
db_session, Role, role.id, interval=0.05, timeout=0.2
)
@pytest.mark.anyio
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
"""Raises NotFoundError when the row is deleted during polling."""
role = Role(name="delete_role")
db_session.add(role)
await db_session.commit()
async def delete_later():
await asyncio.sleep(0.15)
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as other:
r = await other.get(Role, role.id)
await other.delete(r)
await other.commit()
delete_task = asyncio.create_task(delete_later())
with pytest.raises(NotFoundError):
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
await delete_task
class TestCreateDatabase:
"""Tests for create_database."""
@pytest.mark.anyio
async def test_creates_database(self):
"""Database is created by create_database."""
target_url = (
make_url(DATABASE_URL)
.set(database="test_create_db_general")
.render_as_string(hide_password=False)
)
expected_db = make_url(target_url).database
assert expected_db is not None
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
try:
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
await create_database(db_name=expected_db, server_url=DATABASE_URL)
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() == 1
# Cleanup
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
finally:
await engine.dispose()
class TestCleanupTables:
"""Tests for cleanup_tables helper."""
@pytest.mark.anyio
async def test_truncates_all_tables(self):
"""All table rows are removed after cleanup_tables."""
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=uuid.uuid4(), name="cleanup_role")
session.add(role)
await session.flush()
user = User(
id=uuid.uuid4(),
username="cleanup_user",
email="cleanup@test.com",
role_id=role.id,
)
session.add(user)
await session.commit()
# Verify rows exist
roles_count = await RoleCrud.count(session)
users_count = await UserCrud.count(session)
assert roles_count == 1
assert users_count == 1
await cleanup_tables(session, Base)
# Verify tables are empty
roles_count = await RoleCrud.count(session)
users_count = await UserCrud.count(session)
assert roles_count == 0
assert users_count == 0
@pytest.mark.anyio
async def test_noop_for_empty_metadata(self):
"""cleanup_tables does not raise when metadata has no tables."""
class EmptyBase(DeclarativeBase):
pass
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
# Should not raise
await cleanup_tables(session, EmptyBase)

View File

@@ -3,20 +3,42 @@
import inspect
import uuid
from collections.abc import AsyncGenerator
from typing import Any, cast
from typing import Annotated, Any, cast
import pytest
from fastapi.params import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.dependencies import BodyDependency, PathDependency
from fastapi_toolsets.dependencies import (
BodyDependency,
PathDependency,
_unwrap_session_dep,
)
from .conftest import Role, RoleCreate, RoleCrud, User
async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
"""Mock session dependency for testing."""
yield None
yield None # type: ignore[misc] # ty:ignore[invalid-yield]
MockSessionDep = Annotated[AsyncSession, Depends(mock_get_db)]
class TestUnwrapSessionDep:
def test_plain_callable_returned_as_is(self):
"""Plain callable is returned unchanged."""
assert _unwrap_session_dep(mock_get_db) is mock_get_db
def test_annotated_with_depends_unwrapped(self):
"""Annotated form with Depends is unwrapped to the plain callable."""
assert _unwrap_session_dep(MockSessionDep) is mock_get_db
def test_annotated_without_depends_returned_as_is(self):
"""Annotated form with no Depends falls back to returning session_dep as-is."""
annotated_no_dep = Annotated[AsyncSession, "not_a_depends"]
assert _unwrap_session_dep(annotated_no_dep) is annotated_no_dep
class TestPathDependency:
@@ -95,6 +117,39 @@ class TestPathDependency:
assert result.id == role.id
assert result.name == "test_role"
def test_annotated_session_dep_returns_depends_instance(self):
"""PathDependency accepts Annotated[AsyncSession, Depends(...)] form."""
dep = PathDependency(Role, Role.id, session_dep=MockSessionDep)
assert isinstance(dep, Depends)
def test_annotated_session_dep_signature(self):
"""PathDependency with Annotated session_dep produces a valid signature."""
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
sig = inspect.signature(dep.dependency)
assert "role_id" in sig.parameters
assert "session" in sig.parameters
assert isinstance(sig.parameters["session"].default, Depends)
def test_annotated_session_dep_unwraps_callable(self):
"""PathDependency with Annotated form uses the underlying callable, not the Annotated type."""
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
sig = inspect.signature(dep.dependency)
inner_dep = sig.parameters["session"].default
assert inner_dep.dependency is mock_get_db
@pytest.mark.anyio
async def test_annotated_session_dep_fetches_object(self, db_session):
"""PathDependency with Annotated session_dep correctly fetches object from database."""
role = await RoleCrud.create(db_session, RoleCreate(name="annotated_role"))
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
result = await dep.dependency(session=db_session, role_id=role.id)
assert result.id == role.id
assert result.name == "annotated_role"
class TestBodyDependency:
"""Tests for BodyDependency factory."""
@@ -184,3 +239,39 @@ class TestBodyDependency:
assert result.id == role.id
assert result.name == "body_test_role"
def test_annotated_session_dep_returns_depends_instance(self):
"""BodyDependency accepts Annotated[AsyncSession, Depends(...)] form."""
dep = BodyDependency(
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
)
assert isinstance(dep, Depends)
def test_annotated_session_dep_unwraps_callable(self):
"""BodyDependency with Annotated form uses the underlying callable, not the Annotated type."""
dep = cast(
Any,
BodyDependency(
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
),
)
sig = inspect.signature(dep.dependency)
inner_dep = sig.parameters["session"].default
assert inner_dep.dependency is mock_get_db
@pytest.mark.anyio
async def test_annotated_session_dep_fetches_object(self, db_session):
"""BodyDependency with Annotated session_dep correctly fetches object from database."""
role = await RoleCrud.create(db_session, RoleCreate(name="body_annotated_role"))
dep = cast(
Any,
BodyDependency(
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
),
)
result = await dep.dependency(session=db_session, role_id=role.id)
assert result.id == role.id
assert result.name == "body_annotated_role"

View File

@@ -0,0 +1,485 @@
"""Live test for the docs/examples/pagination-search.md example.
Spins up the exact FastAPI app described in the example (sourced from
docs_src/examples/pagination_search/) and exercises it through a real HTTP
client against a real PostgreSQL database.
"""
import datetime
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from docs_src.examples.pagination_search.db import get_db
from docs_src.examples.pagination_search.models import Article, Base, Category
from docs_src.examples.pagination_search.routes import router
from fastapi_toolsets.exceptions import init_exceptions_handlers
from fastapi_toolsets.pytest import create_db_session
from .conftest import DATABASE_URL
def build_app(session: AsyncSession) -> FastAPI:
app = FastAPI()
init_exceptions_handlers(app)
async def override_get_db():
yield session
app.dependency_overrides[get_db] = override_get_db
app.include_router(router)
return app
@pytest.fixture(scope="function")
async def ex_db_session():
"""Isolated session for the example models (separate tables from conftest)."""
async with create_db_session(DATABASE_URL, Base) as session:
yield session
@pytest.fixture
async def client(ex_db_session: AsyncSession):
app = build_app(ex_db_session)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
yield ac
async def seed(session: AsyncSession):
"""Insert representative fixture data."""
python = Category(name="python")
backend = Category(name="backend")
session.add_all([python, backend])
await session.flush()
now = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
session.add_all(
[
Article(
title="FastAPI tips",
body="Ten useful tips for FastAPI.",
status="published",
published=True,
category_id=python.id,
created_at=now,
),
Article(
title="SQLAlchemy async",
body="How to use async SQLAlchemy.",
status="published",
published=True,
category_id=backend.id,
created_at=now + datetime.timedelta(seconds=1),
),
Article(
title="Draft notes",
body="Work in progress.",
status="draft",
published=False,
category_id=None,
created_at=now + datetime.timedelta(seconds=2),
),
]
)
await session.commit()
class TestAppSessionDep:
@pytest.mark.anyio
async def test_get_db_yields_async_session(self):
"""get_db yields a real AsyncSession when called directly."""
from docs_src.examples.pagination_search.db import get_db
gen = get_db()
session = await gen.__anext__()
assert isinstance(session, AsyncSession)
await gen.aclose()
class TestOffsetPagination:
@pytest.mark.anyio
async def test_returns_all_articles(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 3
assert len(body["data"]) == 3
@pytest.mark.anyio
async def test_pagination_page_size(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?items_per_page=2&page=1")
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 2
assert body["pagination"]["total_count"] == 3
assert body["pagination"]["has_more"] is True
@pytest.mark.anyio
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?search=fastapi")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 1
assert body["data"][0]["title"] == "FastAPI tips"
@pytest.mark.anyio
async def test_search_traverses_relationship(
self, client: AsyncClient, ex_db_session
):
await seed(ex_db_session)
# "python" matches Category.name, not Article.title or body
resp = await client.get("/articles/offset?search=python")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 1
assert body["data"][0]["title"] == "FastAPI tips"
@pytest.mark.anyio
async def test_facet_filter_scalar(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?status=published")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 2
assert all(a["status"] == "published" for a in body["data"])
@pytest.mark.anyio
async def test_facet_filter_multi_value(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?status=published&status=draft")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 3
@pytest.mark.anyio
async def test_filter_attributes_in_response(
self, client: AsyncClient, ex_db_session
):
await seed(ex_db_session)
resp = await client.get("/articles/offset")
assert resp.status_code == 200
body = resp.json()
fa = body["filter_attributes"]
assert set(fa["status"]) == {"draft", "published"}
assert set(fa["category__name"]) == {"backend", "python"}
@pytest.mark.anyio
async def test_filter_attributes_scoped_to_filter(
self, client: AsyncClient, ex_db_session
):
await seed(ex_db_session)
resp = await client.get("/articles/offset?status=published")
body = resp.json()
# draft is filtered out → should not appear in filter_attributes
assert "draft" not in body["filter_attributes"]["status"]
@pytest.mark.anyio
async def test_search_and_filter_combined(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?search=async&status=published")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 1
assert body["data"][0]["title"] == "SQLAlchemy async"
class TestCursorPagination:
@pytest.mark.anyio
async def test_first_page(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/cursor?items_per_page=2")
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 2
assert body["pagination"]["has_more"] is True
assert body["pagination"]["next_cursor"] is not None
assert body["pagination"]["prev_cursor"] is None
@pytest.mark.anyio
async def test_second_page(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
first = await client.get("/articles/cursor?items_per_page=2")
next_cursor = first.json()["pagination"]["next_cursor"]
resp = await client.get(
f"/articles/cursor?items_per_page=2&cursor={next_cursor}"
)
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 1
assert body["pagination"]["has_more"] is False
@pytest.mark.anyio
async def test_facet_filter(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/cursor?status=draft")
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 1
assert body["data"][0]["status"] == "draft"
@pytest.mark.anyio
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/cursor?search=sqlalchemy")
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 1
assert body["data"][0]["title"] == "SQLAlchemy async"
class TestOffsetSorting:
"""Tests for order_by / order query parameters on the offset endpoint."""
@pytest.mark.anyio
async def test_default_order_uses_created_at_asc(
self, client: AsyncClient, ex_db_session
):
"""No order_by → default field (created_at) ASC."""
await seed(ex_db_session)
resp = await client.get("/articles/offset")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
@pytest.mark.anyio
async def test_order_by_title_asc(self, client: AsyncClient, ex_db_session):
"""order_by=title&order=asc returns alphabetical order."""
await seed(ex_db_session)
resp = await client.get("/articles/offset?order_by=title&order=asc")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["Draft notes", "FastAPI tips", "SQLAlchemy async"]
@pytest.mark.anyio
async def test_order_by_title_desc(self, client: AsyncClient, ex_db_session):
"""order_by=title&order=desc returns reverse alphabetical order."""
await seed(ex_db_session)
resp = await client.get("/articles/offset?order_by=title&order=desc")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["SQLAlchemy async", "FastAPI tips", "Draft notes"]
@pytest.mark.anyio
async def test_order_by_created_at_desc(self, client: AsyncClient, ex_db_session):
"""order_by=created_at&order=desc returns newest-first."""
await seed(ex_db_session)
resp = await client.get("/articles/offset?order_by=created_at&order=desc")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["Draft notes", "SQLAlchemy async", "FastAPI tips"]
@pytest.mark.anyio
async def test_invalid_order_by_returns_422(
self, client: AsyncClient, ex_db_session
):
"""Unknown order_by field returns 422 with SORT-422 error code."""
resp = await client.get("/articles/offset?order_by=nonexistent_field")
assert resp.status_code == 422
body = resp.json()
assert body["error_code"] == "SORT-422"
assert body["status"] == "FAIL"
class TestCursorSorting:
"""Tests for order_by / order query parameters on the cursor endpoint.
In cursor_paginate the cursor_column is always the primary sort; order_by
acts as a secondary tiebreaker. With the seeded articles (all having unique
created_at values) the overall ordering is always created_at ASC regardless
of the order_by value — only the valid/invalid field check and the response
shape are meaningful here.
"""
@pytest.mark.anyio
async def test_default_order_uses_created_at_asc(
self, client: AsyncClient, ex_db_session
):
"""No order_by → default field (created_at) ASC."""
await seed(ex_db_session)
resp = await client.get("/articles/cursor")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
@pytest.mark.anyio
async def test_order_by_title_asc_accepted(
self, client: AsyncClient, ex_db_session
):
"""order_by=title is a valid field — request succeeds and returns all articles."""
await seed(ex_db_session)
resp = await client.get("/articles/cursor?order_by=title&order=asc")
assert resp.status_code == 200
assert len(resp.json()["data"]) == 3
@pytest.mark.anyio
async def test_order_by_title_desc_accepted(
self, client: AsyncClient, ex_db_session
):
"""order_by=title&order=desc is valid — request succeeds and returns all articles."""
await seed(ex_db_session)
resp = await client.get("/articles/cursor?order_by=title&order=desc")
assert resp.status_code == 200
assert len(resp.json()["data"]) == 3
@pytest.mark.anyio
async def test_invalid_order_by_returns_422(
self, client: AsyncClient, ex_db_session
):
"""Unknown order_by field returns 422 with SORT-422 error code."""
resp = await client.get("/articles/cursor?order_by=nonexistent_field")
assert resp.status_code == 422
body = resp.json()
assert body["error_code"] == "SORT-422"
assert body["status"] == "FAIL"
class TestPaginateUnified:
"""Tests for the unified GET /articles/ endpoint using paginate()."""
@pytest.mark.anyio
async def test_defaults_to_offset_pagination(
self, client: AsyncClient, ex_db_session
):
"""Without pagination_type, defaults to offset pagination."""
await seed(ex_db_session)
resp = await client.get("/articles/")
assert resp.status_code == 200
body = resp.json()
assert body["pagination_type"] == "offset"
assert "total_count" in body["pagination"]
assert body["pagination"]["total_count"] == 3
@pytest.mark.anyio
async def test_explicit_offset_pagination(self, client: AsyncClient, ex_db_session):
"""pagination_type=offset returns OffsetPagination metadata."""
await seed(ex_db_session)
resp = await client.get(
"/articles/?pagination_type=offset&page=1&items_per_page=2"
)
assert resp.status_code == 200
body = resp.json()
assert body["pagination_type"] == "offset"
assert body["pagination"]["total_count"] == 3
assert body["pagination"]["page"] == 1
assert body["pagination"]["has_more"] is True
assert len(body["data"]) == 2
@pytest.mark.anyio
async def test_cursor_pagination_type(self, client: AsyncClient, ex_db_session):
"""pagination_type=cursor returns CursorPagination metadata."""
await seed(ex_db_session)
resp = await client.get("/articles/?pagination_type=cursor&items_per_page=2")
assert resp.status_code == 200
body = resp.json()
assert body["pagination_type"] == "cursor"
assert "next_cursor" in body["pagination"]
assert "total_count" not in body["pagination"]
assert body["pagination"]["has_more"] is True
assert len(body["data"]) == 2
@pytest.mark.anyio
async def test_cursor_pagination_navigate_pages(
self, client: AsyncClient, ex_db_session
):
"""Cursor from first page can be used to fetch the next page."""
await seed(ex_db_session)
first = await client.get("/articles/?pagination_type=cursor&items_per_page=2")
assert first.status_code == 200
first_body = first.json()
next_cursor = first_body["pagination"]["next_cursor"]
assert next_cursor is not None
second = await client.get(
f"/articles/?pagination_type=cursor&items_per_page=2&cursor={next_cursor}"
)
assert second.status_code == 200
second_body = second.json()
assert second_body["pagination_type"] == "cursor"
assert second_body["pagination"]["has_more"] is False
assert len(second_body["data"]) == 1
@pytest.mark.anyio
async def test_cursor_pagination_with_search(
self, client: AsyncClient, ex_db_session
):
"""paginate() with cursor type respects search parameter."""
await seed(ex_db_session)
resp = await client.get("/articles/?pagination_type=cursor&search=fastapi")
assert resp.status_code == 200
body = resp.json()
assert body["pagination_type"] == "cursor"
assert len(body["data"]) == 1
assert body["data"][0]["title"] == "FastAPI tips"
@pytest.mark.anyio
async def test_offset_pagination_with_filter(
self, client: AsyncClient, ex_db_session
):
"""paginate() with offset type respects filter_by parameter."""
await seed(ex_db_session)
resp = await client.get("/articles/?pagination_type=offset&status=published")
assert resp.status_code == 200
body = resp.json()
assert body["pagination_type"] == "offset"
assert body["pagination"]["total_count"] == 2

View File

@@ -2,12 +2,14 @@
import pytest
from fastapi import FastAPI
from fastapi.exceptions import HTTPException
from fastapi.testclient import TestClient
from fastapi_toolsets.exceptions import (
ApiException,
ConflictError,
ForbiddenError,
InvalidOrderFieldError,
NotFoundError,
UnauthorizedError,
generate_error_responses,
@@ -35,8 +37,8 @@ class TestApiException:
assert error.api_error.msg == "I'm a teapot"
assert str(error) == "I'm a teapot"
def test_custom_detail_message(self):
"""Custom detail overrides default message."""
def test_detail_overrides_msg_and_str(self):
"""detail sets both str(exc) and api_error.msg; class-level msg is unchanged."""
class CustomError(ApiException):
api_error = ApiError(
@@ -46,8 +48,172 @@ class TestApiException:
err_code="BAD-400",
)
error = CustomError("Custom message")
assert str(error) == "Custom message"
error = CustomError("Widget not found")
assert str(error) == "Widget not found"
assert error.api_error.msg == "Widget not found"
assert CustomError.api_error.msg == "Bad Request" # class unchanged
def test_desc_override(self):
"""desc kwarg overrides api_error.desc on the instance only."""
class MyError(ApiException):
api_error = ApiError(
code=400, msg="Error", desc="Default.", err_code="ERR-400"
)
err = MyError(desc="Custom desc.")
assert err.api_error.desc == "Custom desc."
assert MyError.api_error.desc == "Default." # class unchanged
def test_data_override(self):
"""data kwarg sets api_error.data on the instance only."""
class MyError(ApiException):
api_error = ApiError(
code=400, msg="Error", desc="Default.", err_code="ERR-400"
)
err = MyError(data={"key": "value"})
assert err.api_error.data == {"key": "value"}
assert MyError.api_error.data is None # class unchanged
def test_desc_and_data_override(self):
"""detail, desc and data can all be overridden together."""
class MyError(ApiException):
api_error = ApiError(
code=400, msg="Error", desc="Default.", err_code="ERR-400"
)
err = MyError("custom msg", desc="New desc.", data={"x": 1})
assert str(err) == "custom msg"
assert err.api_error.msg == "custom msg" # detail also updates msg
assert err.api_error.desc == "New desc."
assert err.api_error.data == {"x": 1}
assert err.api_error.code == 400 # other fields unchanged
def test_class_api_error_not_mutated_after_instance_override(self):
"""Raising with desc/data does not mutate the class-level api_error."""
class MyError(ApiException):
api_error = ApiError(
code=400, msg="Error", desc="Default.", err_code="ERR-400"
)
MyError(desc="Changed", data={"x": 1})
assert MyError.api_error.desc == "Default."
assert MyError.api_error.data is None
def test_subclass_uses_super_with_desc_and_data(self):
"""Subclasses can delegate detail/desc/data to super().__init__()."""
class BuildValidationError(ApiException):
api_error = ApiError(
code=422,
msg="Build Validation Error",
desc="The build configuration is invalid.",
err_code="BUILD-422",
)
def __init__(self, *errors: str) -> None:
super().__init__(
f"{len(errors)} validation error(s)",
desc=", ".join(errors),
data={"errors": [{"message": e} for e in errors]},
)
err = BuildValidationError("Field A is required", "Field B is invalid")
assert str(err) == "2 validation error(s)"
assert err.api_error.msg == "2 validation error(s)" # detail set msg
assert err.api_error.desc == "Field A is required, Field B is invalid"
assert err.api_error.data == {
"errors": [
{"message": "Field A is required"},
{"message": "Field B is invalid"},
]
}
assert err.api_error.code == 422 # other fields unchanged
def test_detail_desc_data_in_http_response(self):
"""detail/desc/data overrides all appear correctly in the FastAPI HTTP response."""
class DynamicError(ApiException):
api_error = ApiError(
code=400, msg="Error", desc="Default.", err_code="ERR-400"
)
def __init__(self, message: str) -> None:
super().__init__(
message,
desc=f"Detail: {message}",
data={"reason": message},
)
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/error")
async def raise_error():
raise DynamicError("something went wrong")
client = TestClient(app)
response = client.get("/error")
assert response.status_code == 400
body = response.json()
assert body["message"] == "something went wrong"
assert body["description"] == "Detail: something went wrong"
assert body["data"] == {"reason": "something went wrong"}
class TestApiExceptionGuard:
"""Tests for the __init_subclass__ api_error guard."""
def test_missing_api_error_raises_type_error(self):
"""Defining a subclass without api_error raises TypeError at class creation time."""
with pytest.raises(
TypeError, match="must define an 'api_error' class attribute"
):
class BrokenError(ApiException):
pass
def test_abstract_subclass_skips_guard(self):
"""abstract=True allows intermediate base classes without api_error."""
class BaseGroupError(ApiException, abstract=True):
pass
# Concrete child must still define it
class ConcreteError(BaseGroupError):
api_error = ApiError(
code=400, msg="Error", desc="Desc.", err_code="ERR-400"
)
err = ConcreteError()
assert err.api_error.code == 400
def test_abstract_child_still_requires_api_error_on_concrete(self):
"""Concrete subclass of an abstract class must define api_error."""
class Base(ApiException, abstract=True):
pass
with pytest.raises(
TypeError, match="must define an 'api_error' class attribute"
):
class Concrete(Base):
pass
def test_inherited_api_error_satisfies_guard(self):
"""Subclass that inherits api_error from a parent does not need its own."""
class ConcreteError(NotFoundError):
pass
err = ConcreteError()
assert err.api_error.code == 404
class TestBuiltInExceptions:
@@ -89,7 +255,7 @@ class TestGenerateErrorResponses:
assert responses[404]["description"] == "Not Found"
def test_generates_multiple_responses(self):
"""Generates responses for multiple exceptions."""
"""Generates responses for multiple exceptions with distinct status codes."""
responses = generate_error_responses(
UnauthorizedError,
ForbiddenError,
@@ -100,14 +266,81 @@ class TestGenerateErrorResponses:
assert 403 in responses
assert 404 in responses
def test_response_has_example(self):
"""Generated response includes example."""
def test_response_has_named_example(self):
"""Generated response uses named examples keyed by err_code."""
responses = generate_error_responses(NotFoundError)
example = responses[404]["content"]["application/json"]["example"]
examples = responses[404]["content"]["application/json"]["examples"]
assert example["status"] == "FAIL"
assert example["error_code"] == "RES-404"
assert example["message"] == "Not Found"
assert "RES-404" in examples
value = examples["RES-404"]["value"]
assert value["status"] == "FAIL"
assert value["error_code"] == "RES-404"
assert value["message"] == "Not Found"
assert value["data"] is None
def test_response_example_has_summary(self):
"""Each named example carries a summary equal to api_error.msg."""
responses = generate_error_responses(NotFoundError)
example = responses[404]["content"]["application/json"]["examples"]["RES-404"]
assert example["summary"] == "Not Found"
def test_response_example_with_data(self):
"""Generated response includes data when set on ApiError."""
class ErrorWithData(ApiException):
api_error = ApiError(
code=400,
msg="Bad Request",
desc="Invalid input.",
err_code="BAD-400",
data={"details": "some context"},
)
responses = generate_error_responses(ErrorWithData)
value = responses[400]["content"]["application/json"]["examples"]["BAD-400"][
"value"
]
assert value["data"] == {"details": "some context"}
def test_two_errors_same_code_both_present(self):
"""Two exceptions with the same HTTP code produce two named examples."""
class BadRequestA(ApiException):
api_error = ApiError(
code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A"
)
class BadRequestB(ApiException):
api_error = ApiError(
code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B"
)
responses = generate_error_responses(BadRequestA, BadRequestB)
assert 400 in responses
examples = responses[400]["content"]["application/json"]["examples"]
assert "ERR-A" in examples
assert "ERR-B" in examples
assert examples["ERR-A"]["value"]["message"] == "Bad A"
assert examples["ERR-B"]["value"]["message"] == "Bad B"
def test_two_errors_same_code_single_top_level_entry(self):
"""Two exceptions with the same HTTP code produce exactly one top-level entry."""
class BadRequestA(ApiException):
api_error = ApiError(
code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A"
)
class BadRequestB(ApiException):
api_error = ApiError(
code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B"
)
responses = generate_error_responses(BadRequestA, BadRequestB)
assert len([k for k in responses if k == 400]) == 1
class TestInitExceptionsHandlers:
@@ -137,6 +370,59 @@ class TestInitExceptionsHandlers:
assert data["error_code"] == "RES-404"
assert data["message"] == "Not Found"
def test_handles_api_exception_without_data(self):
"""ApiException without data returns null data field."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/error")
async def raise_error():
raise NotFoundError()
client = TestClient(app)
response = client.get("/error")
assert response.status_code == 404
assert response.json()["data"] is None
def test_handles_api_exception_with_data(self):
"""ApiException with data returns the data payload."""
app = FastAPI()
init_exceptions_handlers(app)
class CustomValidationError(ApiException):
api_error = ApiError(
code=422,
msg="Validation Error",
desc="1 validation error(s) detected",
err_code="CUSTOM-422",
data={
"errors": [
{
"field": "email",
"message": "invalid format",
"type": "value_error",
}
]
},
)
@app.get("/error")
async def raise_error():
raise CustomValidationError()
client = TestClient(app)
response = client.get("/error")
assert response.status_code == 422
data = response.json()
assert data["data"] == {
"errors": [
{"field": "email", "message": "invalid format", "type": "value_error"}
]
}
assert data["error_code"] == "CUSTOM-422"
def test_handles_validation_error(self):
"""Handles validation errors with structured response."""
from pydantic import BaseModel
@@ -178,13 +464,68 @@ class TestInitExceptionsHandlers:
assert data["status"] == "FAIL"
assert data["error_code"] == "SERVER-500"
def test_custom_openapi_schema(self):
"""Customizes OpenAPI schema for 422 responses."""
def test_handles_http_exception(self):
"""Handles starlette HTTPException with consistent ErrorResponse envelope."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/protected")
async def protected():
raise HTTPException(status_code=403, detail="Forbidden")
client = TestClient(app)
response = client.get("/protected")
assert response.status_code == 403
data = response.json()
assert data["status"] == "FAIL"
assert data["error_code"] == "HTTP-403"
assert data["message"] == "Forbidden"
def test_handles_http_exception_404_from_route(self):
"""HTTPException(404) raised inside a route uses the consistent ErrorResponse envelope."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/items/{item_id}")
async def get_item(item_id: int):
raise HTTPException(status_code=404, detail="Item not found")
client = TestClient(app)
response = client.get("/items/99")
assert response.status_code == 404
data = response.json()
assert data["status"] == "FAIL"
assert data["error_code"] == "HTTP-404"
assert data["message"] == "Item not found"
def test_handles_http_exception_forwards_headers(self):
"""HTTPException with WWW-Authenticate header forwards it in the response."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/secure")
async def secure():
raise HTTPException(
status_code=401,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
client = TestClient(app)
response = client.get("/secure")
assert response.status_code == 401
assert response.headers.get("www-authenticate") == "Bearer"
def test_custom_openapi_schema(self):
"""Customises OpenAPI schema for 422 responses using named examples."""
from pydantic import BaseModel
app = FastAPI()
init_exceptions_handlers(app)
class Item(BaseModel):
name: str
@@ -197,8 +538,128 @@ class TestInitExceptionsHandlers:
post_op = openapi["paths"]["/items"]["post"]
assert "422" in post_op["responses"]
resp_422 = post_op["responses"]["422"]
example = resp_422["content"]["application/json"]["example"]
assert example["error_code"] == "VAL-422"
examples = resp_422["content"]["application/json"]["examples"]
assert "VAL-422" in examples
assert examples["VAL-422"]["value"]["error_code"] == "VAL-422"
def test_custom_openapi_preserves_app_metadata(self):
"""_patched_openapi preserves custom FastAPI app-level metadata."""
app = FastAPI(
title="My API",
version="2.0.0",
description="Custom description",
)
init_exceptions_handlers(app)
schema = app.openapi()
assert schema["info"]["title"] == "My API"
assert schema["info"]["version"] == "2.0.0"
def test_handles_response_validation_error(self):
"""Handles ResponseValidationError with a structured 422 response."""
from pydantic import BaseModel
class CountResponse(BaseModel):
count: int
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/broken", response_model=CountResponse)
async def broken():
return {"count": "not-a-number"} # triggers ResponseValidationError
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/broken")
assert response.status_code == 422
data = response.json()
assert data["status"] == "FAIL"
assert data["error_code"] == "VAL-422"
assert "errors" in data["data"]
def test_handles_validation_error_with_non_standard_loc(self):
"""Validation error with empty loc tuple maps the field to 'root'."""
from fastapi.exceptions import RequestValidationError
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/root-error")
async def root_error():
raise RequestValidationError(
[
{
"type": "custom",
"loc": (),
"msg": "root level error",
"input": None,
"url": "",
}
]
)
client = TestClient(app)
response = client.get("/root-error")
assert response.status_code == 422
data = response.json()
assert data["data"]["errors"][0]["field"] == "root"
def test_openapi_schema_cached_after_first_call(self):
"""app.openapi() returns the cached schema on subsequent calls."""
from pydantic import BaseModel
app = FastAPI()
init_exceptions_handlers(app)
class Item(BaseModel):
name: str
@app.post("/items")
async def create_item(item: Item):
return item
schema_first = app.openapi()
schema_second = app.openapi()
assert schema_first is schema_second
def test_openapi_skips_operations_without_422(self):
"""_patched_openapi leaves operations that have no 422 response unchanged."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/ping")
async def ping():
return {"ok": True}
schema = app.openapi()
get_op = schema["paths"]["/ping"]["get"]
assert "422" not in get_op["responses"]
assert "200" in get_op["responses"]
def test_openapi_skips_non_dict_path_item_values(self):
"""_patched_openapi ignores non-dict values in path items (e.g. path-level parameters)."""
from fastapi_toolsets.exceptions.handler import _patched_openapi
app = FastAPI()
def fake_openapi() -> dict:
return {
"paths": {
"/items": {
"parameters": [
{"name": "q", "in": "query"}
], # list, not a dict
"get": {"responses": {"200": {"description": "OK"}}},
}
}
}
schema = _patched_openapi(app, fake_openapi)
# The list value was skipped without error; the GET operation is intact
assert schema["paths"]["/items"]["parameters"] == [{"name": "q", "in": "query"}]
assert "422" not in schema["paths"]["/items"]["get"]["responses"]
class TestExceptionIntegration:
@@ -263,3 +724,43 @@ class TestExceptionIntegration:
assert response.status_code == 200
assert response.json() == {"id": 1}
class TestInvalidOrderFieldError:
"""Tests for InvalidOrderFieldError exception."""
def test_api_error_attributes(self):
"""InvalidOrderFieldError has correct api_error metadata."""
assert InvalidOrderFieldError.api_error.code == 422
assert InvalidOrderFieldError.api_error.err_code == "SORT-422"
assert InvalidOrderFieldError.api_error.msg == "Invalid Order Field"
def test_stores_field_and_valid_fields(self):
"""InvalidOrderFieldError stores field and valid_fields on the instance."""
error = InvalidOrderFieldError("unknown", ["name", "created_at"])
assert error.field == "unknown"
assert error.valid_fields == ["name", "created_at"]
def test_description_contains_field_and_valid_fields(self):
"""api_error.desc mentions the bad field and valid options."""
error = InvalidOrderFieldError("bad_field", ["name", "email"])
assert "bad_field" in error.api_error.desc
assert "name" in error.api_error.desc
assert "email" in error.api_error.desc
def test_handled_as_422_by_exception_handler(self):
"""init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/items")
async def list_items():
raise InvalidOrderFieldError("bad", ["name"])
client = TestClient(app)
response = client.get("/items")
assert response.status_code == 422
data = response.json()
assert data["error_code"] == "SORT-422"
assert data["status"] == "FAIL"

Some files were not shown because too many files have changed in this diff Show More