'How to do persistent database connection in FastAPI?
I am writing my first project in FastAPI and I am struggling a bit. In particular, I am not sure how I am supposed to use asyncpg connection pool in my app. Currently what I have goes like this
in db.py I have
pgpool = None
async def get_pool():
global pgpool
if not pgpool:
pgpool = await asyncpg.create_pool(dsn='MYDB_DSN')
return pgpool
and then in individual files I use the get_pool as a dependency.
@router.post("/user/", response_model=models.User, status_code=201)
async def create_user(user: models.UserCreate, pgpool = Depends(get_pool)):
# ... do things ...
First, every endpoint I have uses the database, so it seems silly to add that dependency argument for every single function. Second, this seems like a roundabout way of doing things. I define a global, then I define a function that returns that global and then I inject the function. I am sure there is more natural way of going about it.
I have seen people suggest just adding whatever I need as a property to the app object
@app.on_event("startup")
async def startup():
app.pool = await asyncpg.create_pool(dsn='MYDB_DSN')
but it doesn't work when I have multiple files with routers, I don't know how to access the app object from a router object.
What am I missing?
Solution 1:[1]
You can use an application factory pattern to setup your application.
To avoid using global or adding things directly to the app object you can create your own class Database to hold your connection pool.
To pass the connection pool to every route you can use a middleware and add the pool to request.state
Here's the example code:
import asyncio
import asyncpg
from fastapi import FastAPI, Request
class Database():
async def create_pool(self):
self.pool = await asyncpg.create_pool(dsn='MYDB_DSN')
def create_app():
app = FastAPI()
db = Database()
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request.state.pgpool = db.pool
response = await call_next(request)
return response
@app.on_event("startup")
async def startup():
await db.create_pool()
@app.on_event("shutdown")
async def shutdown():
# cleanup
pass
@app.get("/")
async def hello(request: Request):
print(request.state.pool)
return app
app = create_app()
Solution 2:[2]
The way I do it is in db.py.
class Database:
def __init__(self,user,password,host,database,port="5432"):
self.user = user
self.password = password
self.host = host
self.port = port
self.database = database
self._cursor = None
self._connection_pool = None
async def connect(self):
if not self._connection_pool:
try:
self._connection_pool = await asyncpg.create_pool(
min_size=1,
max_size=20,
command_timeout=60,
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
ssl="require"
)
logger.info("Database pool connectionn opened")
except Exception as e:
logger.exception(e)
async def fetch_rows(self, query: str,*args):
if not self._connection_pool:
await self.connect()
else:
con = await self._connection_pool.acquire()
try:
result = await con.fetch(query,*args)
return result
except Exception as e:
logger.exception(e)
finally:
await self._connection_pool.release(con)
async def close(self):
if not self._connection_pool:
try:
await self._connection_pool.close()
logger.info("Database pool connection closed")
except Exception as e:
logger.exception(e)
Then in app
@app.on_event("startup")
async def startup_event():
database_instance = db.Database(**db_arguments)
await database_instance.connect()
app.state.db = database_instance
logger.info("Server Startup")
@app.on_event("shutdown")
async def shutdown_event():
if not app.state.db:
await app.state.db.close()
logger.info("Server Shutdown")
Then you can get the db instance with request.app.state.db by passing in a request parameter in the routes.
Solution 3:[3]
The info in your post allowed me to come up with this solution. A little digging in the class definitions and I was able to find a startup event which can hook async defs onto.
db.py
from asyncpg import create_pool, Pool
pgpool: Pool | None = None
async def get_pool():
global pgpool
if not pgpool:
pgpool = await create_pool(dsn='MY_DSN')
return pgpool
my_router.py
from fastapi import APIRouter
from asyncpg import Pool
from db import get_pool
router = APIRouter()
pgpool: Pool | None = None
@router.on_event("startup")
async def router_startup():
global pgpool
pgpool = await get_pool()
pgpool.acquire() will be available to async defs within my_router.py.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Gabriel Cappelli |
| Solution 2 | |
| Solution 3 | l2affiki |
