Skip to content

Demos: differentiable SQL over ARCO-ERA5 and gradient descent in SQL#195

Merged
alxmrs merged 0 commit into
claude/xarray-sql-autograd-73ovqqfrom
claude/xarray-sql-era5-demo
Jun 30, 2026
Merged

Demos: differentiable SQL over ARCO-ERA5 and gradient descent in SQL#195
alxmrs merged 0 commit into
claude/xarray-sql-autograd-73ovqqfrom
claude/xarray-sql-era5-demo

Conversation

@alxmrs

@alxmrs alxmrs commented Jun 28, 2026

Copy link
Copy Markdown
Collaborator

Stacked on the autograd feature branch (#192) — runnable benchmark scripts kept out of the core branch so it stays reviewable.

What this adds (benchmarks/)

  • grad_era5.py — the autograd feature on real climate data. A physical quantity is written as an analytic SQL formula over ARCO-ERA5 variables, and grad(...) differentiates it symbolically, per grid cell (the relational vmap(grad(f))). Two cases, each checked against an analytic reference:

    • wind speed sqrt(u² + v²)grad(speed, u) = u/speed (exact)
    • saturation vapour pressure A·exp(B·tc/(tc+C))grad(e_s, T) vs the closed-form Clausius–Clapeyron slope

    Reads ARCO-ERA5 anonymously from GCS (needs gcsfs + network); each query round-trips back to an xarray.Dataset.

  • grad_descent.py — gradient descent in SQL. Differentiating through an aggregate is linearity, so the gradient is AVG(grad(loss, θ)). The optimiser trajectory is a params(step, a, b) table that grows one generation per step, with the update computed in SQL (new_a = a - lr*AVG(grad(loss, a))); the whole loss curve is one GROUP BY over the history. Fit matches numpy least-squares. Self-contained (no network).

  • benchmarks/README.md describing both.

Each script declares its deps inline (PEP 723) and runs with uv run benchmarks/<name>.py; nothing is added to pyproject.toml, and the suite lives outside the package so it isn't collected by pytest or shipped in the wheel.

🤖 Generated with Claude Code


Generated by Claude Code

Comment thread benchmarks/grad_descent.py Outdated
Comment on lines +94 to +100
for k in range(STEPS):
register_params()
row = ctx.sql(update_sql(k)).to_pandas()
steps.append(k + 1)
a_hist.append(float(row["a"][0]))
b_hist.append(float(row["b"][0]))
register_params()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can do this iteration in SQL too -- these just add more rows into a (train?) table.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this — done. params is now a DataFusion in-memory table and each step is literally "add a row":

INSERT INTO params VALUES (:next_step, :new_a, :new_b)

where (new_a, new_b) come from the SQL update a - lr*AVG(grad(loss, a)). So the gradient, the update, and the append all happen in SQL — a training step is an INSERT into the table, and the loss curve is one GROUP BY over the accumulated history. (This needed a small core change: the rewrite now resolves grad over any registered table, not just xarray-backed ones — see #192.)

One honest seam remains: the Python for loop still sequences the steps. Doing the iteration itself in SQL — no Python loop — would be a recursive CTE, which DataFusion/Substrait don't round-trip yet (tracked in #194). So today it's "each step is SQL INSERT", not "the whole loop is one SQL statement". The INSERT ... SELECT grad(...) single-statement form also doesn't go through the Substrait rewrite (DML plans don't convert), which is why the update is a SELECT feeding an INSERT VALUES.


Generated by Claude Code

@alxmrs alxmrs force-pushed the claude/xarray-sql-era5-demo branch from afb1036 to fdb17fb Compare June 28, 2026 15:17
Comment thread benchmarks/grad_descent.py Outdated
FROM d CROSS JOIN cur
GROUP BY cur.a, cur.b
"""
).to_pandas()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pandas invokes a copy, let's keep it all in datafusion (lazily evaluated).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by the recursive-CTE rewrite (see the thread below): the training loop is now a single query, so there's no per-step to_pandas() — just one .to_pandas() at the very end to print the trajectory/report the fit.


Generated by Claude Code

Comment thread benchmarks/grad_descent.py Outdated
Comment on lines +79 to +95
for k in range(STEPS):
# Compute the next generation in SQL: descend from the current row along
# the SQL-computed gradient...
row = ctx.sql(
f"""
WITH cur AS (SELECT a, b FROM params WHERE step = {k})
SELECT cur.a - {LR} * AVG(grad({LOSS}, a)) AS a,
cur.b - {LR} * AVG(grad({LOSS}, b)) AS b
FROM d CROSS JOIN cur
GROUP BY cur.a, cur.b
"""
).to_pandas()
# ...and append it to the parameter table with a SQL INSERT.
na, nb = float(row["a"][0]), float(row["b"][0])
ctx.sql(
f"INSERT INTO params VALUES ({k + 1}, {na!r}, {nb!r})"
).collect()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we select over a range of STEPSs values (iteration) and do this in one query as a set of nested SELECT statements? Why do we need to spill into Python and use the imperative paradigm? I think the entire training loop could be declarative tables.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — the whole training loop is now one declarative recursive CTE, no Python iteration:

WITH RECURSIVE params(step, a, b) AS (
  SELECT 0, 0.0, 0.0
  UNION ALL
  SELECT params.step + 1, params.a - lr*AVG(da), params.b - lr*AVG(db)
  FROM params CROSS JOIN d WHERE params.step < STEPS
  GROUP BY params.step, params.a, params.b)
SELECT * FROM params ORDER BY step

The one wrinkle: grad() can't appear inside the recursion, because it needs the Substrait round-trip and Substrait has no recursion (to_substrait_plan rejects a RecursiveQuery). So I split it: a new engine helper xql.differentiate_sql(loss, "a", cols) differentiates the per-row loss to SQL text once (autograd as a "compiler"), and the recursive CTE iterates that plain-SQL rule. da/db are the compiled gradients; gradient + update + iteration are all in SQL. Fits to the OLS solution exactly. (differentiate_sql is in #192.)


Generated by Claude Code

@alxmrs alxmrs left a comment

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More feedback, this time on the ERA5 example.

Comment thread benchmarks/grad_era5.py Outdated
"""
)
.to_dataset(dims=["latitude", "longitude"])
.sortby(_SORT)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done with an ORDER BY SQL expr.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — the query now ends with ORDER BY latitude, longitude and I dropped the Python .sortby on the SQL output. (The reference array is still sorted in Python, since it comes from xarray in ERA5's native descending-latitude order.)


Generated by Claude Code

Comment thread benchmarks/grad_era5.py Outdated
"""
)
.to_dataset(dims=["latitude", "longitude"])
.sortby(_SORT)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done here too — ORDER BY latitude, longitude in the Clausius–Clapeyron query, output .sortby removed.


Generated by Claude Code

Comment thread benchmarks/grad_era5.py Outdated
des_dt_ref = es_ref * (b * c) / (tc_ref + c) ** 2

xr.testing.assert_allclose(
out["e_s"], es_ref.sortby(_SORT).rename("e_s"), rtol=1e-5

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can eliminate the ref sortby id you order latitude by DESC in the SQL query.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — both queries now ORDER BY latitude DESC, longitude to match ERA5's native order, and both reference .sortby calls are gone (the _SORT helper too). No sorting on either side now.

One subtlety I hit making this work: to_dataset only preserves the SQL ORDER BY order for a single partition. Across multiple partitions it reconstructs coordinates ascending regardless (a latitude DESC query still came back ascending). The demo block is tiny, so I switched it to one partition (block.chunk()) and noted why. The general fix — preserve source dim order across partitions on round-trip — is exactly #193; once that lands this works for chunked tables too.


Generated by Claude Code

Comment thread benchmarks/grad_era5.py Outdated
out["e_s"], es_ref.sortby(_SORT).rename("e_s"), rtol=1e-5
)
xr.testing.assert_allclose(
out["de_s_dt"], des_dt_ref.sortby(_SORT).rename("de_s_dt"), rtol=1e-5

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — same change in the Clausius–Clapeyron query (ORDER BY latitude DESC, longitude, reference sort removed).


Generated by Claude Code

@alxmrs alxmrs force-pushed the claude/xarray-sql-era5-demo branch from 8f97173 to 27c02d4 Compare June 28, 2026 15:58
@alxmrs alxmrs merged commit 27c02d4 into claude/xarray-sql-autograd-73ovqq Jun 30, 2026
2 checks passed
@alxmrs alxmrs deleted the claude/xarray-sql-era5-demo branch June 30, 2026 13:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant