'How to efficiently find nearest / bracketing numerical values (and do linear interpolation) in SQLAlchemy?

Assume that I have a database table with the contents:

Timestamp Position
10 900
15 1400
23 1700
24 1708

I would then want to calculate the respective Position for a given Timestamp using linear interpolation. Say, I've got a query for Timestamp: 22, then I'd need to fetch the next smaller and next bigger rows, which are 15, 1400 and 23, 1700. The interpolation would then be interpolation = ((23-22)*1400 + (22-15)*1700) / (23-15) resulting in 1662.5.

For a single value, I can easily get the bracketing rows and interpolate using:

query_time = 22
time_lower, pos_lower = session.query(Time, Position).filter(Time <= query_time).order_by(Time.desc()).limit(1).one()
time_higher, pos_higher = session.query(Time, Position).filter(Time > query_time).order_by(Time.asc()).limit(1).one()
result_pos = ((time_higher-query_time)*pos_lower + 
                  (query_time-time_lower)*pos_higher) / (time_higher-time_lower)

My problem is that I need to do that efficiently for a large number of queries. The database contains millions of rows and my queries are for a couple of thousand query values. That means, math and memory are no issues but sending 2000 queries for 1000 values is a lot of overhead. Also, if I could do the calculations within the database, I'd probably save (de-)serialization time and network bandwidth.

For some scenarios, I found a speedup by fetching all rows between the smallest and largest query value and then stuff them into numpy.interp(). This has the disadvantage that for sparse input data, I fetch lots of rows that are not needed.

Any ideas how to solve this in SQL using SQLAlchemy efficiently? Ideally in a way that works with SQLite, MariaDB, and PostgreSQL.

I've tried to search the web and StackOverflow for solutions, but it seems I'm using the wrong search terms.



Solution 1:[1]

At least three options come to mind that allow querying with multiple values in a single query in this case: scalar subqueries, using LATERAL joins, or window functions. LATERAL is the "for each" of SQL; it allows a sub-SELECT to refer to FROM items that appear before it in the FROM list. Window functions allow performing calculations over sets of rows that are related to the current row.

Based on guessing, I would imagine that LATERAL wins when the number of query values is small compared to the number of rows in the table (or the table itself is small), and Timestamp is indexed. The scalar subquery approach is similar, but has to perform 4 subqueries, depending on the planner, instead of 2 joins. If the number of query values is close to the number of rows in the table and spread wide, so that in order to calculate the interpolated values almost all of the table must be read, then the window function approach might take the lead.

Feature PostgreSQL MySQL MariaDB SQLite
Scalar subqueries x x x x
LATERAL 9.3 8.0.14
Window functions 8.4 8.0 10.2 3.25.0

In order to pass multiple values to query against, a "constant table" or a temporary table is needed. In PostgreSQL it is easy to create such a "constant table" using VALUES (or unnest() etc.):

from sqlalchemy import values, column, Integer

# Add query values as needed to `data`
qts = values(column("ts", Integer), name="qts").data([(22,)])

In MySQL (before 8.0.19) and SQLite a UNION ALL of SELECT must be used instead:

from sqlalchemy import literal, select, union_all

qts = union_all(*[select([literal(v).label("ts")]) for v in [22]])

Scalar subqueries

lower = (
    session.query()
    .filter(Time <= qts.c.ts)
    .order_by(Time.desc())
    .limit(1)
)
higher = (
    session.query()
    .filter(Time > qts.c.ts)
    .order_by(Time.asc())
    .limit(1)
)
time_lower = lower.with_entities(Time).label("time_lower")
time_higher = higher.with_entities(Time).label("time_higher")
pos_lower = lower.with_entities(Position).label("pos_lower")
pos_higher = higher.with_entities(Position).label("pos_higher")

with_high_low = session.query(
    qts.c.ts,
    time_lower,
    time_higher,
    pos_lower,
    pos_higher
).subquery()

session.query(
    with_high_low.c.ts,
    (
        (with_high_low.c.time_higher - with_high_low.c.ts)
        * with_high_low.c.pos_lower
        + (with_high_low.c.ts - with_high_low.c.time_lower)
        * with_high_low.c.pos_higher
    ) / (with_high_low.c.time_higher - with_high_low.c.time_lower)
).all()

LATERAL

lower = (
    session.query(Time.label("ts"), Position.label("pos"))
    .filter(Time <= qts.c.ts)
    .order_by(Time.desc())
    .limit(1)
    .subquery()
    .lateral()
)
higher = (
    session.query(Time.label("ts"), Position.label("pos"))
    .filter(Time > qts.c.ts)
    .order_by(Time.asc())
    .limit(1)
    .subquery()
    .lateral()
)

session.query(
    qts.c.ts,
    (
        (higher.c.ts - qts.c.ts) * lower.c.pos
        + (qts.c.ts - lower.c.ts) * higher.c.pos
    ) / (higher.c.ts - lower.c.ts)
).all()

Window Functions

Combine the query values and the known values, removing query values that have a known value:

from sqlalchemy import null, not_, func, case

comb = session.query(
    qts.c.ts.label("ts"),
    null().label("pos")
).filter(
    not_(session.query(Time).filter(Time == qts.c.ts).exists())
).union(
    session.query(Time, Position)
).cte()

Build a partitioned view of the data so that it is possible to pick the lower and higher non-NULL values:

part = session.query(
    comb.c.ts,
    comb.c.pos,
    func.count(comb.c.pos).over(order_by=comb.c.ts).label("lp"),
    func.count(comb.c.pos).over(order_by=comb.c.ts.desc()).label("hp")
).cte()

The above would be unnecessary, if lag() and lead() with IGNORE NULLS was supported.

Finally, interpolate:

lw = dict(
    partition_by=part.c.lp,
    order_by=(part.c.ts, part.c.pos.nullslast())
)
hw = dict(
    partition_by=part.c.hp,
    order_by=(part.c.ts.desc(), part.c.pos.desc().nullslast())
)
fv = func.first_value
time_lower = fv(part.c.ts).over(**lw)
time_higher = fv(part.c.ts).over(**hw)
pos_lower = fv(part.c.pos).over(**lw)
pos_higher = fv(part.c.pos).over(**hw)

interp = session.query(
    part.c.ts,
    case([(part.c.pos == None,
           ((time_higher - part.c.ts) * pos_lower +
            (part.c.ts - time_lower) * pos_higher) /
           (time_higher - time_lower))],
         else_=part.c.pos).label("pos")
).cte()

And then pick the values of interest:

session.query(interp.c.ts, interp.c.pos).filter(
    interp.c.ts.in_(qts.select())
).all()

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1