Skip to content

Differentiable SQL: grad()/jvp()/vjp() over xarray and any DataFusion table#192

Open
alxmrs wants to merge 12 commits into
mainfrom
claude/xarray-sql-autograd-73ovqq
Open

Differentiable SQL: grad()/jvp()/vjp() over xarray and any DataFusion table#192
alxmrs wants to merge 12 commits into
mainfrom
claude/xarray-sql-autograd-73ovqq

Conversation

@alxmrs

@alxmrs alxmrs commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

Differentiable SQL for xarray-sql: write grad(expr, column) in a query and get
its derivative back as an ordinary column, evaluated row-by-row by DataFusion
alongside everything else. Because each row of a table is an independent
evaluation point, differentiating a column expression and letting DataFusion
evaluate it per row is the relational equivalent of jax.vmap(jax.grad(f)) — the
rows are the batch dimension. This turns SQL into a place you can express
gradients, directional derivatives, and whole training loops.

The surface

Three scalar operations, each returning one value per row (staying in the
long/tidy data model):

  • grad(expr, column) — the partial derivative d(expr)/d(column).
  • jvp(expr, column, tangent) — forward-mode directional derivative,
    d(expr)/d(column) * tangent (seed a tangent on an input). A multi-input
    directional derivative is a sum of jvp terms.
  • vjp(expr, column, cotangent) — reverse-mode pullback,
    cotangent * d(expr)/d(column) (seed a cotangent on the output).

A full gradient or Jacobian is expressed as several scalar columns rather than a
nested array, which would break the one-value-per-coordinate model:

SELECT i, grad(x * y, x) AS dfdx, grad(x * y, y) AS dfdy FROM g

What you can write:

  • Higher-order derivatives, by nesting: grad(grad(sin(x), x), x).

  • Differentiation through an aggregate — it's just linearity, so put the
    marker inside the aggregate: AVG(grad(loss, theta)) is the relational
    d/dθ (Σ loss) / N. Enough to run gradient descent in SQL.

  • grad inside any query shape — recursive CTEs, DML, and subqueries — so a
    training loop can live entirely in one declarative query. For example,
    Newton's method for √2 driven by grad in the recursive term:

    WITH RECURSIVE newton AS (
      SELECT 0 AS step, CAST(1.0 AS DOUBLE) AS x
      UNION ALL
      SELECT step + 1, x - (x * x - 2.0) / grad(x * x - 2.0, x)
      FROM newton WHERE step < 20)
    SELECT x FROM newton ORDER BY step DESC LIMIT 1
  • grad over any registered table, not just xarray datasets: plain in-memory
    tables holding model parameters, and schema-qualified tables (era5.surface)
    from mixed-dimension datasets all work.

Also exported: differentiate_sql(expr, wrt, columns) — a calculus compiler
that differentiates a single expression and hands back the derivative as SQL
text, for when you want the update rule as a string to embed yourself.

How it works

The core is src/autograd.rs, a symbolic differentiation engine over DataFusion
logical Expr trees. The design mirrors JAX's per-primitive rule registry
(defjvp and friends): every node type has a differentiation rule and the chain
rule composes them as the tree is walked. A small 0/1-folding simplifier
keeps the output compact, playing the role of JAX's Zero tangents and
add_tangents. grad and jvp share the chain rule and differ only in their
leaf rule (a one-hot seed vs. an arbitrary per-input tangent); vjp scales the
partial by the cotangent.

Rules are implemented for +, -, *, /, the unary chain rule for
sin/cos/tan, asin/acos/atan, exp/ln/log2/log10/sqrt,
sinh/cosh/tanh, abs, and power() with a constant base or exponent. An
unsupported node or function returns a clear NotImplemented error rather than a
silently wrong derivative.

The markers are differentiated as a SQL source-to-source rewrite, before the
query is planned
: XarrayContext.sql() hands a query containing a marker to
the native rewrite_grad_sql, which parses the statement with sqlparser,
differentiates each grad/jvp/vjp call, and renders the derivative back into
the query in place. The result is ordinary SQL the stock datafusion-python
context plans and executes directly. Running before binding is what makes every
query shape the parser accepts — recursive CTEs, DML, subqueries — work
uniformly, with no special cases per plan type.

This needs no DataFusion fork and no custom datafusion-python wheel: it runs
against the published package. There is no Substrait round-trip and no protoc
build dependency.

Things to know

  • Marker arguments use unqualified column names (e.g. grad(y - a*x - b, a)),
    matching how the differentiation reads its variable. This is the one
    consequence of differentiating syntactically, before binding.
  • Put grad inside an aggregate (AVG(grad(f, x))), not outside. The
    transposed form grad(SUM(f), x) is rejected by SQL's own scoping, since the
    per-row column is gone after aggregation.

Tests

Covered end to end by the Python suite (tests/test_autograd.py) — derivatives
checked against numpy analytics, gradient descent and a recursive-CTE training
loop converging to their closed-form solutions, jvp/vjp agreement, multi-
input Jacobians, higher-order derivatives, schema-qualified tables — and by the
Rust unit tests for the differentiation rules and the SQL rewrite.

Co-Authored-By: Claude Opus 4.8 noreply@anthropic.com

claude added 8 commits June 27, 2026 22:44
Introduce `src/autograd.rs`, the Rust core of the autograd feature: a
`differentiate(&Expr, wrt)` function that symbolically differentiates a
DataFusion logical `Expr` tree with respect to a named column and returns a
new `Expr` built from ordinary SQL expressions.

The design mirrors JAX's per-primitive rule registry (defjvp and friends):
each node type has a differentiation rule and the chain rule composes them
as the tree is walked. A small 0/1-folding simplifier keeps output compact,
playing the role of JAX's Zero tangents and add_tangents.

Because each table row is an independent evaluation point, differentiating a
column expression and letting DataFusion evaluate it row-by-row is the
relational equivalent of vmap(grad(f)).

This first cut implements scalar `grad`: rules for +, -, *, / (sum, product,
quotient), unary chain rule for sin/cos/tan, asin/acos/atan, exp/ln/log2/
log10/sqrt, sinh/cosh/tanh, abs, and power() with constant base or exponent.
Unsupported nodes/functions return a clear NotImplemented error rather than a
silently wrong derivative.

The engine operates purely on DataFusion `Expr`, keeping the eventual
Python<->Rust transport (SQL text, Substrait, or proto) pluggable. Covered by
11 unit tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Add the `grad` marker UDF and a plan-level rewriter (`rewrite_grad_calls`)
to the autograd engine, plus a `grad_rewrite` PyO3 function that bridges the
differentiation engine into the datafusion-python SessionContext.

Because the native extension links its own copy of DataFusion, expressions
cross the Python<->Rust boundary as Substrait protobuf. Python produces the
logical plan as Substrait; `grad_rewrite` consumes it into a DataFusion
LogicalPlan, rewrites every `grad(expr, column)` ScalarFunction into the
symbolic derivative via `differentiate`, and re-produces Substrait bytes for
Python to consume and execute. The custom xarray table provider round-trips
because Substrait serializes table scans by name (resolved against the
registry on consume), so the rewrite context only needs empty tables with
matching schemas.

`grad` is registered as a marker ScalarUDF that carries the differentiation
request intact through parsing, planning, and serialization; it is always
rewritten away before execution and errors if it ever reaches invoke.

Deps: datafusion-substrait 52 and prost 0.14 (matching the substrait crate).
Building now requires `protoc` (the substrait crate codegens from .proto).

Verified end to end (produce -> grad_rewrite -> consume -> execute) against
analytic derivatives for cos, the product rule, and exp with 0.0 error.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Wire the autograd surface into XarrayContext so users can write calculus
directly in SQL:

    ctx.sql("SELECT grad(sin(val), val) AS d_val, sin(val) AS val FROM t")

On construction the context registers the `grad` marker UDF so such queries
parse and plan. XarrayContext.sql() detects `grad(` (a cheap regex gate so
ordinary queries are untouched) and routes through _sql_with_autograd: it
plans the query, produces the logical plan as Substrait, calls the native
grad_rewrite to differentiate every grad(expr, column) symbolically, then
consumes the rewritten Substrait back into an executable DataFrame.

Table scans are resolved by name on the consume side, so _table_schemas()
passes the (name, schema) of each registered table to the rewrite. Schema-
qualified tables (mixed-dimension datasets) are skipped for now and noted as
a follow-up.

Adds tests/test_autograd.py covering sin/cos, product and quotient rules,
power, exp, the non-grad passthrough, and a clear error for unsupported
functions — all checked against numpy analytic derivatives. Existing SQL
tests still pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Adding datafusion-substrait pulls in the `substrait` crate, whose build
script generates Rust from .proto files and requires `protoc`. Without it the
Rust/maturin builds fail.

- ci.yml, ci-build.yml, ci-rust.yml: add arduino/setup-protoc before the
  build (covers Linux, macOS and Windows runners).
- publish.yml: setup-protoc for the macOS/Windows wheel job; for the
  manylinux maturin-action jobs install protoc inside the container via
  before-script-linux (arch-aware download). The sdist job is unchanged as it
  packages source without compiling.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Extend the autograd surface from scalar grad() to multi-input Jacobians.

    SELECT jacobian(sin(x) * y, [x, y]) AS jac FROM g
    -- per row: [d/dx, d/dy] = [cos(x)*y, sin(x)]  (a List<Float64>)

`jacobian(expr, [c1, c2, ...])` differentiates `expr` with respect to each
listed column and returns the gradient row as an array. Using a SQL array for
the inputs keeps the marker at fixed arity two (avoiding variadic-UDF issues):
the `[c1, c2, ...]` parses to make_array(c1, c2, ...), from which the rewrite
extracts the input columns; the result is built with make_array of the
partials. Array/list columns round-trip through Substrait, verified end to end.

The single grad() marker is generalized into a reusable MarkerUdf (with
grad_marker()/jacobian_marker() constructors and per-marker return types), and
the plan rewrite dispatches on the function name. A full Jacobian can also be
written as separate scalar grad() columns, which already worked; both forms are
covered by tests against numpy analytic derivatives.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Drop the jacobian(expr, [cols]) -> List<Float64> form: a nested array column
breaks the long/tidy data model (a cell should be one value aligned to its
coordinates). The same Jacobian is expressed in-model as several scalar
columns, e.g. grad(f, x) AS dfdx, grad(f, y) AS dfdy.

Add forward- and reverse-mode gradients as scalar SQL functions:

  * jvp(expr, column, tangent)   -> d(expr)/d(column) * tangent   (forward)
  * vjp(expr, column, cotangent) -> cotangent * d(expr)/d(column) (reverse)

A multi-input directional derivative is the sum of per-input jvp terms; both
stay scalar, so they round-trip cleanly through Substrait and back to xarray.

Engine: unify grad and jvp behind a single `linearize` (forward-mode chain
rule with a pluggable leaf rule) — grad is a one-hot seed, jvp an arbitrary
seed per input. This mirrors JAX's structure and removes rule duplication.
vjp is cotangent * grad; for a scalar output forward and reverse coincide
(asserted by a jvp/vjp agreement test), differing only in seed placement.

Tests: 15 Rust unit tests and 11 Python integration tests (incl. jvp/vjp
semantics, the multi-input sum, and jvp==vjp for a unit seed), all checked
against numpy analytic derivatives. fmt/clippy/ruff/mypy clean.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Mixed-dimension datasets register as schema-qualified tables (e.g.
era5.surface / era5.time_x_level). The autograd rewrite consumes the plan in
a throwaway context that registers an empty table per scanned name, but
register_table("era5.time_x", ...) failed with "failed to resolve schema:
era5" because the namespace did not exist.

Add ensure_schema(): before registering each table, parse its name into a
TableReference and, for qualified names, create the schema namespace
(MemorySchemaProvider) in the default catalog if absent. The Python side
already resolves qualified names via ctx.table(name).schema(); only the Rust
rewrite context needed the namespace.

Tests: a mixed-dimension fixture exercising grad on both the 2D surface and
3D atmosphere tables, against numpy analytic derivatives.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Nested calls such as grad(grad(f, x), x) already yield higher-order
derivatives: the plan rewrite walks expressions bottom-up (transform_up), so
the inner grad is differentiated to a plain expression first and the outer
grad differentiates that result. No code change was needed; this adds tests
and documents the behavior.

- Rust: a unit test that differentiation composes (d2/dx2 sin = -sin).
- Python: second derivatives of sin (-sin) and x^3 (6x) and the third
  derivative of sin (-cos), against numpy.
- Doc: note higher-order support in the module overview.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Comment thread benchmarks/grad_descent.py Outdated
Comment on lines +48 to +56
def set_params(ctx: xql.XarrayContext, a: float, b: float) -> None:
"""(Re)register the one-row params table holding the current a, b."""
if "params" in ctx._registered_datasets:
ctx.deregister_table("params")
del ctx._registered_datasets["params"]
params = xr.Dataset(
{"a": (("p",), [a]), "b": (("p",), [b])}, coords={"p": [0]}
)
ctx.from_dataset("params", params, chunks={"p": 1})

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any ideas for what a more SQL native update would be? What if we had a parameter table and added new rows per each update?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea — implemented it that way. The optimiser trajectory is now a params(step, a, b) table that grows one generation per step, and the update itself is computed in SQL by descending from the current row:

WITH cur AS (SELECT a, b FROM params WHERE step = :k)
SELECT cur.a - lr * AVG(grad(loss, a)) AS a,
       cur.b - lr * AVG(grad(loss, b)) AS b
FROM d CROSS JOIN cur
GROUP BY cur.a, cur.b

So each step appends the next generation, and the whole loss curve is a single GROUP BY over the history joined to the data — the trajectory is a relation you can query. One honest caveat: xarray-backed tables are read-only to SQL, so "append a row" is done by re-registering the grown history; a true in-place INSERT would need a mutable table provider. (I also tried WHERE step = (SELECT MAX(step) ...), but scalar subqueries don't convert to Substrait yet, so the loop passes the current step as a literal.)

I moved the runnable demos off this branch to keep it focused on the engine — grad_descent.py (with this change) now lives on the stacked claude/xarray-sql-era5-demo branch, and the MNIST trainer on claude/xarray-sql-mnist-demo stacked above it.


Generated by Claude Code

Document and test that differentiating through SUM/AVG is just linearity:
AGG(grad(f, x)) == d/dx AGG(f). Writing grad inside the aggregate composes
with SQL scoping (the marker rewrites to plain SQL before the aggregate runs),
so it needs no special machinery -- enough to express gradient descent in SQL.

Adds tests for SUM/AVG(grad(...)) and an end-to-end gradient-descent
convergence test, plus a note in the module overview. The runnable benchmark
scripts live on stacked demo branches to keep this feature branch reviewable.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
claude added 2 commits June 28, 2026 15:14
Generalize _table_schemas() to enumerate the catalog instead of only the
xarray-registered datasets, so the Substrait rewrite can resolve grad() queries
that reference plain DataFusion tables too -- e.g. in-memory MemTables holding
model parameters or intermediate results. This makes grad compose with ordinary
relational state (a parameter table you INSERT into), not only gridded xarray
data.

Adds a test differentiating an expression whose coefficient lives in an
in-memory table cross-joined to the xarray data.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Expose the autograd engine as a "calculus compiler": differentiate_sql(expr,
wrt, columns) parses a SQL scalar expression (parse_sql_expr), differentiates it
with the engine, and unparses the derivative back to SQL (expr_to_sql).

Where grad(...) rewrites a whole plan via Substrait, this hands back a single
derivative expression as text -- usable where the Substrait round-trip can't
carry a grad marker, e.g. embedding a precomputed update rule inside a recursive
-CTE training loop (Substrait has no recursion). Exposed as xarray_sql.
differentiate_sql; covered by a round-trip test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Make grad()/jvp()/vjp() work inside any query shape (recursive CTEs, DML,
subqueries) by rewriting the calls as SQL text before planning, rather than
round-tripping the logical plan through Substrait (which could not represent
those shapes). Closes the gap tracked in #197.

XarrayContext.sql() now hands a query containing a marker to the native
rewrite_grad_sql, which parses the statement with sqlparser, differentiates
each marker call with the existing engine, and renders the derivative back
into the SQL in place. Because it runs before planning, every query shape the
parser accepts is supported, and the result is ordinary SQL the stock
datafusion-python context plans and executes directly.

This removes the Substrait round-trip entirely: the datafusion-substrait and
prost dependencies, the grad_rewrite/_sql_with_autograd/_table_schemas plumbing,
the marker-UDF registration, and the protoc steps in CI. Unlike the FFI
alternative, it needs no datafusion fork and no custom datafusion-python wheel.

The grad surface is unchanged (same SQL, same results); marker arguments use
unqualified column names, matching existing usage, since differentiation is
syntactic and runs before binding.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@alxmrs alxmrs changed the title Add symbolic differentiation engine for SQL expressions (autograd MVP) Differentiable SQL: grad()/jvp()/vjp() over xarray and any DataFusion table Jun 30, 2026
@alxmrs alxmrs force-pushed the claude/xarray-sql-autograd-73ovqq branch 2 times, most recently from a4fc101 to 7b1e530 Compare June 30, 2026 13:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants