'Sending Notifications to certain clients using Server Sent Events in FastApi?

I have been working with server sent events to send out certain type of notifications to only certain clients. I am using the module called sse-starlette to try and achieve this. I am fairly new to FastApi so I am not able to figure out how to send the data to only certain clients instead of broadcasting to everyone.

This is what I have thought so far:

Subscribe request with query param

localhost:8000/subscribe?id=1
from sse_starlette.sse import EventSourceResponse


class EmitEventModel(BaseModel):
    event_name: str
    event_data: Optional[str] = "No Event Data"
    event_id: Optional[int] = None
    recipient_id: str

async def connection_established():
    yield dict(data="Connection established")

clients = {}


@app.get("/subscribe")
async def loopBackStream(req: Request, id: str = ""):
    clients[id] = EventSourceResponse(connection_established())
    return clients[id]


@app.post("/emit")
async def emitEvent(event: EmitEventModel):
    if clients[event.recipient_id]:
        clients[event.recipient_id](publish_event())

Whenever there is an api call to localhost:8000/emit containing the body, Based on the recipient_id the event is going to be routed. Ofcourse this doesn't work so far. Any pointers as to what should be done to achieve this?

sse_starlette for reference: https://github.com/sysid/sse-starlette/blob/master/sse_starlette/sse.py



Solution 1:[1]

The idea here is that you're going to need to identify the recipient_id on the SSE generator. I've slightly modified your code, to be able to show what I mean:

from __future__ import annotations 

import asyncio
import itertools
from collections import defaultdict

from fastapi import Request, FastAPI
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse


app = FastAPI()
clients = defaultdict(list)


class EmitEventModel(BaseModel):
    event_name: str
    event_data: Optional[str] = "No Event Data"
    event_id: Optional[int] = None
    recipient_id: str


async def retrieve_events(recipient_id: str) -> NoReturn:
    yield dict(data="Connection established")
    while True:
        if recipient_id in clients and len(clients[recipient_id]) > 0:
            yield clients[recipient_id].pop()
        await asyncio.sleep(1)
        print(clients)
        

@app.get("/subscribe/{recipient_id}")
async def loopBackStream(req: Request, recipient_id: str):
    return EventSourceResponse(retrieve_events(recipient_id))


@app.post("/emit")
async def emitEvent(event: EmitEventModel):
    clients[event.recipient_id].append(event)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Marcelo Trylesinski