4
4
5
5
from fastapi import HTTPException
6
6
from sqlalchemy import create_engine , func , select
7
+ from sqlalchemy .ext .asyncio import (
8
+ AsyncEngine ,
9
+ AsyncSession ,
10
+ async_sessionmaker ,
11
+ create_async_engine ,
12
+ )
7
13
from sqlalchemy .orm import (
8
14
DeclarativeBase ,
9
15
Mapped ,
19
25
from src .helpers .casing import snakecase
20
26
from src .helpers .sql import random_uuid , utcnow
21
27
28
+ # Sync engine and session
22
29
engine = create_engine (str (settings .database_url ))
23
30
SessionLocal = sessionmaker (autocommit = False , autoflush = False , bind = engine )
24
31
32
+ # Async engine and session
33
+ async_engine : AsyncEngine = create_async_engine (str (settings .async_database_url ))
34
+
35
+
36
+ def async_session_generator () -> async_sessionmaker [AsyncSession ]:
37
+ return async_sessionmaker (
38
+ autocommit = False , autoflush = False , bind = async_engine , class_ = AsyncSession
39
+ )
40
+
25
41
26
42
class SQLBase (DeclarativeBase ):
27
43
__abstract__ = True
@@ -34,6 +50,12 @@ def __tablename__(cls) -> str:
34
50
def objects (cls : Type ["_Model" ], session : Session ) -> "Objects[_Model]" :
35
51
return Objects (cls , session )
36
52
53
+ @classmethod
54
+ def async_objects (
55
+ cls : Type ["_Model" ], session : AsyncSession
56
+ ) -> "AsyncObjects[_Model]" :
57
+ return AsyncObjects (cls , session )
58
+
37
59
38
60
_Model = TypeVar ("_Model" , bound = SQLBase )
39
61
@@ -94,6 +116,72 @@ def create(self, data: Dict[str, Any]) -> _Model:
94
116
return obj
95
117
96
118
119
+ class AsyncObjects (Generic [_Model ]):
120
+ cls : Type [_Model ]
121
+ session : AsyncSession
122
+ base_statement : Select
123
+ queryset_filters : Any = None
124
+
125
+ def __init__ (
126
+ self ,
127
+ cls : Type [_Model ],
128
+ session : AsyncSession ,
129
+ * queryset_filters : Any ,
130
+ ) -> None :
131
+ self .cls = cls
132
+ self .session = session
133
+ base_statement = select (cls )
134
+ if queryset_filters :
135
+ self .queryset_filters = queryset_filters
136
+ base_statement = base_statement .where (* queryset_filters )
137
+ self .base_statement = base_statement
138
+
139
+ async def all (self ) -> Sequence [_Model ]:
140
+ result = await self .session .execute (self .base_statement )
141
+ return result .scalars ().unique ().all ()
142
+
143
+ async def get (self , * where_clause : Any ) -> _Model | None :
144
+ statement = self .base_statement .where (* where_clause )
145
+ result = await self .session .execute (statement )
146
+ return result .unique ().scalar_one_or_none ()
147
+
148
+ async def get_or_404 (self , * where_clause : Any ) -> _Model :
149
+ obj = await self .get (* where_clause )
150
+ if obj is None :
151
+ raise HTTPException (
152
+ status_code = 404 , detail = f"{ self .cls .__name__ } not found"
153
+ )
154
+ return obj
155
+
156
+ async def get_all (self , * where_clause : Any ) -> Sequence [_Model ]:
157
+ statement = self .base_statement .where (* where_clause )
158
+ result = await self .session .execute (statement )
159
+ return result .scalars ().unique ().all ()
160
+
161
+ async def count (self , * where_clause : Any ) -> int :
162
+ statement = select (func .count ()).select_from (self .cls )
163
+ if self .queryset_filters :
164
+ statement = statement .where (* self .queryset_filters )
165
+ if where_clause :
166
+ statement = statement .where (* where_clause )
167
+
168
+ result = await self .session .execute (statement )
169
+ return result .scalar_one ()
170
+
171
+ async def create (self , data : Dict [str , Any ]) -> _Model :
172
+ obj = self .cls (** data )
173
+ self .session .add (obj )
174
+ await self .session .commit ()
175
+ await self .session .refresh (obj )
176
+ return obj
177
+
178
+ async def bulk_create (self , data : Sequence [Dict [str , Any ]]) -> Sequence [_Model ]:
179
+ objs = [self .cls (** item ) for item in data ]
180
+ self .session .add_all (objs )
181
+ await self .session .commit ()
182
+ return objs
183
+
184
+
97
185
@declarative_mixin
98
186
class TableIdMixin :
99
187
id : Mapped [uuid .UUID ] = mapped_column (
0 commit comments