'How to efficiently find nearest / bracketing numerical values (and do linear interpolation) in SQLAlchemy?
Assume that I have a database table with the contents:
| Timestamp | Position |
|---|---|
| 10 | 900 |
| 15 | 1400 |
| 23 | 1700 |
| 24 | 1708 |
I would then want to calculate the respective Position for a given Timestamp using linear interpolation. Say, I've got a query for Timestamp: 22, then I'd need to fetch the next smaller and next bigger rows, which are 15, 1400 and 23, 1700. The interpolation would then be interpolation = ((23-22)*1400 + (22-15)*1700) / (23-15) resulting in 1662.5.
For a single value, I can easily get the bracketing rows and interpolate using:
query_time = 22
time_lower, pos_lower = session.query(Time, Position).filter(Time <= query_time).order_by(Time.desc()).limit(1).one()
time_higher, pos_higher = session.query(Time, Position).filter(Time > query_time).order_by(Time.asc()).limit(1).one()
result_pos = ((time_higher-query_time)*pos_lower +
(query_time-time_lower)*pos_higher) / (time_higher-time_lower)
My problem is that I need to do that efficiently for a large number of queries. The database contains millions of rows and my queries are for a couple of thousand query values. That means, math and memory are no issues but sending 2000 queries for 1000 values is a lot of overhead. Also, if I could do the calculations within the database, I'd probably save (de-)serialization time and network bandwidth.
For some scenarios, I found a speedup by fetching all rows between the smallest and largest query value and then stuff them into numpy.interp(). This has the disadvantage that for sparse input data, I fetch lots of rows that are not needed.
Any ideas how to solve this in SQL using SQLAlchemy efficiently? Ideally in a way that works with SQLite, MariaDB, and PostgreSQL.
I've tried to search the web and StackOverflow for solutions, but it seems I'm using the wrong search terms.
Solution 1:[1]
At least three options come to mind that allow querying with multiple values in a single query in this case: scalar subqueries, using LATERAL joins, or window functions. LATERAL is the "for each" of SQL; it allows a sub-SELECT to refer to FROM items that appear before it in the FROM list. Window functions allow performing calculations over sets of rows that are related to the current row.
Based on guessing, I would imagine that LATERAL wins when the number of query values is small compared to the number of rows in the table (or the table itself is small), and Timestamp is indexed. The scalar subquery approach is similar, but has to perform 4 subqueries, depending on the planner, instead of 2 joins. If the number of query values is close to the number of rows in the table and spread wide, so that in order to calculate the interpolated values almost all of the table must be read, then the window function approach might take the lead.
| Feature | PostgreSQL | MySQL | MariaDB | SQLite |
|---|---|---|---|---|
| Scalar subqueries | x | x | x | x |
LATERAL |
9.3 | 8.0.14 | ||
| Window functions | 8.4 | 8.0 | 10.2 | 3.25.0 |
In order to pass multiple values to query against, a "constant table" or a temporary table is needed. In PostgreSQL it is easy to create such a "constant table" using VALUES (or unnest() etc.):
from sqlalchemy import values, column, Integer
# Add query values as needed to `data`
qts = values(column("ts", Integer), name="qts").data([(22,)])
In MySQL (before 8.0.19) and SQLite a UNION ALL of SELECT must be used instead:
from sqlalchemy import literal, select, union_all
qts = union_all(*[select([literal(v).label("ts")]) for v in [22]])
Scalar subqueries
lower = (
session.query()
.filter(Time <= qts.c.ts)
.order_by(Time.desc())
.limit(1)
)
higher = (
session.query()
.filter(Time > qts.c.ts)
.order_by(Time.asc())
.limit(1)
)
time_lower = lower.with_entities(Time).label("time_lower")
time_higher = higher.with_entities(Time).label("time_higher")
pos_lower = lower.with_entities(Position).label("pos_lower")
pos_higher = higher.with_entities(Position).label("pos_higher")
with_high_low = session.query(
qts.c.ts,
time_lower,
time_higher,
pos_lower,
pos_higher
).subquery()
session.query(
with_high_low.c.ts,
(
(with_high_low.c.time_higher - with_high_low.c.ts)
* with_high_low.c.pos_lower
+ (with_high_low.c.ts - with_high_low.c.time_lower)
* with_high_low.c.pos_higher
) / (with_high_low.c.time_higher - with_high_low.c.time_lower)
).all()
LATERAL
lower = (
session.query(Time.label("ts"), Position.label("pos"))
.filter(Time <= qts.c.ts)
.order_by(Time.desc())
.limit(1)
.subquery()
.lateral()
)
higher = (
session.query(Time.label("ts"), Position.label("pos"))
.filter(Time > qts.c.ts)
.order_by(Time.asc())
.limit(1)
.subquery()
.lateral()
)
session.query(
qts.c.ts,
(
(higher.c.ts - qts.c.ts) * lower.c.pos
+ (qts.c.ts - lower.c.ts) * higher.c.pos
) / (higher.c.ts - lower.c.ts)
).all()
Window Functions
Combine the query values and the known values, removing query values that have a known value:
from sqlalchemy import null, not_, func, case
comb = session.query(
qts.c.ts.label("ts"),
null().label("pos")
).filter(
not_(session.query(Time).filter(Time == qts.c.ts).exists())
).union(
session.query(Time, Position)
).cte()
Build a partitioned view of the data so that it is possible to pick the lower and higher non-NULL values:
part = session.query(
comb.c.ts,
comb.c.pos,
func.count(comb.c.pos).over(order_by=comb.c.ts).label("lp"),
func.count(comb.c.pos).over(order_by=comb.c.ts.desc()).label("hp")
).cte()
The above would be unnecessary, if lag() and lead() with IGNORE NULLS was supported.
Finally, interpolate:
lw = dict(
partition_by=part.c.lp,
order_by=(part.c.ts, part.c.pos.nullslast())
)
hw = dict(
partition_by=part.c.hp,
order_by=(part.c.ts.desc(), part.c.pos.desc().nullslast())
)
fv = func.first_value
time_lower = fv(part.c.ts).over(**lw)
time_higher = fv(part.c.ts).over(**hw)
pos_lower = fv(part.c.pos).over(**lw)
pos_higher = fv(part.c.pos).over(**hw)
interp = session.query(
part.c.ts,
case([(part.c.pos == None,
((time_higher - part.c.ts) * pos_lower +
(part.c.ts - time_lower) * pos_higher) /
(time_higher - time_lower))],
else_=part.c.pos).label("pos")
).cte()
And then pick the values of interest:
session.query(interp.c.ts, interp.c.pos).filter(
interp.c.ts.in_(qts.select())
).all()
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |
