diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 1c87b0a..4bca120 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -170,6 +170,18 @@ class AsyncCrud(Generic[ModelType]): return load_options return cls.default_load_options + @classmethod + async def _reload_with_options( + cls: type[Self], session: AsyncSession, instance: ModelType + ) -> ModelType: + """Re-query instance by PK with default_load_options applied.""" + mapper = cls.model.__mapper__ + pk_filters = [ + getattr(cls.model, col.key) == getattr(instance, col.key) + for col in mapper.primary_key + ] + return await cls.get(session, filters=pk_filters) + @classmethod async def _resolve_m2m( cls: type[Self], @@ -705,6 +717,8 @@ class AsyncCrud(Generic[ModelType]): session.add(db_model) await session.refresh(db_model) + if cls.default_load_options: + db_model = await cls._reload_with_options(session, db_model) result = cast(ModelType, db_model) if schema: return Response(data=schema.model_validate(result)) @@ -1060,6 +1074,8 @@ class AsyncCrud(Generic[ModelType]): for rel_attr, related_instances in m2m_resolved.items(): setattr(db_model, rel_attr, related_instances) await session.refresh(db_model) + if cls.default_load_options: + db_model = await cls._reload_with_options(session, db_model) if schema: return Response(data=schema.model_validate(db_model)) return db_model diff --git a/tests/test_crud.py b/tests/test_crud.py index 3667a08..f44652e 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -380,6 +380,43 @@ class TestDefaultLoadOptionsIntegration: assert result.data[0].role is not None assert result.data[0].role.name == "admin" + @pytest.mark.anyio + async def test_default_load_options_applied_to_create( + self, db_session: AsyncSession + ): + """default_load_options loads relationships after create().""" + UserWithDefaultLoad = CrudFactory( + User, default_load_options=[selectinload(User.role)] + ) + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + user = await UserWithDefaultLoad.create( + db_session, + UserCreate(username="alice", email="alice@test.com", role_id=role.id), + ) + assert user.role is not None + assert user.role.name == "admin" + + @pytest.mark.anyio + async def test_default_load_options_applied_to_update( + self, db_session: AsyncSession + ): + """default_load_options loads relationships after update().""" + UserWithDefaultLoad = CrudFactory( + User, default_load_options=[selectinload(User.role)] + ) + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + user = await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com"), + ) + updated = await UserWithDefaultLoad.update( + db_session, + UserUpdate(role_id=role.id), + filters=[User.id == user.id], + ) + assert updated.role is not None + assert updated.role.name == "admin" + @pytest.mark.anyio async def test_load_options_overrides_default_load_options( self, db_session: AsyncSession