'How to do persistent database connection in FastAPI?

I am writing my first project in FastAPI and I am struggling a bit. In particular, I am not sure how I am supposed to use asyncpg connection pool in my app. Currently what I have goes like this

in db.py I have

pgpool = None


async def get_pool():
    global pgpool
    if not pgpool:
        pgpool = await asyncpg.create_pool(dsn='MYDB_DSN')
    return pgpool

and then in individual files I use the get_pool as a dependency.

@router.post("/user/", response_model=models.User, status_code=201)
async def create_user(user: models.UserCreate, pgpool = Depends(get_pool)):
    # ... do things ...

First, every endpoint I have uses the database, so it seems silly to add that dependency argument for every single function. Second, this seems like a roundabout way of doing things. I define a global, then I define a function that returns that global and then I inject the function. I am sure there is more natural way of going about it.

I have seen people suggest just adding whatever I need as a property to the app object

@app.on_event("startup")
async def startup():
    app.pool = await asyncpg.create_pool(dsn='MYDB_DSN')

but it doesn't work when I have multiple files with routers, I don't know how to access the app object from a router object.

What am I missing?



Solution 1:[1]

You can use an application factory pattern to setup your application.

To avoid using global or adding things directly to the app object you can create your own class Database to hold your connection pool.

To pass the connection pool to every route you can use a middleware and add the pool to request.state

Here's the example code:

import asyncio

import asyncpg
from fastapi import FastAPI, Request

class Database():

    async def create_pool(self):
        self.pool = await asyncpg.create_pool(dsn='MYDB_DSN')

def create_app():

    app = FastAPI()
    db = Database()

    @app.middleware("http")
    async def db_session_middleware(request: Request, call_next):
        request.state.pgpool = db.pool
        response = await call_next(request)
        return response

    @app.on_event("startup")
    async def startup():
        await db.create_pool()

    @app.on_event("shutdown")
    async def shutdown():
        # cleanup
        pass

    @app.get("/")
    async def hello(request: Request):
        print(request.state.pool)

    return app

app = create_app()

Solution 2:[2]

The way I do it is in db.py.

class Database:
    def __init__(self,user,password,host,database,port="5432"):
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self.database = database
        self._cursor = None

        self._connection_pool = None
        
    async def connect(self):
        if not self._connection_pool:
            try:
                self._connection_pool = await asyncpg.create_pool(
                    min_size=1,
                    max_size=20,
                    command_timeout=60,
                    host=self.host,
                    port=self.port,
                    user=self.user,
                    password=self.password,
                    database=self.database,
                    ssl="require"
                )
                logger.info("Database pool connectionn opened")

            except Exception as e:
                logger.exception(e)

    async def fetch_rows(self, query: str,*args):
        if not self._connection_pool:
            await self.connect()
        else:
            con = await self._connection_pool.acquire()
            try:
                result = await con.fetch(query,*args)
                return result
            except Exception as e:
                logger.exception(e)
            finally:
                await self._connection_pool.release(con)

    async def close(self):
        if not self._connection_pool:
            try:
                await self._connection_pool.close()
                logger.info("Database pool connection closed")
            except Exception as e:
                logger.exception(e)

Then in app

@app.on_event("startup")
async def startup_event():
    database_instance = db.Database(**db_arguments)
    await database_instance.connect()
    app.state.db = database_instance
    logger.info("Server Startup")

@app.on_event("shutdown")
async def shutdown_event():
    if not app.state.db:
        await app.state.db.close()
    logger.info("Server Shutdown")

Then you can get the db instance with request.app.state.db by passing in a request parameter in the routes.

Solution 3:[3]

The info in your post allowed me to come up with this solution. A little digging in the class definitions and I was able to find a startup event which can hook async defs onto.

db.py

from asyncpg import create_pool, Pool

pgpool: Pool | None = None

async def get_pool():
    global pgpool
    if not pgpool:
        pgpool = await create_pool(dsn='MY_DSN')
    return pgpool

my_router.py

from fastapi import APIRouter
from asyncpg import Pool
from db import get_pool

router = APIRouter()
pgpool: Pool | None = None

@router.on_event("startup")
async def router_startup():
    global pgpool
    pgpool = await get_pool()

pgpool.acquire() will be available to async defs within my_router.py.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Gabriel Cappelli
Solution 2
Solution 3 l2affiki