Skip to content

Commit e486b76

Browse files
authoredMar 21, 2025
feat: Add DB async connection (#37)
* feat: Add DB async connection * style: use built-in list for type hints
1 parent dc88b79 commit e486b76

File tree

10 files changed

+225
-10
lines changed

10 files changed

+225
-10
lines changed
 

‎.env.example

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
# The name of the project will be used to build the final docker image and has to
1+
# The name of the project will be used to build the final docker image and has to
22
# be exactly the same as project.name as defined in pyproject.toml
33
PROJECT_NAME="python-template"
44

55
DATABASE_URL=postgresql://dev:dev@postgres:5432/dev
6+
ASYNC_DATABASE_URL=postgresql+asyncpg://dev:dev@postgres:5432/dev
67
LOG_LEVEL=DEBUG
78
SERVER_URL=example.com
89
ACCESS_TOKEN_EXPIRE_MINUTES=15

‎poetry.lock

+65-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description = "Xmartlabs' Python project template"
66
authors = [{ name = "Xmartlabs", email = "getintouch@xmartlabs.com" }]
77
readme = "README.md"
88
requires-python = "^3.13"
9-
dependencies = []
9+
dependencies = ["asyncpg (>=0.30.0,<0.31.0)"]
1010

1111
[tool.poetry]
1212
# TODO(remer): this can be removed when the source files are moved to project name folder within src

‎src/api/dependencies.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from typing import Iterator
1+
from typing import AsyncIterator, Iterator
22

33
from fastapi import Depends, Request
44

5-
from src.core.database import Session, SessionLocal
5+
from src.core.database import (
6+
AsyncSession,
7+
Session,
8+
SessionLocal,
9+
async_session_generator,
10+
)
611
from src.core.security import AuthManager
712
from src.models import User
813

@@ -15,6 +20,18 @@ def db_session() -> Iterator[Session]:
1520
db.close()
1621

1722

23+
async def async_db_session() -> AsyncIterator[AsyncSession]:
24+
try:
25+
async_session = async_session_generator()
26+
async with async_session() as session:
27+
yield session
28+
except Exception:
29+
await session.rollback()
30+
raise
31+
finally:
32+
await session.close()
33+
34+
1835
def get_user(request: Request, session: Session = Depends(db_session)) -> User:
1936
manager = AuthManager()
2037
return manager(request=request, session=session)

‎src/api/v1/routers/item.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from fastapi_pagination import Page
55
from fastapi_pagination.ext.sqlalchemy import paginate
66

7-
from src.api.dependencies import db_session, get_user
8-
from src.api.v1.schemas import Item, ItemCreate
7+
from src.api.dependencies import async_db_session, db_session, get_user
8+
from src.api.v1.schemas import BulkItemCreate, Item, ItemCreate
99
from src.controllers import ItemController
10-
from src.core.database import Session
10+
from src.core.database import AsyncSession, Session
1111
from src.models import User
1212

1313
router = APIRouter()
@@ -27,3 +27,24 @@ def create_item(
2727
session: Session = Depends(db_session),
2828
) -> Any:
2929
return ItemController.create(item_data=item_data, owner_id=user.id, session=session)
30+
31+
32+
@router.get("/async", response_model=Page[Item])
33+
async def get_items_async(
34+
user: User = Depends(get_user),
35+
async_session: AsyncSession = Depends(async_db_session),
36+
) -> Any:
37+
"""Get items asynchronously."""
38+
return await paginate(async_session, user.get_items())
39+
40+
41+
@router.post("/async", response_model=list[Item], status_code=201)
42+
async def create_item_async(
43+
item_data: BulkItemCreate,
44+
user: User = Depends(get_user),
45+
async_session: AsyncSession = Depends(async_db_session),
46+
) -> Any:
47+
"""Create items asynchronously."""
48+
return await ItemController.bulk_create(
49+
items_data=item_data.items, owner_id=user.id, async_session=async_session
50+
)

‎src/api/v1/schemas/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .item import Item, ItemCreate
1+
from .item import BulkItemCreate, Item, ItemCreate
22
from .token import Token, TokenPayload
33
from .user import User, UserCreate

‎src/api/v1/schemas/item.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ class ItemCreate(BaseModel):
1212
class Item(ItemCreate):
1313
owner_id: UUID
1414
model_config = ConfigDict(from_attributes=True)
15+
16+
17+
class BulkItemCreate(BaseModel):
18+
items: list[ItemCreate]

‎src/controllers/item.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from typing import Sequence
12
from uuid import UUID
23

34
from src import models
45
from src.api.v1 import schemas
5-
from src.core.database import Session
6+
from src.core.database import AsyncSession, Session
67

78

89
class ItemController:
@@ -13,3 +14,21 @@ def create(
1314
item_data = schemas.Item(owner_id=owner_id, **item_data.model_dump())
1415
item = models.Item.objects(session).create(item_data.model_dump())
1516
return item
17+
18+
@staticmethod
19+
async def bulk_create(
20+
items_data: Sequence[schemas.ItemCreate],
21+
owner_id: UUID,
22+
async_session: AsyncSession,
23+
) -> Sequence[models.Item]:
24+
items_data = [
25+
schemas.Item(owner_id=owner_id, **item_data.model_dump())
26+
for item_data in items_data
27+
]
28+
items = await models.Item.async_objects(async_session).bulk_create(
29+
[item_data.model_dump() for item_data in items_data]
30+
)
31+
for item in items:
32+
await async_session.refresh(item)
33+
34+
return items

‎src/core/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class LogLevel(str, Enum):
1414

1515
class Settings(BaseSettings):
1616
database_url: PostgresDsn
17+
async_database_url: PostgresDsn
1718
log_level: LogLevel = LogLevel.debug
1819
server_url: str
1920

‎src/core/database.py

+88
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44

55
from fastapi import HTTPException
66
from sqlalchemy import create_engine, func, select
7+
from sqlalchemy.ext.asyncio import (
8+
AsyncEngine,
9+
AsyncSession,
10+
async_sessionmaker,
11+
create_async_engine,
12+
)
713
from sqlalchemy.orm import (
814
DeclarativeBase,
915
Mapped,
@@ -19,9 +25,19 @@
1925
from src.helpers.casing import snakecase
2026
from src.helpers.sql import random_uuid, utcnow
2127

28+
# Sync engine and session
2229
engine = create_engine(str(settings.database_url))
2330
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
2431

32+
# Async engine and session
33+
async_engine: AsyncEngine = create_async_engine(str(settings.async_database_url))
34+
35+
36+
def async_session_generator() -> async_sessionmaker[AsyncSession]:
37+
return async_sessionmaker(
38+
autocommit=False, autoflush=False, bind=async_engine, class_=AsyncSession
39+
)
40+
2541

2642
class SQLBase(DeclarativeBase):
2743
__abstract__ = True
@@ -34,6 +50,12 @@ def __tablename__(cls) -> str:
3450
def objects(cls: Type["_Model"], session: Session) -> "Objects[_Model]":
3551
return Objects(cls, session)
3652

53+
@classmethod
54+
def async_objects(
55+
cls: Type["_Model"], session: AsyncSession
56+
) -> "AsyncObjects[_Model]":
57+
return AsyncObjects(cls, session)
58+
3759

3860
_Model = TypeVar("_Model", bound=SQLBase)
3961

@@ -94,6 +116,72 @@ def create(self, data: Dict[str, Any]) -> _Model:
94116
return obj
95117

96118

119+
class AsyncObjects(Generic[_Model]):
120+
cls: Type[_Model]
121+
session: AsyncSession
122+
base_statement: Select
123+
queryset_filters: Any = None
124+
125+
def __init__(
126+
self,
127+
cls: Type[_Model],
128+
session: AsyncSession,
129+
*queryset_filters: Any,
130+
) -> None:
131+
self.cls = cls
132+
self.session = session
133+
base_statement = select(cls)
134+
if queryset_filters:
135+
self.queryset_filters = queryset_filters
136+
base_statement = base_statement.where(*queryset_filters)
137+
self.base_statement = base_statement
138+
139+
async def all(self) -> Sequence[_Model]:
140+
result = await self.session.execute(self.base_statement)
141+
return result.scalars().unique().all()
142+
143+
async def get(self, *where_clause: Any) -> _Model | None:
144+
statement = self.base_statement.where(*where_clause)
145+
result = await self.session.execute(statement)
146+
return result.unique().scalar_one_or_none()
147+
148+
async def get_or_404(self, *where_clause: Any) -> _Model:
149+
obj = await self.get(*where_clause)
150+
if obj is None:
151+
raise HTTPException(
152+
status_code=404, detail=f"{self.cls.__name__} not found"
153+
)
154+
return obj
155+
156+
async def get_all(self, *where_clause: Any) -> Sequence[_Model]:
157+
statement = self.base_statement.where(*where_clause)
158+
result = await self.session.execute(statement)
159+
return result.scalars().unique().all()
160+
161+
async def count(self, *where_clause: Any) -> int:
162+
statement = select(func.count()).select_from(self.cls)
163+
if self.queryset_filters:
164+
statement = statement.where(*self.queryset_filters)
165+
if where_clause:
166+
statement = statement.where(*where_clause)
167+
168+
result = await self.session.execute(statement)
169+
return result.scalar_one()
170+
171+
async def create(self, data: Dict[str, Any]) -> _Model:
172+
obj = self.cls(**data)
173+
self.session.add(obj)
174+
await self.session.commit()
175+
await self.session.refresh(obj)
176+
return obj
177+
178+
async def bulk_create(self, data: Sequence[Dict[str, Any]]) -> Sequence[_Model]:
179+
objs = [self.cls(**item) for item in data]
180+
self.session.add_all(objs)
181+
await self.session.commit()
182+
return objs
183+
184+
97185
@declarative_mixin
98186
class TableIdMixin:
99187
id: Mapped[uuid.UUID] = mapped_column(

0 commit comments

Comments
 (0)