mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
fix: cascade delete M2M association rows via ORM session (#121)
This commit is contained in:
@@ -14,7 +14,6 @@ from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
|||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
||||||
from sqlalchemy import delete as sql_delete
|
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@@ -674,8 +673,10 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
``None``, or ``Response[None]`` when ``return_response=True``.
|
``None``, or ``Response[None]`` when ``return_response=True``.
|
||||||
"""
|
"""
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
q = sql_delete(cls.model).where(and_(*filters))
|
result = await session.execute(select(cls.model).where(and_(*filters)))
|
||||||
await session.execute(q)
|
objects = result.scalars().all()
|
||||||
|
for obj in objects:
|
||||||
|
await session.delete(obj)
|
||||||
if return_response:
|
if return_response:
|
||||||
return Response(data=None)
|
return Response(data=None)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .conftest import (
|
|||||||
RoleCursorCrud,
|
RoleCursorCrud,
|
||||||
RoleRead,
|
RoleRead,
|
||||||
RoleUpdate,
|
RoleUpdate,
|
||||||
|
Tag,
|
||||||
TagCreate,
|
TagCreate,
|
||||||
TagCrud,
|
TagCrud,
|
||||||
User,
|
User,
|
||||||
@@ -480,6 +481,69 @@ class TestCrudDelete:
|
|||||||
assert result.data is None
|
assert result.data is None
|
||||||
assert await RoleCrud.first(db_session, [Role.id == role.id]) is None
|
assert await RoleCrud.first(db_session, [Role.id == role.id]) is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_m2m_cascade(self, db_session: AsyncSession):
|
||||||
|
"""Deleting a record with M2M relationships cleans up the association table."""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag1 = await TagCrud.create(db_session, TagCreate(name="python"))
|
||||||
|
tag2 = await TagCrud.create(db_session, TagCreate(name="fastapi"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="M2M Delete Test",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag1.id, tag2.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await PostM2MCrud.delete(db_session, [Post.id == post.id])
|
||||||
|
|
||||||
|
# Post is gone
|
||||||
|
assert await PostCrud.first(db_session, [Post.id == post.id]) is None
|
||||||
|
|
||||||
|
# Association rows are gone — tags themselves must still exist
|
||||||
|
assert await TagCrud.first(db_session, [Tag.id == tag1.id]) is not None
|
||||||
|
assert await TagCrud.first(db_session, [Tag.id == tag2.id]) is not None
|
||||||
|
|
||||||
|
# No orphaned rows in post_tags
|
||||||
|
result = await db_session.execute(
|
||||||
|
text("SELECT COUNT(*) FROM post_tags WHERE post_id = :pid").bindparams(
|
||||||
|
pid=post.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result.scalar() == 0
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_m2m_does_not_delete_related_records(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Deleting a post with M2M tags must not delete the tags themselves."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author2", email="author2@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="shared_tag"))
|
||||||
|
|
||||||
|
post1 = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(title="Post 1", author_id=user.id, tag_ids=[tag.id]),
|
||||||
|
)
|
||||||
|
post2 = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(title="Post 2", author_id=user.id, tag_ids=[tag.id]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete only post1
|
||||||
|
await PostM2MCrud.delete(db_session, [Post.id == post1.id])
|
||||||
|
|
||||||
|
# Tag and post2 still exist
|
||||||
|
assert await TagCrud.first(db_session, [Tag.id == tag.id]) is not None
|
||||||
|
assert await PostCrud.first(db_session, [Post.id == post2.id]) is not None
|
||||||
|
|
||||||
|
|
||||||
class TestCrudExists:
|
class TestCrudExists:
|
||||||
"""Tests for CRUD exists operations."""
|
"""Tests for CRUD exists operations."""
|
||||||
|
|||||||
Reference in New Issue
Block a user