Differentiable SQL: grad()/jvp()/vjp() over xarray and any DataFusion table#192
Differentiable SQL: grad()/jvp()/vjp() over xarray and any DataFusion table#192alxmrs wants to merge 12 commits into
Conversation
Introduce `src/autograd.rs`, the Rust core of the autograd feature: a `differentiate(&Expr, wrt)` function that symbolically differentiates a DataFusion logical `Expr` tree with respect to a named column and returns a new `Expr` built from ordinary SQL expressions. The design mirrors JAX's per-primitive rule registry (defjvp and friends): each node type has a differentiation rule and the chain rule composes them as the tree is walked. A small 0/1-folding simplifier keeps output compact, playing the role of JAX's Zero tangents and add_tangents. Because each table row is an independent evaluation point, differentiating a column expression and letting DataFusion evaluate it row-by-row is the relational equivalent of vmap(grad(f)). This first cut implements scalar `grad`: rules for +, -, *, / (sum, product, quotient), unary chain rule for sin/cos/tan, asin/acos/atan, exp/ln/log2/ log10/sqrt, sinh/cosh/tanh, abs, and power() with constant base or exponent. Unsupported nodes/functions return a clear NotImplemented error rather than a silently wrong derivative. The engine operates purely on DataFusion `Expr`, keeping the eventual Python<->Rust transport (SQL text, Substrait, or proto) pluggable. Covered by 11 unit tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Add the `grad` marker UDF and a plan-level rewriter (`rewrite_grad_calls`) to the autograd engine, plus a `grad_rewrite` PyO3 function that bridges the differentiation engine into the datafusion-python SessionContext. Because the native extension links its own copy of DataFusion, expressions cross the Python<->Rust boundary as Substrait protobuf. Python produces the logical plan as Substrait; `grad_rewrite` consumes it into a DataFusion LogicalPlan, rewrites every `grad(expr, column)` ScalarFunction into the symbolic derivative via `differentiate`, and re-produces Substrait bytes for Python to consume and execute. The custom xarray table provider round-trips because Substrait serializes table scans by name (resolved against the registry on consume), so the rewrite context only needs empty tables with matching schemas. `grad` is registered as a marker ScalarUDF that carries the differentiation request intact through parsing, planning, and serialization; it is always rewritten away before execution and errors if it ever reaches invoke. Deps: datafusion-substrait 52 and prost 0.14 (matching the substrait crate). Building now requires `protoc` (the substrait crate codegens from .proto). Verified end to end (produce -> grad_rewrite -> consume -> execute) against analytic derivatives for cos, the product rule, and exp with 0.0 error. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Wire the autograd surface into XarrayContext so users can write calculus
directly in SQL:
ctx.sql("SELECT grad(sin(val), val) AS d_val, sin(val) AS val FROM t")
On construction the context registers the `grad` marker UDF so such queries
parse and plan. XarrayContext.sql() detects `grad(` (a cheap regex gate so
ordinary queries are untouched) and routes through _sql_with_autograd: it
plans the query, produces the logical plan as Substrait, calls the native
grad_rewrite to differentiate every grad(expr, column) symbolically, then
consumes the rewritten Substrait back into an executable DataFrame.
Table scans are resolved by name on the consume side, so _table_schemas()
passes the (name, schema) of each registered table to the rewrite. Schema-
qualified tables (mixed-dimension datasets) are skipped for now and noted as
a follow-up.
Adds tests/test_autograd.py covering sin/cos, product and quotient rules,
power, exp, the non-grad passthrough, and a clear error for unsupported
functions — all checked against numpy analytic derivatives. Existing SQL
tests still pass.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Adding datafusion-substrait pulls in the `substrait` crate, whose build script generates Rust from .proto files and requires `protoc`. Without it the Rust/maturin builds fail. - ci.yml, ci-build.yml, ci-rust.yml: add arduino/setup-protoc before the build (covers Linux, macOS and Windows runners). - publish.yml: setup-protoc for the macOS/Windows wheel job; for the manylinux maturin-action jobs install protoc inside the container via before-script-linux (arch-aware download). The sdist job is unchanged as it packages source without compiling. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Extend the autograd surface from scalar grad() to multi-input Jacobians.
SELECT jacobian(sin(x) * y, [x, y]) AS jac FROM g
-- per row: [d/dx, d/dy] = [cos(x)*y, sin(x)] (a List<Float64>)
`jacobian(expr, [c1, c2, ...])` differentiates `expr` with respect to each
listed column and returns the gradient row as an array. Using a SQL array for
the inputs keeps the marker at fixed arity two (avoiding variadic-UDF issues):
the `[c1, c2, ...]` parses to make_array(c1, c2, ...), from which the rewrite
extracts the input columns; the result is built with make_array of the
partials. Array/list columns round-trip through Substrait, verified end to end.
The single grad() marker is generalized into a reusable MarkerUdf (with
grad_marker()/jacobian_marker() constructors and per-marker return types), and
the plan rewrite dispatches on the function name. A full Jacobian can also be
written as separate scalar grad() columns, which already worked; both forms are
covered by tests against numpy analytic derivatives.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Drop the jacobian(expr, [cols]) -> List<Float64> form: a nested array column breaks the long/tidy data model (a cell should be one value aligned to its coordinates). The same Jacobian is expressed in-model as several scalar columns, e.g. grad(f, x) AS dfdx, grad(f, y) AS dfdy. Add forward- and reverse-mode gradients as scalar SQL functions: * jvp(expr, column, tangent) -> d(expr)/d(column) * tangent (forward) * vjp(expr, column, cotangent) -> cotangent * d(expr)/d(column) (reverse) A multi-input directional derivative is the sum of per-input jvp terms; both stay scalar, so they round-trip cleanly through Substrait and back to xarray. Engine: unify grad and jvp behind a single `linearize` (forward-mode chain rule with a pluggable leaf rule) — grad is a one-hot seed, jvp an arbitrary seed per input. This mirrors JAX's structure and removes rule duplication. vjp is cotangent * grad; for a scalar output forward and reverse coincide (asserted by a jvp/vjp agreement test), differing only in seed placement. Tests: 15 Rust unit tests and 11 Python integration tests (incl. jvp/vjp semantics, the multi-input sum, and jvp==vjp for a unit seed), all checked against numpy analytic derivatives. fmt/clippy/ruff/mypy clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Mixed-dimension datasets register as schema-qualified tables (e.g.
era5.surface / era5.time_x_level). The autograd rewrite consumes the plan in
a throwaway context that registers an empty table per scanned name, but
register_table("era5.time_x", ...) failed with "failed to resolve schema:
era5" because the namespace did not exist.
Add ensure_schema(): before registering each table, parse its name into a
TableReference and, for qualified names, create the schema namespace
(MemorySchemaProvider) in the default catalog if absent. The Python side
already resolves qualified names via ctx.table(name).schema(); only the Rust
rewrite context needed the namespace.
Tests: a mixed-dimension fixture exercising grad on both the 2D surface and
3D atmosphere tables, against numpy analytic derivatives.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Nested calls such as grad(grad(f, x), x) already yield higher-order derivatives: the plan rewrite walks expressions bottom-up (transform_up), so the inner grad is differentiated to a plain expression first and the outer grad differentiates that result. No code change was needed; this adds tests and documents the behavior. - Rust: a unit test that differentiation composes (d2/dx2 sin = -sin). - Python: second derivatives of sin (-sin) and x^3 (6x) and the third derivative of sin (-cos), against numpy. - Doc: note higher-order support in the module overview. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
| def set_params(ctx: xql.XarrayContext, a: float, b: float) -> None: | ||
| """(Re)register the one-row params table holding the current a, b.""" | ||
| if "params" in ctx._registered_datasets: | ||
| ctx.deregister_table("params") | ||
| del ctx._registered_datasets["params"] | ||
| params = xr.Dataset( | ||
| {"a": (("p",), [a]), "b": (("p",), [b])}, coords={"p": [0]} | ||
| ) | ||
| ctx.from_dataset("params", params, chunks={"p": 1}) |
There was a problem hiding this comment.
Any ideas for what a more SQL native update would be? What if we had a parameter table and added new rows per each update?
There was a problem hiding this comment.
Great idea — implemented it that way. The optimiser trajectory is now a params(step, a, b) table that grows one generation per step, and the update itself is computed in SQL by descending from the current row:
WITH cur AS (SELECT a, b FROM params WHERE step = :k)
SELECT cur.a - lr * AVG(grad(loss, a)) AS a,
cur.b - lr * AVG(grad(loss, b)) AS b
FROM d CROSS JOIN cur
GROUP BY cur.a, cur.bSo each step appends the next generation, and the whole loss curve is a single GROUP BY over the history joined to the data — the trajectory is a relation you can query. One honest caveat: xarray-backed tables are read-only to SQL, so "append a row" is done by re-registering the grown history; a true in-place INSERT would need a mutable table provider. (I also tried WHERE step = (SELECT MAX(step) ...), but scalar subqueries don't convert to Substrait yet, so the loop passes the current step as a literal.)
I moved the runnable demos off this branch to keep it focused on the engine — grad_descent.py (with this change) now lives on the stacked claude/xarray-sql-era5-demo branch, and the MNIST trainer on claude/xarray-sql-mnist-demo stacked above it.
Generated by Claude Code
Document and test that differentiating through SUM/AVG is just linearity: AGG(grad(f, x)) == d/dx AGG(f). Writing grad inside the aggregate composes with SQL scoping (the marker rewrites to plain SQL before the aggregate runs), so it needs no special machinery -- enough to express gradient descent in SQL. Adds tests for SUM/AVG(grad(...)) and an end-to-end gradient-descent convergence test, plus a note in the module overview. The runnable benchmark scripts live on stacked demo branches to keep this feature branch reviewable. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
07a7ff2 to
255413e
Compare
Generalize _table_schemas() to enumerate the catalog instead of only the xarray-registered datasets, so the Substrait rewrite can resolve grad() queries that reference plain DataFusion tables too -- e.g. in-memory MemTables holding model parameters or intermediate results. This makes grad compose with ordinary relational state (a parameter table you INSERT into), not only gridded xarray data. Adds a test differentiating an expression whose coefficient lives in an in-memory table cross-joined to the xarray data. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Expose the autograd engine as a "calculus compiler": differentiate_sql(expr, wrt, columns) parses a SQL scalar expression (parse_sql_expr), differentiates it with the engine, and unparses the derivative back to SQL (expr_to_sql). Where grad(...) rewrites a whole plan via Substrait, this hands back a single derivative expression as text -- usable where the Substrait round-trip can't carry a grad marker, e.g. embedding a precomputed update rule inside a recursive -CTE training loop (Substrait has no recursion). Exposed as xarray_sql. differentiate_sql; covered by a round-trip test. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6
Make grad()/jvp()/vjp() work inside any query shape (recursive CTEs, DML, subqueries) by rewriting the calls as SQL text before planning, rather than round-tripping the logical plan through Substrait (which could not represent those shapes). Closes the gap tracked in #197. XarrayContext.sql() now hands a query containing a marker to the native rewrite_grad_sql, which parses the statement with sqlparser, differentiates each marker call with the existing engine, and renders the derivative back into the SQL in place. Because it runs before planning, every query shape the parser accepts is supported, and the result is ordinary SQL the stock datafusion-python context plans and executes directly. This removes the Substrait round-trip entirely: the datafusion-substrait and prost dependencies, the grad_rewrite/_sql_with_autograd/_table_schemas plumbing, the marker-UDF registration, and the protoc steps in CI. Unlike the FFI alternative, it needs no datafusion fork and no custom datafusion-python wheel. The grad surface is unchanged (same SQL, same results); marker arguments use unqualified column names, matching existing usage, since differentiation is syntactic and runs before binding. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
a4fc101 to
7b1e530
Compare
Differentiable SQL for xarray-sql: write
grad(expr, column)in a query and getits derivative back as an ordinary column, evaluated row-by-row by DataFusion
alongside everything else. Because each row of a table is an independent
evaluation point, differentiating a column expression and letting DataFusion
evaluate it per row is the relational equivalent of
jax.vmap(jax.grad(f))— therows are the batch dimension. This turns SQL into a place you can express
gradients, directional derivatives, and whole training loops.
The surface
Three scalar operations, each returning one value per row (staying in the
long/tidy data model):
grad(expr, column)— the partial derivatived(expr)/d(column).jvp(expr, column, tangent)— forward-mode directional derivative,d(expr)/d(column) * tangent(seed a tangent on an input). A multi-inputdirectional derivative is a sum of
jvpterms.vjp(expr, column, cotangent)— reverse-mode pullback,cotangent * d(expr)/d(column)(seed a cotangent on the output).A full gradient or Jacobian is expressed as several scalar columns rather than a
nested array, which would break the one-value-per-coordinate model:
What you can write:
Higher-order derivatives, by nesting:
grad(grad(sin(x), x), x).Differentiation through an aggregate — it's just linearity, so put the
marker inside the aggregate:
AVG(grad(loss, theta))is the relationald/dθ (Σ loss) / N. Enough to run gradient descent in SQL.gradinside any query shape — recursive CTEs, DML, and subqueries — so atraining loop can live entirely in one declarative query. For example,
Newton's method for √2 driven by
gradin the recursive term:gradover any registered table, not just xarray datasets: plain in-memorytables holding model parameters, and schema-qualified tables (
era5.surface)from mixed-dimension datasets all work.
Also exported:
differentiate_sql(expr, wrt, columns)— a calculus compilerthat differentiates a single expression and hands back the derivative as SQL
text, for when you want the update rule as a string to embed yourself.
How it works
The core is
src/autograd.rs, a symbolic differentiation engine over DataFusionlogical
Exprtrees. The design mirrors JAX's per-primitive rule registry(
defjvpand friends): every node type has a differentiation rule and the chainrule composes them as the tree is walked. A small
0/1-folding simplifierkeeps the output compact, playing the role of JAX's
Zerotangents andadd_tangents.gradandjvpshare the chain rule and differ only in theirleaf rule (a one-hot seed vs. an arbitrary per-input tangent);
vjpscales thepartial by the cotangent.
Rules are implemented for
+,-,*,/, the unary chain rule forsin/cos/tan,asin/acos/atan,exp/ln/log2/log10/sqrt,sinh/cosh/tanh,abs, andpower()with a constant base or exponent. Anunsupported node or function returns a clear
NotImplementederror rather than asilently wrong derivative.
The markers are differentiated as a SQL source-to-source rewrite, before the
query is planned:
XarrayContext.sql()hands a query containing a marker tothe native
rewrite_grad_sql, which parses the statement withsqlparser,differentiates each
grad/jvp/vjpcall, and renders the derivative back intothe query in place. The result is ordinary SQL the stock
datafusion-pythoncontext plans and executes directly. Running before binding is what makes every
query shape the parser accepts — recursive CTEs, DML, subqueries — work
uniformly, with no special cases per plan type.
This needs no DataFusion fork and no custom
datafusion-pythonwheel: it runsagainst the published package. There is no Substrait round-trip and no
protocbuild dependency.
Things to know
grad(y - a*x - b, a)),matching how the differentiation reads its variable. This is the one
consequence of differentiating syntactically, before binding.
gradinside an aggregate (AVG(grad(f, x))), not outside. Thetransposed form
grad(SUM(f), x)is rejected by SQL's own scoping, since theper-row column is gone after aggregation.
Tests
Covered end to end by the Python suite (
tests/test_autograd.py) — derivativeschecked against numpy analytics, gradient descent and a recursive-CTE training
loop converging to their closed-form solutions,
jvp/vjpagreement, multi-input Jacobians, higher-order derivatives, schema-qualified tables — and by the
Rust unit tests for the differentiation rules and the SQL rewrite.
Co-Authored-By: Claude Opus 4.8 noreply@anthropic.com