diff --git a/Cargo.lock b/Cargo.lock index 8b58e75..b6cbcf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2656,9 +2656,9 @@ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80ece43fc6fbed4eb5392ab50c07334d3e577cbf40997ee896fe7af40bba4245" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ "serde_core", "serde_derive", @@ -2666,18 +2666,18 @@ dependencies = [ [[package]] name = "serde_core" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a576275b607a2c86ea29e410193df32bc680303c82f31e275bbfcafe8b33be5" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51e694923b8824cf0e9b382adf0f60d4e05f348f357b38833a3fa5ed7c2ede04" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -2851,18 +2851,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.16" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.16" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", @@ -3375,7 +3375,7 @@ checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "xarray_sql" -version = "0.2.3" +version = "0.3.0" dependencies = [ "arrow", "async-stream", @@ -3385,6 +3385,7 @@ dependencies = [ "futures", "pyo3", "pyo3-build-config", + "sqlparser", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 1dc95bd..21f5b43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ async-stream = "0.3" async-trait = "0.1" datafusion = { version = "52.0.0" } datafusion-ffi = { version = "52.0.0" } +sqlparser = { version = "0.59", features = ["visitor"] } futures = { version = "0.3" } pyo3 = { version = "0.26.0", features = ["extension-module"] } tokio = { version = "1.46.1", features = ["rt"] } diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..5fa3fcd --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,125 @@ +# Benchmarks & demos + +Standalone scripts that exercise xarray-sql against real data. Each declares its +own dependencies inline (PEP 723) and points `xarray_sql` at this checkout, so +they run with no setup: + +```bash +uv run benchmarks/grad_era5.py +``` + +## `grad_era5.py` — differentiable SQL over ARCO-ERA5 + +Demonstrates the autograd feature on a real climate archive +([ARCO-ERA5](https://github.com/google-research/arco-era5), read anonymously +from GCS — needs `gcsfs` and network access). + +The key idea: a physical quantity is written as an **analytic SQL formula** over +ERA5 variables, and `grad(...)` differentiates that formula **symbolically**, +evaluated at every grid cell. Because each row is an independent point, this is +the relational equivalent of `jax.vmap(jax.grad(f))`. It is *not* a finite- +difference spatial gradient — `grad(f(u, v), u)` is the exact partial derivative +of `f`. + +Two worked cases, each checked against an analytic reference: + +| Quantity | SQL | Derivative | Check | +| --- | --- | --- | --- | +| Wind speed | `sqrt(power(u,2) + power(v,2))` | `grad(speed, u) = u/speed` | exact | +| Saturation vapour pressure | `A*exp(B*tc/(tc+C))` | `grad(e_s, T)` | closed-form Clausius-Clapeyron slope | + +Each query round-trips back to an `xarray.Dataset` via `.to_dataset(...)`. + +## `grad_descent.py` — gradient descent as one declarative SQL query + +Fits a line `y ~= a*x + b` by minimising the mean squared error, with the +**entire training loop expressed as a single recursive CTE** — no Python +iteration. Two pieces: + +- **`grad` compiles the update rule.** `xql.differentiate_sql(loss, "a", cols)` + turns the per-row loss into its symbolic derivative *as SQL text* — the + autograd engine as a calculus compiler. +- **A recursive CTE is the optimiser.** `params(step, a, b)` starts at one row + and each recursion appends the next generation, descending along the gradient + (`AVG` of the compiled rule over the data): + + ```sql + WITH RECURSIVE params(step, a, b) AS ( + SELECT 0, 0.0, 0.0 + UNION ALL + SELECT params.step + 1, params.a - lr*AVG(da), params.b - lr*AVG(db) + FROM params CROSS JOIN d WHERE params.step < STEPS + GROUP BY params.step, params.a, params.b) + SELECT * FROM params ORDER BY step + ``` + +So gradient, update, and iteration are all declarative SQL; the trajectory is +the rows of one query. The fit matches numpy's least-squares solution. +Self-contained (no network). + +(Why differentiate to text instead of `grad(...)` inside the recursion? `grad` +needs the Substrait round-trip, and Substrait has no recursion — so a `grad` +marker can't live inside a recursive CTE. Differentiating once to plain SQL +sidesteps that.) + +## `mnist_mlp.py` — an MNIST MLP as relational tensor algebra + +An MLP (196 -> 32 tanh -> 10 softmax on 2x2-pooled 14x14 MNIST) built on one +idea: **a neural net is a chain of tensor contractions (einsums), and an einsum +over coordinate-indexed arrays *is* relational algebra.** + +``` +C[i,k] = sum_j A[i,j] * B[j,k] <=> JOIN A, B ON A.j = B.j + GROUP BY i, k -> SUM(A.val * B.val) +``` + +Contracting a shared index is a join on it followed by a grouped `SUM` over the +indices that survive. In xarray-sql an array indexed by named dims is a table +keyed by those dims, so **the dimension names are the join keys**. + +**The whole network is one relation.** Two moves get there: + +- **Bias folded into the weights (an `nn.Linear`).** Each layer's bias is the + weight of a constant-`1` input, kept as the extra row `inp = width` of the same + weight array — so a layer is a single matrix. +- **A `layer` dimension.** Every layer's weight lives in one + `weight(layer, inp, out)` array, so the forward/backward filter on the `layer` + *column* instead of referencing a table per layer. + +So **the architecture is data**: the whole model is one `xr.Dataset` with a +`layer` dim, registered via `from_dataset`. The dim sizes are the layer widths +and the number of layers is the depth — differing neuron counts are just +differing sizes, NaN-padded in the dense array and dropped on the way in (the +relational form is naturally ragged). Change `WIDTHS` (e.g. `196, 64, 32, 10`) +and the same code trains the deeper net. + +A small `contract()` helper turns an einsum spec into one query, and a single +generic loop trains a net of any shape: + +- **forward** contracts the activation with `weight WHERE layer = L`, adds the + bias row, `tanh` (softmax on the last layer). +- **backward is the *same* operator with indices transposed** — the VJP of a + contraction is a contraction — accumulated into one `gweight` relation, with + `grad(tanh(z), z)` for the only genuinely-calculus part. Even the update is one + query over the whole `weight` relation. Linear algebra is joins; the + derivatives of the nonlinearities are `grad`. + +Everything stays relational: every stage is an inspectable table (`a1`, `delta2`, +`gweight`, …); the only hand-written gradient is softmax + cross-entropy's +`delta = softmax - onehot`. Even the training metrics are a table — each logged +step appends a `(step, loss, train_acc, test_acc)` row to a `metrics` relation +rather than a Python list (NN training produces a lot of such data; it belongs in +rows). Evaluation is SQL too (a forward pass + `ROW_NUMBER()` argmax), and the +trained model, predictions, and metrics all come **back out as xarray** via +`to_dataset`. Reaches ~83% test accuracy over 60 steps. Downloads MNIST on first +run. + +This is not a numpy replacement — relational matmul carries join overhead a BLAS +inner product doesn't. What it buys is a fully declarative, inspectable pipeline +whose data side is chunked xarray (parallel over the batch, larger-than-memory). +The *outer* training loop stays in Python because the steps must be materialised +between iterations: a multi-layer net can't be one recursive CTE (the recursive +relation may be referenced only once, but the weights are used several times per +step), and unrolling the steps as non-recursive CTEs blows up exponentially +(DataFusion inlines CTEs). The thin loop does exactly that materialisation; all +the maths stays in SQL. diff --git a/benchmarks/grad_descent.py b/benchmarks/grad_descent.py new file mode 100644 index 0000000..daff207 --- /dev/null +++ b/benchmarks/grad_descent.py @@ -0,0 +1,115 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "xarray_sql", +# "xarray", +# "numpy", +# ] +# +# [tool.uv.sources] +# xarray_sql = { path = "..", editable = true } +# /// +"""Gradient descent as a single declarative SQL query. + +Fits a line ``y ~= a*x + b`` by minimising the mean squared error — with the +**entire training loop expressed as one recursive CTE**, no Python iteration. + +Two pieces: + +1. **grad compiles the update rule.** ``differentiate_sql`` turns the per-row + loss into the symbolic derivative *as SQL text* — the autograd engine acting + as a calculus compiler: + + da = differentiate_sql("(y-(a*x+b))^2", "a") # -> "-2*((a*x+b)-y)*x", etc. + +2. **A recursive CTE is the optimiser.** ``params(step, a, b)`` starts at one + row and each recursion appends the next generation, descending along the + gradient (``AVG`` of the compiled rule over the data): + + params.a - lr * AVG(da) + + So the whole loop — gradient, update, and iteration — is declarative SQL; + the optimisation trajectory is the rows of one query. + +Why two pieces instead of ``grad(...)`` directly inside the recursion? ``grad`` +needs the Substrait round-trip, and Substrait has no recursion — so ``grad`` +can't live inside a recursive CTE (tracked in #194 / a follow-up). Differentiating +once to plain SQL sidesteps that: the recursive query contains no ``grad`` marker. + +Run standalone: + + uv run benchmarks/grad_descent.py +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr + +import xarray_sql as xql + +# Per-row loss r^2 with residual r = y - (a*x + b), over columns a, b, x, y. +RESIDUAL = "(y - (a * x + b))" +LOSS = f"{RESIDUAL} * {RESIDUAL}" +COLUMNS = ["a", "b", "x", "y"] +LR = 0.4 +STEPS = 200 + + +def main() -> None: + rng = np.random.default_rng(0) + n = 500 + x = rng.uniform(0.0, 1.0, n) + a_true, b_true = 2.0, -1.0 + y = a_true * x + b_true + rng.normal(0.0, 0.01, n) + + ctx = xql.XarrayContext() + ctx.from_dataset( + "d", + xr.Dataset( + {"x": (("i",), x), "y": (("i",), y)}, coords={"i": np.arange(n)} + ), + chunks={"i": n}, + ) + + # grad compiles the per-row update rule to SQL, once. + da = xql.differentiate_sql(LOSS, "a", COLUMNS) + db = xql.differentiate_sql(LOSS, "b", COLUMNS) + print(f"d(loss)/da = {da}") + print(f"d(loss)/db = {db}\n") + + # The entire training loop is one declarative recursive query: each step + # appends the next generation, descending along the SQL-computed gradient. + trajectory = ctx.sql( + f""" + WITH RECURSIVE params(step, a, b) AS ( + SELECT 0 AS step, CAST(0.0 AS DOUBLE) AS a, CAST(0.0 AS DOUBLE) AS b + UNION ALL + SELECT params.step + 1 AS step, + params.a - {LR} * AVG({da}) AS a, + params.b - {LR} * AVG({db}) AS b + FROM params CROSS JOIN d + WHERE params.step < {STEPS} + GROUP BY params.step, params.a, params.b + ) + SELECT step, a, b FROM params ORDER BY step + """ + ).to_pandas() + + print("trajectory (every 40th generation):") + print(trajectory.iloc[::40].to_string(index=False)) + + a, b = float(trajectory["a"].iloc[-1]), float(trajectory["b"].iloc[-1]) + a_ols, b_ols = np.polyfit(x, y, 1) + print( + f"\nSQL gradient descent: a={a:.4f} b={b:.4f} ({len(trajectory)} generations)" + ) + print(f"least-squares (numpy): a={a_ols:.4f} b={b_ols:.4f}") + assert abs(a - a_ols) < 1e-2 and abs(b - b_ols) < 1e-2 + print( + "\nOK: a single recursive-CTE query fit the line to the OLS solution." + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/grad_era5.py b/benchmarks/grad_era5.py new file mode 100644 index 0000000..866f066 --- /dev/null +++ b/benchmarks/grad_era5.py @@ -0,0 +1,171 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "xarray_sql", +# "xarray[io]", +# "gcsfs", +# "numpy", +# ] +# +# [tool.uv.sources] +# xarray_sql = { path = "..", editable = true } +# /// +"""Differentiable SQL over ARCO-ERA5. + +A minimal demonstration of xarray-sql's autograd: take a real climate archive +(ARCO-ERA5, read anonymously from GCS), express a physical quantity as an +*analytic* SQL formula over its variables, and let ``grad(...)`` differentiate +that formula symbolically — evaluated per grid cell, which is the relational +equivalent of ``jax.vmap(jax.grad(f))`` (each row is an independent point). + +Note this is *symbolic* differentiation of an expression, not a finite- +difference spatial gradient: ``grad(f(u, v), u)`` is the exact partial +derivative of the formula ``f``, evaluated at every cell's values. + +Two cases: + +1. Wind-speed magnitude ``speed = sqrt(u^2 + v^2)``. Its sensitivity to the + eastward wind is ``d(speed)/du = u / speed`` — checked exactly. + +2. Saturation vapour pressure ``e_s(T)`` (August-Roche-Magnus form of the + Clausius-Clapeyron relation). ``d(e_s)/dT`` governs how fast the atmosphere's + moisture capacity grows with temperature — checked against the closed-form + slope. + +Run standalone (builds the local extension on first use): + + uv run benchmarks/grad_era5.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import xarray as xr + +import xarray_sql as xql + +ARCO_ERA5 = ( + "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" +) + +# ERA5 variable names start with a digit, so they must be double-quoted in SQL. +U = '"10m_u_component_of_wind"' +V = '"10m_v_component_of_wind"' +T = '"2m_temperature"' + + +def load_era5_block() -> xr.Dataset: + """Open ARCO-ERA5 and pull one timestamp over a small region. + + Lazy open of the whole archive; only the requested block is read. We keep + it to a few thousand cells so the demo runs in seconds. + """ + full = xr.open_zarr( + ARCO_ERA5, chunks=None, storage_options={"token": "anon"} + ) + block = ( + full[ + [ + "10m_u_component_of_wind", + "10m_v_component_of_wind", + "2m_temperature", + ] + ] + .sel(time="2020-01-01T00") + # A ~North-America box (index-based to avoid lat-orientation pitfalls). + .isel(latitude=slice(120, 200), longitude=slice(900, 1000)) + .load() + ) + # One partition, so a SQL `ORDER BY latitude DESC` survives the round-trip + # back to xarray (across multiple partitions, to_dataset reconstructs + # coordinates in ascending order regardless of ORDER BY). + return block.chunk() + + +def wind_speed_sensitivity(ctx: xql.XarrayContext, ref: xr.Dataset) -> None: + """grad(sqrt(u^2 + v^2)) checked against the exact u / speed, v / speed.""" + speed = f"sqrt(power({U}, 2) + power({V}, 2))" + out = ctx.sql( + f""" + SELECT + latitude, + longitude, + {speed} AS wind_speed, + grad({speed}, {U}) AS d_speed_d_u, + grad({speed}, {V}) AS d_speed_d_v + FROM era5 + ORDER BY latitude DESC, longitude + """ + ).to_dataset(dims=["latitude", "longitude"]) + + u = ref["10m_u_component_of_wind"] + v = ref["10m_v_component_of_wind"] + speed_ref = np.sqrt(u**2 + v**2) + + xr.testing.assert_allclose( + out["wind_speed"], speed_ref.rename("wind_speed") + ) + xr.testing.assert_allclose( + out["d_speed_d_u"], (u / speed_ref).rename("d_speed_d_u") + ) + xr.testing.assert_allclose( + out["d_speed_d_v"], (v / speed_ref).rename("d_speed_d_v") + ) + print(" wind-speed sensitivity matches u/|w|, v/|w| exactly") + print(out) + + +def clausius_clapeyron(ctx: xql.XarrayContext, ref: xr.Dataset) -> None: + """grad(e_s(T)) checked against the closed-form Clausius-Clapeyron slope.""" + # August-Roche-Magnus: e_s(T) = A * exp(B * tc / (tc + C)), tc = T - 273.15. + a, b, c = 6.1094, 17.625, 243.04 + tc = f"({T} - 273.15)" + es = f"{a} * exp({b} * {tc} / ({tc} + {c}))" + out = ctx.sql( + f""" + SELECT + latitude, + longitude, + {es} AS e_s, + grad({es}, {T}) AS de_s_dt + FROM era5 + ORDER BY latitude DESC, longitude + """ + ).to_dataset(dims=["latitude", "longitude"]) + + # Reference in float64 (the columns are float32): the exact derivative is + # d(e_s)/dT = e_s * B*C / (tc + C)^2. + temp = ref["2m_temperature"].astype("float64") + tc_ref = temp - 273.15 + es_ref = a * np.exp(b * tc_ref / (tc_ref + c)) + des_dt_ref = es_ref * (b * c) / (tc_ref + c) ** 2 + + xr.testing.assert_allclose(out["e_s"], es_ref.rename("e_s"), rtol=1e-5) + xr.testing.assert_allclose( + out["de_s_dt"], des_dt_ref.rename("de_s_dt"), rtol=1e-5 + ) + print(" d(e_s)/dT matches the closed-form Clausius-Clapeyron slope") + print(out) + + +def main() -> None: + t0 = time.time() + ds = load_era5_block() + print(f"loaded ERA5 block {dict(ds.sizes)} in {time.time() - t0:.1f}s") + + ctx = xql.XarrayContext() + ctx.from_dataset("era5", ds) + + print("\n== wind-speed sensitivity: grad(sqrt(u^2 + v^2)) ==") + wind_speed_sensitivity(ctx, ds) + + print("\n== Clausius-Clapeyron: grad(e_s(T)) ==") + clausius_clapeyron(ctx, ds) + + print("\nOK: symbolic SQL gradients match the analytic references.") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/mnist_mlp.py b/benchmarks/mnist_mlp.py new file mode 100644 index 0000000..fe31cee --- /dev/null +++ b/benchmarks/mnist_mlp.py @@ -0,0 +1,459 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "xarray_sql", +# "xarray", +# "numpy", +# ] +# +# [tool.uv.sources] +# xarray_sql = { path = "..", editable = true } +# /// +"""Train an MNIST MLP as relational tensor algebra — the whole net is one table. + +A neural network is a chain of **tensor contractions** (einsums), and an einsum +over coordinate-indexed arrays *is* relational algebra: + + C[i,k] = sum_j A[i,j] * B[j,k] <=> JOIN A, B ON A.j = B.j + GROUP BY i, k -> SUM(A.val * B.val) + +Contracting a shared index is a join on it followed by a grouped SUM. In +xarray-sql an array indexed by named dims is a table keyed by those dims, so the +dim names are the join keys. + +Two simplifications make the whole model **one relation**: + +* **Bias folded into the weights (an ``nn.Linear``).** Each layer's bias is the + weight of a constant-``1`` input, stored as the extra row ``inp = width`` in the + same weight array — so a layer is a single matrix. The forward reads the matmul + rows and that bias row from the one relation (no separate bias table). +* **A ``layer`` dimension.** Every layer's weight lives in one + ``weight(layer, inp, out)`` array, so the forward/backward filter on the + ``layer`` *column* instead of referencing a table per layer. The whole network + is one ``xr.Dataset`` registered with ``from_dataset``; differing layer widths + are NaN-padded in the dense array and dropped on the way in (the relational + form is naturally ragged). The architecture is data — change ``WIDTHS`` and the + same code trains a different net. + +A single ``contract()`` and one generic loop train a net of any depth: forward +contracts the activation with ``weight WHERE layer = L``; backward is the same +contraction transposed (the VJP of a contraction is a contraction), with +``grad(tanh(z), z)`` for the one local-derivative step. Even the weight update is +one query over the whole ``weight`` relation. Linear algebra is joins; the +derivatives of the nonlinearities are ``grad``. + +Everything stays relational and inspectable: activations, errors, gradients, and +the per-step training metrics are all tables; the trained model, predictions, and +metrics come back out as ``xarray`` via ``to_dataset``. + +This is not a numpy replacement — the long form puts one matrix entry per row, so +the matmul-as-join carries overhead a BLAS inner product doesn't. What it buys is +a declarative, inspectable pipeline whose data side is chunked xarray (parallel +over the batch, larger-than-memory). Recovering BLAS speed would mean storing +dense *tiles* per cell and contracting them with a tile-matmul — a future +direction, not done here. + +Run standalone (builds the local extension on first use): + + uv run benchmarks/mnist_mlp.py +""" + +from __future__ import annotations + +import gzip +import struct +import tempfile +import time +import urllib.request +from pathlib import Path + +import numpy as np +import pyarrow as pa +import xarray as xr + +import xarray_sql as xql + +MIRROR = "https://storage.googleapis.com/cvdf-datasets/mnist" +CACHE = Path(tempfile.gettempdir()) / "mnist-xql" + +# The architecture, as data: layer widths. 196 pooled pixels -> 32 tanh -> 10. +# Add an entry (e.g. 196, 64, 32, 10) and the same code trains the deeper net. +WIDTHS = [196, 32, 10] +DEPTH = len(WIDTHS) - 1 # number of weight layers +N_TRAIN, N_TEST = 1000, 500 +LR, STEPS, CHUNK = 0.5, 60, 250 + + +# --- the one idea: a tensor contraction is a relational query ----------------- + + +def contract(spec: str, left: str, right: str) -> str: + """An einsum over two coordinate-indexed relations, as one SQL query. + + ``contract("sample,inp * inp,out -> sample,out", "x", w)`` joins ``x`` and + ``w`` on their shared dim ``inp``, groups by the output dims, and sums the + product of values — a matmul. ``left`` / ``right`` are table names or + parenthesised subqueries; each exposes its dims plus a ``val`` column. + Indices in the inputs but not the output are contracted (summed over). + """ + spec = spec.replace(" ", "") + lhs, out = spec.split("->") + da, db = (operand.split(",") for operand in lhs.split("*")) + out_dims = out.split(",") + shared = [d for d in da if d in db] + join = ( + f"JOIN {right} r ON " + " AND ".join(f"l.{d} = r.{d}" for d in shared) + if shared + else f"CROSS JOIN {right} r" + ) + pick = ", ".join(f"{'l' if d in da else 'r'}.{d} AS {d}" for d in out_dims) + return ( + f"SELECT {pick}, SUM(l.val * r.val) AS val " + f"FROM {left} l {join} GROUP BY {', '.join(out_dims)}" + ) + + +def register_tensor( + ctx: xql.XarrayContext, + name: str, + arr: np.ndarray, + dims: tuple[str, ...], + var: str = "val", + chunk: int | None = None, +) -> None: + """Register a numpy array as a relation, the array-relational way: wrap it as + an ``xr.Dataset`` whose named dims become the table's key columns, then hand + it to ``from_dataset``. A tensor is an array at the edge and a relation + inside; ``from_dataset`` is the bridge, and the dims become the join keys.""" + arr = np.asarray(arr, dtype=np.float64) + ds = xr.Dataset( + {var: (dims, arr)}, + coords={d: np.arange(n) for d, n in zip(dims, arr.shape)}, + ) + ctx.from_dataset(name, ds, chunks={dims[0]: chunk or arr.shape[0]}) + + +class Tensors: + """A step rewrites a handful of relations; ``put`` materialises a query as a + named table (the stages of the forward/backward pass).""" + + def __init__(self, ctx: xql.XarrayContext): + self.ctx = ctx + + def put(self, name: str, sql: str) -> None: + batches = self.ctx.sql(sql).collect() + # UNION branches can yield batches that differ only in field nullability; + # cast them all to one (nullable) schema so registration accepts them. + if batches: + target = pa.schema( + [pa.field(f.name, f.type) for f in batches[0].schema] + ) + batches = [b.cast(target) for b in batches] + if self.ctx.table_exist(name): + self.ctx.deregister_table(name) + self.ctx.register_record_batches(name, [batches]) + + +# --- the model: one weight relation, bias folded in --------------------------- + + +def build_model(rng: np.random.Generator) -> xr.Dataset: + """The whole network as one ``weight(layer, inp, out)`` Dataset. + + Layer ``L`` connects ``WIDTHS[L]`` inputs (plus a constant-1 bias input, index + ``WIDTHS[L]``) to ``WIDTHS[L+1]`` outputs. The dense array is NaN-padded to the + widest layer; the padding is dropped when the relation is seeded, so the live + table is the ragged set of real weights. + """ + max_in = max(WIDTHS[layer] + 1 for layer in range(DEPTH)) + max_out = max(WIDTHS[layer + 1] for layer in range(DEPTH)) + arr = np.full((DEPTH, max_in, max_out), np.nan) + for layer in range(DEPTH): + n_in, n_out = WIDTHS[layer], WIDTHS[layer + 1] + arr[layer, :n_in, :n_out] = rng.standard_normal((n_in, n_out)) * 0.1 + arr[layer, n_in, :n_out] = ( + 0.0 # bias row (weight of the constant input) + ) + return xr.Dataset( + {"weight": (("layer", "inp", "out"), arr)}, + coords={ + "layer": np.arange(DEPTH), + "inp": np.arange(max_in), + "out": np.arange(max_out), + }, + ) + + +def matmul_rows(layer: int) -> str: + """The matmul (non-bias) rows of one layer's weight, as a subquery.""" + return f"(SELECT inp, out, val FROM weight WHERE layer = {layer} AND inp < {WIDTHS[layer]})" + + +def bias_row(layer: int) -> str: + """The bias row (inp = width) of one layer's weight, as a subquery over out.""" + return f"(SELECT out, val FROM weight WHERE layer = {layer} AND inp = {WIDTHS[layer]})" + + +# --- the network, as contractions (generic over depth) ------------------------ + + +def forward(t: Tensors, inp: str = "x") -> None: + """Forward pass from ``inp``: per layer, contract with the matmul rows and add + the bias row (both from the one weight relation), then tanh on the hidden + layers. Leaves ``a{L}.z`` for backprop and the output ``logits``.""" + prev = inp + for layer in range(DEPTH): + zc = contract( + "sample,inp * inp,out -> sample,out", prev, matmul_rows(layer) + ) + if layer < DEPTH - 1: + t.put( + f"a{layer + 1}", + f"""WITH zc AS ({zc}) + SELECT zc.sample, zc.out AS inp, zc.val + b.val AS z, + tanh(zc.val + b.val) AS val + FROM zc JOIN {bias_row(layer)} b ON zc.out = b.out""", + ) + prev = f"a{layer + 1}" + else: + t.put( + "logits", + f"""WITH zc AS ({zc}) + SELECT zc.sample, zc.out, zc.val + b.val AS z + FROM zc JOIN {bias_row(layer)} b ON zc.out = b.out""", + ) + + +def softmax_delta_sql() -> str: + """Output error delta = softmax(logits) - onehot(label). The one hand-derived + rule: softmax couples classes through a per-sample normaliser an aggregate + grad() does not cross.""" + return """ + WITH m AS (SELECT sample, MAX(z) AS m FROM logits GROUP BY sample), + e AS (SELECT logits.sample, logits.out, exp(logits.z - m.m) AS e + FROM logits JOIN m ON logits.sample = m.sample), + s AS (SELECT sample, SUM(e) AS s FROM e GROUP BY sample) + SELECT e.sample, e.out, + e.e / s.s - CASE WHEN e.out = y.label THEN 1.0 ELSE 0.0 END AS val + FROM e JOIN s ON e.sample = s.sample JOIN y ON y.sample = e.sample""" + + +def train_step(t: Tensors) -> None: + """Forward, backward (the same contraction transposed), one SGD update.""" + forward(t) + t.put(f"delta{DEPTH}", softmax_delta_sql()) + # Backward: gradients are contractions over the batch, accumulated into one + # gweight relation tagged by layer. delta{L} is the error at layer L's units. + for layer in reversed(range(DEPTH)): + a_in = "x" if layer == 0 else f"a{layer}" + # matmul gradient (mean over batch) + bias gradient (mean of delta), + # both tagged with this layer, as rows of one gweight relation. + gw = contract( + "sample,inp * sample,out -> inp,out", a_in, f"delta{layer + 1}" + ) + rows = ( + f"SELECT CAST({layer} AS BIGINT) AS layer, inp, out, " + f"val / {N_TRAIN} AS val FROM ({gw}) " + f"UNION ALL " + f"SELECT CAST({layer} AS BIGINT) AS layer, " + f"CAST({WIDTHS[layer]} AS BIGINT) AS inp, out, AVG(val) AS val " + f"FROM delta{layer + 1} GROUP BY out" + ) + t.put( + "gweight", + f"SELECT * FROM gweight UNION ALL {rows}" + if t.ctx.table_exist("gweight") + else rows, + ) + if layer > 0: # propagate the cotangent, scaled by the local derivative + dc = contract( + "sample,out * inp,out -> sample,inp", + f"delta{layer + 1}", + matmul_rows(layer), + ) + t.put( + f"delta{layer}", + f"""WITH dc AS ({dc}) + SELECT dc.sample, dc.inp AS out, + dc.val * grad(tanh(a{layer}.z), a{layer}.z) AS val + FROM dc JOIN a{layer} + ON dc.sample = a{layer}.sample AND dc.inp = a{layer}.inp""", + ) + # One SGD update for the whole network: weight <- weight - lr * gweight. + t.put( + "weight", + f"""SELECT w.layer, w.inp, w.out, w.val - {LR} * g.val AS val + FROM weight w JOIN gweight g + ON w.layer = g.layer AND w.inp = g.inp AND w.out = g.out""", + ) + t.ctx.deregister_table("gweight") + + +def accuracy(t: Tensors, inp: str, lab: str) -> float: + """A forward pass over ``inp`` + argmax, compared to ``lab`` — all in SQL.""" + forward(t, inp) + return float( + t.ctx.sql( + f"""WITH pred AS ( + SELECT sample, out, + ROW_NUMBER() OVER (PARTITION BY sample ORDER BY z DESC) AS rk + FROM logits) + SELECT AVG(CASE WHEN p.out = l.label THEN 1.0 ELSE 0.0 END) AS acc + FROM pred p JOIN {lab} l ON p.sample = l.sample WHERE p.rk = 1""" + ).to_pandas()["acc"][0] + ) + + +def record_metrics(t: Tensors, step: int) -> None: + """Append a (step, loss, train_acc, test_acc) row to the ``metrics`` table. + + NN training emits a lot of data — loss curves, per-step accuracies — and like + everything else here it lives as rows in a relation, grown each time, not a + Python list. Read it back at the end as a tidy ``(step,)`` xarray. + """ + train = accuracy(t, "x", "y") # leaves the training forward in `logits` + loss = float( + t.ctx.sql( + """WITH m AS (SELECT sample, MAX(z) AS m FROM logits GROUP BY sample), + e AS (SELECT logits.sample, logits.out, exp(logits.z - m.m) AS e + FROM logits JOIN m ON logits.sample = m.sample), + s AS (SELECT sample, SUM(e) AS s FROM e GROUP BY sample) + SELECT -AVG(ln(e.e / s.s)) AS loss + FROM e JOIN s ON e.sample = s.sample JOIN y ON y.sample = e.sample + WHERE e.out = y.label""" + ).to_pandas()["loss"][0] + ) + test = accuracy(t, "x_te", "y_te") + row = ( + f"SELECT CAST({step} AS BIGINT) AS step, CAST({loss} AS DOUBLE) AS loss, " + f"CAST({train} AS DOUBLE) AS train_acc, CAST({test} AS DOUBLE) AS test_acc" + ) + t.put( + "metrics", + f"SELECT * FROM metrics UNION ALL {row}" + if t.ctx.table_exist("metrics") + else row, + ) + print( + f"step {step:2d}: loss {loss:.3f} train {train:.3f} test {test:.3f}" + ) + + +# --- MNIST loading ------------------------------------------------------------ + + +def _download(url: str, dest: Path, tries: int = 5) -> None: + last = None + for _ in range(tries): + try: + with urllib.request.urlopen(url, timeout=120) as resp: + data = resp.read() + if len(data) < 1024: + raise OSError(f"suspiciously small download: {len(data)} bytes") + dest.write_bytes(data) + return + except Exception as exc: # noqa: BLE001 - retry any transient failure + last = exc + raise OSError(f"failed to download {url}: {last}") + + +def _read_idx(path: Path) -> np.ndarray: + with gzip.open(path, "rb") as f: + (magic,) = struct.unpack(">I", f.read(4)) + if magic == 2051: # images + n, r, c = struct.unpack(">III", f.read(12)) + return np.frombuffer(f.read(), np.uint8).reshape(n, r, c) + struct.unpack(">I", f.read(4)) # labels: skip the count + return np.frombuffer(f.read(), np.uint8) + + +def load_mnist(): + CACHE.mkdir(exist_ok=True) + files = { + "images": "train-images-idx3-ubyte.gz", + "labels": "train-labels-idx1-ubyte.gz", + } + paths = {} + for key, name in files.items(): + dest = CACHE / name + if not dest.exists(): + _download(f"{MIRROR}/{name}", dest) + paths[key] = dest + imgs = _read_idx(paths["images"]).astype(np.float32) / 255.0 + labs = _read_idx(paths["labels"]).astype(np.int64) + side = WIDTHS[0] # pooled pixels per image + pool = 28 // int(round(side**0.5)) # 2 for 196 pixels (14x14) + k = 28 // pool + pooled = ( + imgs.reshape(-1, k, pool, k, pool).mean(axis=(2, 4)).reshape(-1, side) + ) + rng = np.random.default_rng(0) + idx = rng.permutation(len(pooled)) + tr, te = idx[:N_TRAIN], idx[N_TRAIN : N_TRAIN + N_TEST] + return pooled[tr], labs[tr], pooled[te], labs[te] + + +# --- driver ------------------------------------------------------------------- + + +def main() -> None: + Xtr, ytr, Xte, yte = load_mnist() + print(f"MNIST: train {Xtr.shape}, test {Xte.shape} architecture {WIDTHS}") + + ctx = xql.XarrayContext() + # The whole model is one Dataset with a layer dim; from_dataset gives one + # `net` table, and seeding drops the NaN padding to the live `weight` relation. + rng = np.random.default_rng(1) + model = build_model(rng) + ctx.from_dataset( + "net", + model, + chunks={ + "layer": DEPTH, + "inp": model.sizes["inp"], + "out": model.sizes["out"], + }, + ) + t = Tensors(ctx) + t.put( + "weight", + "SELECT layer, inp, out, weight AS val FROM net WHERE weight IS NOT NULL", + ) + + # Inputs and labels (the bias is in the weight relation, so no augmentation). + register_tensor(ctx, "x", Xtr, ("sample", "inp"), chunk=CHUNK) + register_tensor(ctx, "y", ytr, ("sample",), var="label") + register_tensor(ctx, "x_te", Xte, ("sample", "inp")) + register_tensor(ctx, "y_te", yte, ("sample",), var="label") + + print(f"init: test acc {accuracy(t, 'x_te', 'y_te'):.3f}") + t0 = time.time() + for step in range(STEPS): + train_step(t) + if step % 10 == 0 or step == STEPS - 1: + record_metrics(t, step) + dt = time.time() - t0 + + # The trained model, predictions, and metrics all come back out as xarray. + weights = ( + ctx.sql("SELECT layer, inp, out, val FROM weight") + .to_dataset(dims=["layer", "inp", "out"]) + .rename({"val": "weight"}) + ) + metrics = ctx.sql("SELECT * FROM metrics ORDER BY step").to_dataset( + dims=["step"] + ) + + print( + f"\ntrained a {WIDTHS} MLP as relational tensor algebra in {dt:.0f}s: " + f"test accuracy {accuracy(t, 'x_te', 'y_te'):.3f}." + ) + print( + f"the whole model is one weight relation -> xarray " + f"{dict(weights.sizes)}; metrics are a table -> xarray " + f"{list(metrics.data_vars)} over {dict(metrics.sizes)}." + ) + + +if __name__ == "__main__": + main() diff --git a/src/autograd.rs b/src/autograd.rs new file mode 100644 index 0000000..c729ca3 --- /dev/null +++ b/src/autograd.rs @@ -0,0 +1,838 @@ +//! Symbolic differentiation of DataFusion logical [`Expr`] trees. +//! +//! This is the autograd engine for xarray-sql. Given an [`Expr`] and the name +//! of a column to differentiate with respect to, [`differentiate`] returns a +//! new [`Expr`] for the (symbolic) partial derivative, built entirely from +//! ordinary DataFusion expressions so the result can be planned and evaluated +//! by DataFusion like any other SQL expression. +//! +//! ## Design +//! +//! The approach mirrors JAX's per-primitive rule registry (`defjvp` and +//! friends in `jax/_src/interpreters/ad.py`): every expression node has a +//! differentiation rule, and the chain rule composes them as the tree is +//! walked. Because each row of a relational table is an independent evaluation +//! point, differentiating a column expression and letting DataFusion evaluate +//! it row-by-row is the moral equivalent of `jax.vmap(jax.grad(f))` — the rows +//! *are* the batch dimension. +//! +//! A small simplifier folds the `0`/`1` constants that differentiation +//! produces in abundance (e.g. `d/dx (c) = 0`, `d/dx (x) = 1`), keeping output +//! expressions compact. This plays the role of JAX's `Zero` tangents and +//! `add_tangents`: a `0` derivative short-circuits products and drops out of +//! sums, and a `1` factor drops out of products. +//! +//! ## Surface +//! +//! Three scalar operations, all rewritten away before execution: +//! +//! * `grad(expr, column)` — the partial derivative `d(expr)/d(column)`. +//! * `jvp(expr, column, tangent)` — forward-mode directional derivative, +//! `d(expr)/d(column) * tangent` (seed a tangent on an input). +//! * `vjp(expr, column, cotangent)` — reverse-mode pullback, +//! `cotangent * d(expr)/d(column)` (seed a cotangent on the output). +//! +//! All three return a scalar per row, staying in the long/tidy data model. A +//! full gradient or Jacobian is expressed as several scalar columns (e.g. +//! `grad(f, x) AS dfdx, grad(f, y) AS dfdy`) rather than a nested array, which +//! would break the one-value-per-coordinate model. +//! +//! Calls nest, giving higher-order derivatives for free: the rewrite walks +//! bottom-up, so the inner call in `grad(grad(f, x), x)` is differentiated +//! first and the outer call differentiates that result. +//! +//! Differentiation through an aggregate is just linearity and needs no special +//! handling: write the `grad` *inside* the aggregate, e.g. `SUM(grad(f, x))` or +//! `AVG(grad(loss, theta))`. Because the marker is rewritten to plain SQL +//! before the aggregate runs (and the column is in scope there), this is the +//! relational `d/dθ Σ f = Σ ∂f/∂θ` — enough to run gradient descent in SQL. +//! (The transposed form `grad(SUM(f), x)` is rejected by SQL's own scoping, +//! since `x` is gone after aggregation.) + +#![allow(dead_code)] + +use std::any::Any; +use std::collections::HashMap; +use std::f64::consts::{LN_10, LN_2}; +use std::ops::ControlFlow; +use std::sync::Arc; + +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::common::{DFSchema, DataFusionError, Result, ScalarValue, TableReference}; +use datafusion::functions::math::expr_fn; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::logical_expr::{ + lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::prelude::SessionContext; +use datafusion::sql::unparser::expr_to_sql; +use sqlparser::ast::{Expr as SqlExpr, Visit, VisitMut, Visitor, VisitorMut}; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::Parser; + +// --------------------------------------------------------------------------- +// Constant helpers and the 0/1-folding builders +// --------------------------------------------------------------------------- + +/// The constant `0.0`, used as the derivative of anything not depending on the +/// differentiation variable. +fn zero() -> Expr { + lit(0.0_f64) +} + +/// The constant `1.0`, used as the derivative of the differentiation variable. +fn one() -> Expr { + lit(1.0_f64) +} + +/// Interpret a [`ScalarValue`] as `f64` if it is a (non-null) numeric scalar. +fn scalar_as_f64(sv: &ScalarValue) -> Option { + match sv { + ScalarValue::Float64(Some(v)) => Some(*v), + ScalarValue::Float32(Some(v)) => Some(*v as f64), + ScalarValue::Int64(Some(v)) => Some(*v as f64), + ScalarValue::Int32(Some(v)) => Some(*v as f64), + ScalarValue::Int16(Some(v)) => Some(*v as f64), + ScalarValue::Int8(Some(v)) => Some(*v as f64), + ScalarValue::UInt64(Some(v)) => Some(*v as f64), + ScalarValue::UInt32(Some(v)) => Some(*v as f64), + ScalarValue::UInt16(Some(v)) => Some(*v as f64), + ScalarValue::UInt8(Some(v)) => Some(*v as f64), + _ => None, + } +} + +/// Return the constant `f64` value of a literal expression, if it is one. +fn as_const(e: &Expr) -> Option { + match e { + Expr::Literal(sv, _) => scalar_as_f64(sv), + _ => None, + } +} + +/// True if the expression is a numeric literal exactly equal to zero. +fn is_zero(e: &Expr) -> bool { + matches!(as_const(e), Some(v) if v == 0.0) +} + +/// True if the expression is a numeric literal exactly equal to one. +fn is_one(e: &Expr) -> bool { + matches!(as_const(e), Some(v) if v == 1.0) +} + +fn binary(left: Expr, op: Operator, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) +} + +/// `a + b`, dropping a zero operand. +fn add(a: Expr, b: Expr) -> Expr { + if is_zero(&a) { + b + } else if is_zero(&b) { + a + } else { + binary(a, Operator::Plus, b) + } +} + +/// `a - b`, dropping a zero right operand and turning `0 - b` into `-b`. +fn sub(a: Expr, b: Expr) -> Expr { + if is_zero(&b) { + a + } else if is_zero(&a) { + neg(b) + } else { + binary(a, Operator::Minus, b) + } +} + +/// `a * b`, folding `0 * _ = 0` and `1 * b = b` (and the mirror cases). +fn mul(a: Expr, b: Expr) -> Expr { + if is_zero(&a) || is_zero(&b) { + zero() + } else if is_one(&a) { + b + } else if is_one(&b) { + a + } else { + binary(a, Operator::Multiply, b) + } +} + +/// `a / b`, folding `0 / _ = 0` and `a / 1 = a`. +fn div(a: Expr, b: Expr) -> Expr { + if is_zero(&a) { + zero() + } else if is_one(&b) { + a + } else { + binary(a, Operator::Divide, b) + } +} + +/// `-a`, folding `-0 = 0`. +fn neg(a: Expr) -> Expr { + if is_zero(&a) { + zero() + } else { + Expr::Negative(Box::new(a)) + } +} + +/// `e * e`. +fn square(e: Expr) -> Expr { + mul(e.clone(), e) +} + +// --------------------------------------------------------------------------- +// The differentiation engine (forward-mode linearization) +// --------------------------------------------------------------------------- + +/// A *leaf rule*: the tangent of a column, i.e. the seed assigned to each input +/// during forward-mode differentiation. +/// +/// `grad` uses a one-hot leaf (`1` for the differentiation variable, `0` +/// otherwise); `jvp` uses an arbitrary seed per input. Everything above the +/// leaves — the chain rule — is shared. +type Leaf<'a> = dyn Fn(&str) -> Expr + 'a; + +/// Linearize `expr`: push tangents from the leaves (per `leaf`) up through the +/// expression via the chain rule, returning the tangent of `expr`. +/// +/// This is forward-mode automatic differentiation. `differentiate` (a single +/// partial derivative) and `jvp` (a directional derivative) are both thin +/// wrappers that only differ in their leaf rule. Returns a +/// [`DataFusionError::NotImplemented`] for nodes or functions without a rule, +/// so callers surface a clear error rather than a silently-wrong derivative. +fn linearize(expr: &Expr, leaf: &Leaf) -> Result { + match expr { + // The leaf rule decides a column's tangent. + Expr::Column(c) => Ok(leaf(&c.name)), + + // Constants have zero tangent. + Expr::Literal(_, _) => Ok(zero()), + + // An alias is transparent; the surrounding query re-applies any naming. + Expr::Alias(a) => linearize(&a.expr, leaf), + + // A numeric cast is (locally) linear: tangent of cast(u) = cast(du). + Expr::Cast(c) => { + let du = linearize(&c.expr, leaf)?; + Ok(Expr::Cast(Cast::new(Box::new(du), c.data_type.clone()))) + } + + // tangent of -u = -(du). + Expr::Negative(inner) => Ok(neg(linearize(inner, leaf)?)), + + Expr::BinaryExpr(be) => linearize_binary(be, leaf), + + Expr::ScalarFunction(sf) => linearize_scalar_function(sf, leaf), + + other => Err(DataFusionError::NotImplemented(format!( + "grad: differentiation is not implemented for this expression: {other}" + ))), + } +} + +/// Differentiate `expr` with respect to the column named `wrt`. +/// +/// Forward-mode with a one-hot seed: `1` on `wrt`, `0` on every other column. +pub fn differentiate(expr: &Expr, wrt: &str) -> Result { + linearize(expr, &|name| if name == wrt { one() } else { zero() }) +} + +/// Forward-mode directional derivative: the tangent of `expr` given a tangent +/// (`seeds[col]`) for each seeded input column; unseeded columns are constant. +fn jvp(expr: &Expr, seeds: &HashMap) -> Result { + linearize(expr, &|name| seeds.get(name).cloned().unwrap_or_else(zero)) +} + +/// Linearize a binary arithmetic expression via the sum/product/quotient rules. +fn linearize_binary(be: &BinaryExpr, leaf: &Leaf) -> Result { + let a = be.left.as_ref(); + let b = be.right.as_ref(); + let da = linearize(a, leaf)?; + let db = linearize(b, leaf)?; + + match be.op { + // tangent of (a + b) = da + db + Operator::Plus => Ok(add(da, db)), + // tangent of (a - b) = da - db + Operator::Minus => Ok(sub(da, db)), + // tangent of (a * b) = da*b + a*db (product rule) + Operator::Multiply => Ok(add(mul(da, b.clone()), mul(a.clone(), db))), + // tangent of (a / b) = (da*b - a*db) / b^2 (quotient rule) + Operator::Divide => { + let numerator = sub(mul(da, b.clone()), mul(a.clone(), db)); + Ok(div(numerator, square(b.clone()))) + } + op => Err(DataFusionError::NotImplemented(format!( + "grad: operator '{op}' is not differentiable" + ))), + } +} + +/// Linearize a scalar-function call via the chain rule. +/// +/// For a unary primitive `f(u)`, the tangent is `f'(u) * du`. For `power`, +/// which is binary, we handle the constant-exponent and constant-base cases. +fn linearize_scalar_function(sf: &ScalarFunction, leaf: &Leaf) -> Result { + let name = sf.func.name(); + let args = &sf.args; + + // `power(base, exponent)` is the one binary primitive we linearize. + if name == "power" { + return linearize_power(args, leaf); + } + + if args.len() != 1 { + return Err(DataFusionError::NotImplemented(format!( + "grad: no derivative rule for function '{name}' with {} arguments", + args.len() + ))); + } + + let u = &args[0]; + let du = linearize(u, leaf)?; + // Chain rule short-circuit: if du is 0, the whole tangent is 0 and we avoid + // emitting the (dead) outer derivative term entirely. + if is_zero(&du) { + return Ok(zero()); + } + + let outer = match name { + // Trigonometric. + "sin" => expr_fn::cos(u.clone()), + "cos" => neg(expr_fn::sin(u.clone())), + "tan" => div(one(), square(expr_fn::cos(u.clone()))), + // Inverse trigonometric. + "asin" => div(one(), expr_fn::sqrt(sub(one(), square(u.clone())))), + "acos" => neg(div(one(), expr_fn::sqrt(sub(one(), square(u.clone()))))), + "atan" => div(one(), add(one(), square(u.clone()))), + // Exponential / logarithmic. + "exp" => expr_fn::exp(u.clone()), + "ln" => div(one(), u.clone()), + "log2" => div(one(), mul(u.clone(), lit(LN_2))), + "log10" => div(one(), mul(u.clone(), lit(LN_10))), + "sqrt" => div(one(), mul(lit(2.0_f64), expr_fn::sqrt(u.clone()))), + // Hyperbolic. + "sinh" => expr_fn::cosh(u.clone()), + "cosh" => expr_fn::sinh(u.clone()), + "tanh" => sub(one(), square(expr_fn::tanh(u.clone()))), + // Piecewise-linear: derivative is the sign (undefined at 0, like JAX). + "abs" => expr_fn::signum(u.clone()), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "grad: no derivative rule for function '{name}'" + ))) + } + }; + + Ok(mul(outer, du)) +} + +/// Linearize `power(base, exponent)`. +/// +/// * Constant exponent `c`: tangent = `c * base^(c-1) * d(base)`. +/// * Constant base `a`: tangent = `a^u * ln(a) * d(u)`. +/// * Both variable (`u^v`): not supported yet. +fn linearize_power(args: &[Expr], leaf: &Leaf) -> Result { + if args.len() != 2 { + return Err(DataFusionError::NotImplemented( + "grad: power() expects exactly two arguments".to_string(), + )); + } + let base = &args[0]; + let exponent = &args[1]; + + match (as_const(base), as_const(exponent)) { + // Constant exponent (covers the common x^2, x^0.5, ... cases). + (_, Some(c)) => { + let dbase = linearize(base, leaf)?; + if is_zero(&dbase) { + return Ok(zero()); + } + let outer = mul(lit(c), expr_fn::power(base.clone(), lit(c - 1.0))); + Ok(mul(outer, dbase)) + } + // Constant base, variable exponent. + (Some(a), None) => { + let dexp = linearize(exponent, leaf)?; + if is_zero(&dexp) { + return Ok(zero()); + } + let outer = mul(expr_fn::power(base.clone(), exponent.clone()), lit(a.ln())); + Ok(mul(outer, dexp)) + } + // General u^v requires the exp/log trick; deferred for now. + (None, None) => Err(DataFusionError::NotImplemented( + "grad: power(base, exponent) where both depend on the \ + differentiation variable is not yet supported" + .to_string(), + )), + } +} + +// --------------------------------------------------------------------------- +// The `grad` / `jacobian` marker UDFs and the plan-level rewrite +// --------------------------------------------------------------------------- + +/// A no-op placeholder UDF for the autograd surface functions. +/// +/// `grad`, `jvp`, and `vjp` are *markers*: they carry the differentiation +/// request intact through SQL parsing, logical planning, and Substrait +/// serialization. They are always rewritten away by [`rewrite_grad_calls`] +/// before execution, so `invoke` is never reached in normal use (and +/// deliberately errors if it somehow is, rather than returning a wrong value). +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MarkerUdf { + name: String, + signature: Signature, +} + +impl MarkerUdf { + fn new(name: &str, arity: usize) -> Self { + Self { + name: name.to_string(), + signature: Signature::any(arity, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MarkerUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // Every autograd marker rewrites to a scalar derivative expression. + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Err(DataFusionError::Execution(format!( + "{}() marker reached execution without being rewritten; this is \ + an internal xarray-sql autograd error", + self.name + ))) + } +} + +/// The `grad(expr, column)` marker: scalar partial derivative `d(expr)/dcolumn`. +pub fn grad_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("grad", 2)) +} + +/// The `jvp(expr, column, tangent)` marker: forward-mode directional derivative. +pub fn jvp_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("jvp", 3)) +} + +/// The `vjp(expr, column, cotangent)` marker: reverse-mode pullback to an input. +pub fn vjp_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("vjp", 3)) +} + +/// Rewrite every `grad`/`jvp`/`vjp` call anywhere in a logical plan into its +/// symbolic derivative, leaving everything else untouched. The plan's schema is +/// recomputed afterwards because replacing a marker can change an expression's +/// name or type. +pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { + let rewritten = plan + .transform_up(|node| node.map_expressions(rewrite_grad_in_expr))? + .data; + rewritten.recompute_schema() +} + +/// Replace any `grad`/`jvp`/`vjp` calls nested anywhere inside a single +/// expression. +fn rewrite_grad_in_expr(expr: Expr) -> Result> { + expr.transform_up(|e| { + let Expr::ScalarFunction(sf) = &e else { + return Ok(Transformed::no(e)); + }; + match sf.func.name() { + "grad" => Ok(Transformed::yes(rewrite_grad(&sf.args)?)), + "jvp" => Ok(Transformed::yes(rewrite_jvp(&sf.args)?)), + "vjp" => Ok(Transformed::yes(rewrite_vjp(&sf.args)?)), + _ => Ok(Transformed::no(e)), + } + }) +} + +/// Read a bare column name from a marker argument, or report a clear error. +fn column_arg(func: &str, arg: &Expr) -> Result { + match arg { + Expr::Column(c) => Ok(c.name.clone()), + other => Err(DataFusionError::Plan(format!( + "{func}(): the column argument must be a bare column to \ + differentiate with respect to, got: {other}" + ))), + } +} + +/// `grad(expr, column)` -> `d(expr)/d(column)`. +fn rewrite_grad(args: &[Expr]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "grad() expects two arguments grad(expr, column), got {}", + args.len() + ))); + } + let wrt = column_arg("grad", &args[1])?; + differentiate(&args[0], &wrt) +} + +/// `jvp(expr, column, tangent)` -> forward-mode tangent: seed `tangent` on +/// `column` and push it through `expr`, yielding `d(expr)/d(column) * tangent`. +/// +/// A directional derivative over several inputs is the sum of per-input jvps, +/// e.g. `jvp(f, x, dx) + jvp(f, y, dy)`, since each treats the other inputs as +/// having zero tangent. +fn rewrite_jvp(args: &[Expr]) -> Result { + if args.len() != 3 { + return Err(DataFusionError::Plan(format!( + "jvp() expects three arguments jvp(expr, column, tangent), got {}", + args.len() + ))); + } + let wrt = column_arg("jvp", &args[1])?; + let seeds = HashMap::from([(wrt, args[2].clone())]); + jvp(&args[0], &seeds) +} + +/// `vjp(expr, column, cotangent)` -> reverse-mode pullback: the sensitivity that +/// an output cotangent induces on `column`, i.e. `cotangent * d(expr)/d(column)`. +/// +/// For a single scalar output this equals the matching `jvp` (both contract the +/// same partial derivative); the surfaces differ in where the seed lives — `jvp` +/// seeds an input tangent, `vjp` seeds an output cotangent. +fn rewrite_vjp(args: &[Expr]) -> Result { + if args.len() != 3 { + return Err(DataFusionError::Plan(format!( + "vjp() expects three arguments vjp(expr, column, cotangent), got {}", + args.len() + ))); + } + let wrt = column_arg("vjp", &args[1])?; + let derivative = differentiate(&args[0], &wrt)?; + Ok(mul(args[2].clone(), derivative)) +} + +// --------------------------------------------------------------------------- +// SQL source-to-source rewrite +// --------------------------------------------------------------------------- + +/// Rewrite every `grad`/`jvp`/`vjp` call in a SQL statement into its symbolic +/// derivative, returning the rewritten SQL text. +/// +/// Unlike a logical-plan rewrite, this is a pure source-to-source transform run +/// *before* the query is planned, so it works for any query shape the SQL parser +/// accepts — recursive CTEs, DML, and subqueries included. Each marker call is +/// parsed into a DataFusion [`Expr`], differentiated by the engine in this +/// module, and rendered back to SQL in place. Columns are taken from the call's +/// own identifiers (all treated as `Float64`; types don't affect the symbolic +/// result), so no catalog or table schema is needed. +pub fn rewrite_grad_in_sql(sql: &str) -> Result { + let dialect = GenericDialect {}; + let mut statements = Parser::parse_sql(&dialect, sql) + .map_err(|e| DataFusionError::Plan(format!("grad: failed to parse SQL: {e}")))?; + + // A throwaway context that only needs the marker UDFs registered so the + // calls parse into `ScalarFunction` nodes the engine can dispatch on. + let ctx = SessionContext::new(); + ctx.register_udf(grad_marker()); + ctx.register_udf(jvp_marker()); + ctx.register_udf(vjp_marker()); + + let mut rewriter = GradSqlRewriter { ctx: &ctx }; + for stmt in &mut statements { + if let ControlFlow::Break(msg) = stmt.visit(&mut rewriter) { + return Err(DataFusionError::Plan(msg)); + } + } + + Ok(statements + .iter() + .map(ToString::to_string) + .collect::>() + .join("; ")) +} + +/// True if `name` is one of the autograd marker functions (case-insensitive). +fn is_marker_name(name: &str) -> bool { + matches!(name.to_lowercase().as_str(), "grad" | "jvp" | "vjp") +} + +/// Walks a SQL AST and replaces each `grad`/`jvp`/`vjp` call with its derivative. +struct GradSqlRewriter<'a> { + ctx: &'a SessionContext, +} + +impl VisitorMut for GradSqlRewriter<'_> { + type Break = String; + + fn pre_visit_expr(&mut self, expr: &mut SqlExpr) -> ControlFlow { + let is_marker = matches!( + expr, + SqlExpr::Function(f) if is_marker_name(&f.name.to_string()) + ); + if !is_marker { + return ControlFlow::Continue(()); + } + match self.rewrite_call(expr) { + Ok(()) => ControlFlow::Continue(()), + Err(e) => ControlFlow::Break(e), + } + } +} + +impl GradSqlRewriter<'_> { + /// Differentiate a single marker call in place. The replacement is wrapped + /// in parentheses so it keeps the call's precedence in the surrounding SQL. + fn rewrite_call(&self, expr: &mut SqlExpr) -> std::result::Result<(), String> { + let schema = call_schema(expr)?; + let text = expr.to_string(); + let parsed = self + .ctx + .parse_sql_expr(&text, &schema) + .map_err(|e| format!("grad: failed to parse '{text}': {e}"))?; + let derivative = rewrite_grad_in_expr(parsed) + .map_err(|e| format!("grad: failed to differentiate '{text}': {e}"))? + .data; + let rendered = expr_to_sql(&derivative) + .map_err(|e| format!("grad: failed to render derivative for '{text}': {e}"))?; + *expr = SqlExpr::Nested(Box::new(rendered)); + Ok(()) + } +} + +/// Build a `Float64` schema covering every column identifier referenced inside a +/// marker call, so the call's argument expression can be parsed standalone. +fn call_schema(call: &SqlExpr) -> std::result::Result { + let mut collector = ColumnCollector::default(); + let _ = call.visit(&mut collector); + let fields = collector + .cols + .into_iter() + .map(|(qualifier, name)| { + let qualifier = qualifier.map(TableReference::bare); + ( + qualifier, + Arc::new(Field::new(name, DataType::Float64, true)), + ) + }) + .collect(); + DFSchema::new_with_metadata(fields, HashMap::new()) + .map_err(|e| format!("grad: failed to build schema for differentiation: {e}")) +} + +/// Collects the (optional qualifier, name) of every column identifier in a SQL +/// expression tree. +#[derive(Default)] +struct ColumnCollector { + cols: Vec<(Option, String)>, +} + +impl Visitor for ColumnCollector { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &SqlExpr) -> ControlFlow<()> { + let pair = match expr { + SqlExpr::Identifier(ident) => Some((None, ident.value.clone())), + SqlExpr::CompoundIdentifier(parts) => parts.last().map(|last| { + let qualifier = (parts.len() >= 2).then(|| parts[parts.len() - 2].value.clone()); + (qualifier, last.value.clone()) + }), + _ => None, + }; + if let Some(pair) = pair { + if !self.cols.contains(&pair) { + self.cols.push(pair); + } + } + ControlFlow::Continue(()) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use datafusion::logical_expr::col; + + use super::*; + + #[test] + fn constant_has_zero_derivative() { + assert_eq!(differentiate(&lit(3.0_f64), "x").unwrap(), zero()); + } + + #[test] + fn variable_has_unit_derivative() { + assert_eq!(differentiate(&col("x"), "x").unwrap(), one()); + } + + #[test] + fn other_variable_has_zero_derivative() { + assert_eq!(differentiate(&col("y"), "x").unwrap(), zero()); + } + + #[test] + fn sum_rule_folds_constants() { + // d/dx (x + y) = 1 + 0 = 1 + let e = add(col("x"), col("y")); + assert_eq!(differentiate(&e, "x").unwrap(), one()); + } + + #[test] + fn product_rule() { + // d/dx (x * x) = 1*x + x*1 = x + x + let e = binary(col("x"), Operator::Multiply, col("x")); + let expected = add(col("x"), col("x")); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn quotient_rule() { + // d/dx (x / y) = (1*y - x*0) / (y*y) = y / (y*y) + let e = binary(col("x"), Operator::Divide, col("y")); + let expected = div(col("y"), square(col("y"))); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn chain_rule_sin() { + // d/dx sin(x) = cos(x) * 1 = cos(x) + let d = differentiate(&expr_fn::sin(col("x")), "x").unwrap(); + assert_eq!(d, expr_fn::cos(col("x"))); + // Readable, precedence-free rendering. + assert_eq!(d.to_string(), "cos(x)"); + } + + #[test] + fn composite_sin_times_x() { + // d/dx (sin(x) * x) = cos(x)*x + sin(x) + let e = binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let d = differentiate(&e, "x").unwrap(); + assert_eq!(d.to_string(), "cos(x) * x + sin(x)"); + } + + #[test] + fn power_constant_exponent() { + // d/dx power(x, 2) = 2 * power(x, 1) * 1 = 2 * power(x, 1) + let e = expr_fn::power(col("x"), lit(2.0_f64)); + let expected = mul(lit(2.0_f64), expr_fn::power(col("x"), lit(1.0_f64))); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn unsupported_operator_errors() { + let e = binary(col("x"), Operator::Modulo, col("y")); + assert!(differentiate(&e, "x").is_err()); + } + + #[test] + fn unsupported_function_errors() { + // atan2 is binary and has no rule yet. + let e = expr_fn::atan2(col("x"), col("y")); + assert!(differentiate(&e, "x").is_err()); + } + + #[test] + fn higher_order_derivative() { + // Differentiation composes: d2/dx2 sin(x) = -sin(x). + let d1 = differentiate(&expr_fn::sin(col("x")), "x").unwrap(); + let d2 = differentiate(&d1, "x").unwrap(); + assert_eq!(d2, neg(expr_fn::sin(col("x")))); + } + + #[test] + fn jvp_seeds_a_tangent_on_one_input() { + // jvp(x*y, {x: dx}) = product rule with tangent(x)=dx, tangent(y)=0 + // = dx*y + x*0 = dx*y + let f = binary(col("x"), Operator::Multiply, col("y")); + let seeds = HashMap::from([("x".to_string(), col("dx"))]); + let t = jvp(&f, &seeds).unwrap(); + assert_eq!(t, mul(col("dx"), col("y"))); + } + + #[test] + fn jvp_with_unit_seed_matches_grad() { + // A one-hot tangent reproduces the partial derivative. + let f = expr_fn::sin(col("x")); + let seeds = HashMap::from([("x".to_string(), one())]); + assert_eq!(jvp(&f, &seeds).unwrap(), differentiate(&f, "x").unwrap()); + } + + #[test] + fn vjp_equals_cotangent_times_grad() { + // rewrite_vjp(sin(x), x, w) = w * cos(x) + let f = expr_fn::sin(col("x")); + let got = rewrite_vjp(&[f.clone(), col("x"), col("w")]).unwrap(); + assert_eq!(got, mul(col("w"), expr_fn::cos(col("x")))); + } + + #[test] + fn jvp_and_vjp_agree_for_unit_seed() { + // With matching unit seed/cotangent, forward and reverse coincide. + let f = binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let fwd = rewrite_jvp(&[f.clone(), col("x"), one()]).unwrap(); + let rev = rewrite_vjp(&[f, col("x"), one()]).unwrap(); + assert_eq!(fwd, rev); + } + + #[test] + fn sql_rewrite_replaces_grad_call() { + // grad(sin(x), x) -> cos(x); the surrounding SELECT is preserved. + let out = rewrite_grad_in_sql("SELECT grad(sin(x), x) AS d FROM t").unwrap(); + assert_eq!(out, "SELECT (cos(x)) AS d FROM t"); + } + + #[test] + fn sql_rewrite_leaves_non_grad_queries_intact() { + // A query with no marker is still parsed and re-emitted unchanged in + // meaning (the caller only invokes the rewrite when a marker is present). + let out = rewrite_grad_in_sql("SELECT a + b FROM t").unwrap(); + assert_eq!(out, "SELECT a + b FROM t"); + } + + #[test] + fn sql_rewrite_fires_inside_recursive_cte() { + // The #197 capability: a marker inside a recursive term is rewritten, + // a query shape the Substrait bridge could never carry. d/dx(x*x) = x+x. + let out = rewrite_grad_in_sql( + "WITH RECURSIVE r AS (SELECT 1.0 AS x UNION ALL \ + SELECT x - grad(x * x, x) FROM r WHERE x < 10) SELECT x FROM r", + ) + .unwrap(); + assert!(out.contains("(x + x)"), "unexpected rewrite: {out}"); + assert!( + !out.to_lowercase().contains("grad("), + "marker left behind: {out}" + ); + } + + #[test] + fn sql_rewrite_handles_nested_higher_order_grad() { + // grad(grad(power(x, 3), x), x) -> d2/dx2 (x^3) = 6x; bottom-up so the + // inner call is differentiated before the outer one. + let out = rewrite_grad_in_sql("SELECT grad(grad(power(x, 3), x), x) AS d FROM t").unwrap(); + assert!( + !out.to_lowercase().contains("grad("), + "marker left behind: {out}" + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index c489609..d157d72 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,8 @@ //! Will skip loading partitions whose time ranges are entirely before 2020-02-01. //! Supported operators: `=`, `<`, `>`, `<=`, `>=`, `BETWEEN`, `IN`, `AND`, `OR`. +mod autograd; + use std::any::Any; use std::collections::{HashMap, HashSet}; use std::ffi::CString; @@ -48,13 +50,13 @@ use std::fmt::Debug; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::pyarrow::FromPyArrow; use async_stream::try_stream; use async_trait::async_trait; use datafusion::catalog::streaming::StreamingTable; use datafusion::catalog::Session; -use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::common::{DFSchema, DataFusionError, Result as DFResult, ScalarValue}; use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::logical_expr::expr::InList; @@ -64,6 +66,8 @@ use datafusion::logical_expr::{ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use datafusion::prelude::SessionContext; +use datafusion::sql::unparser::expr_to_sql; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use pyo3::prelude::*; @@ -981,9 +985,91 @@ impl LazyArrowStreamTable { } } +// ============================================================================ +// Autograd: SQL-level grad() rewrite +// ============================================================================ + +/// Rewrite every `grad`/`jvp`/`vjp` call in a SQL query into its symbolic +/// derivative, returning the rewritten SQL text. +/// +/// The autograd engine operates on DataFusion logical `Expr` trees. Rather than +/// round-tripping a whole plan across the cdylib boundary, this rewrites the +/// query as **SQL text** before it is planned: each marker call is parsed, +/// differentiated, and rendered back to SQL in place. Because it runs before +/// planning, it works for any query shape the parser accepts — recursive CTEs, +/// DML, and subqueries — which the plan-level Substrait bridge could not carry. +/// +/// Args: +/// query: A SQL query string that may contain `grad`/`jvp`/`vjp` calls. +/// +/// Returns: +/// The rewritten SQL string, ready to pass to ``SessionContext.sql``. +#[pyfunction] +fn rewrite_grad_sql(query: &str) -> PyResult { + autograd::rewrite_grad_in_sql(query).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "rewrite_grad_sql: failed to rewrite grad() calls: {e}" + )) + }) +} + +/// Differentiate a SQL scalar expression symbolically and return the +/// derivative as SQL text. +/// +/// Where [`grad_rewrite`] rewrites `grad(...)` calls inside a whole plan, this +/// differentiates a single expression and hands back the result as SQL — the +/// autograd engine acting as a "calculus compiler". It lets a caller obtain an +/// update rule once and embed it in queries the Substrait round-trip can't +/// carry a `grad` marker through, such as a recursive-CTE training loop. +/// +/// Args: +/// expr: A SQL scalar expression over `columns` (e.g. `"sin(x) * x"`). +/// wrt: The column name to differentiate with respect to. +/// columns: The column names in scope; all treated as `Float64` (enough to +/// parse and differentiate — types don't affect the symbolic result). +/// +/// Returns: +/// The derivative as a SQL string (e.g. `"cos(x) * x + sin(x)"`). +#[pyfunction] +fn differentiate_sql(expr: &str, wrt: &str, columns: Vec) -> PyResult { + let ctx = SessionContext::new(); + + let fields: Vec = columns + .iter() + .map(|name| Field::new(name, DataType::Float64, true)) + .collect(); + let df_schema = DFSchema::try_from(Schema::new(fields)).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to build schema: {e}" + )) + })?; + + let parsed = ctx.parse_sql_expr(expr, &df_schema).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to parse expression '{expr}': {e}" + )) + })?; + + let derivative = autograd::differentiate(&parsed, wrt).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to differentiate: {e}" + )) + })?; + + let sql = expr_to_sql(&derivative).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to render derivative to SQL: {e}" + )) + })?; + + Ok(sql.to_string()) +} + /// Python module initialization #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_function(wrap_pyfunction!(rewrite_grad_sql, m)?)?; + m.add_function(wrap_pyfunction!(differentiate_sql, m)?)?; Ok(()) } diff --git a/tests/test_autograd.py b/tests/test_autograd.py new file mode 100644 index 0000000..794194e --- /dev/null +++ b/tests/test_autograd.py @@ -0,0 +1,309 @@ +"""Tests for the SQL autograd surface: ``SELECT grad(expr, column) ...``. + +These exercise the full path — XarrayContext.sql() differentiates every +``grad``/``jvp``/``vjp`` call as SQL text before planning, then DataFusion +executes the rewritten query — and compare results against analytic +derivatives computed with numpy. +""" + +import numpy as np +import pyarrow as pa +import pytest +import xarray as xr + +import xarray_sql as xql + + +@pytest.fixture +def ctx(): + val = np.linspace(0.1, 3.0, 16) + ds = xr.Dataset( + {"val": (("i",), val)}, + coords={"i": np.arange(16)}, + ) + context = xql.XarrayContext() + context.from_dataset("t", ds, chunks={"i": 5}) + return context + + +@pytest.fixture +def ctx_xy(): + rng = np.random.default_rng(0) + n = 16 + ds = xr.Dataset( + { + "x": (("i",), rng.uniform(0.5, 2.5, n)), + "y": (("i",), rng.uniform(0.5, 2.5, n)), + }, + coords={"i": np.arange(n)}, + ) + context = xql.XarrayContext() + context.from_dataset("g", ds, chunks={"i": 5}) + return context, ds + + +def _ordered(df, key="i"): + """Collect a result DataFrame into a dict of column -> numpy array, sorted + by the integer key column so comparisons are index-aligned.""" + pdf = df.to_pandas().sort_values(key) + return {c: pdf[c].to_numpy() for c in pdf.columns} + + +def test_grad_sin_is_cos(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql("SELECT i, grad(sin(val), val) AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val)) + + +def test_grad_product_rule(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql("SELECT i, grad(sin(val) * val, val) AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val) * val + np.sin(val)) + + +def test_grad_exp_equals_value(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql("SELECT i, exp(val) AS v, grad(exp(val), val) AS d FROM t") + ) + np.testing.assert_allclose(res["d"], np.exp(val)) + np.testing.assert_allclose(res["d"], res["v"]) + + +def test_grad_quotient_and_power(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(1.0 / val, val) AS dinv, " + "grad(power(val, 3), val) AS dcube FROM t" + ) + ) + np.testing.assert_allclose(res["dinv"], -1.0 / val**2) + np.testing.assert_allclose(res["dcube"], 3.0 * val**2) + + +def test_higher_order_grad(ctx): + # Nested grad() differentiates repeatedly: the inner call is rewritten + # first, then the outer differentiates its result. + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, " + "grad(grad(sin(val), val), val) AS d2_sin, " + "grad(grad(power(val, 3), val), val) AS d2_cube FROM t" + ) + ) + np.testing.assert_allclose(res["d2_sin"], -np.sin(val)) # -sin + np.testing.assert_allclose(res["d2_cube"], 6.0 * val) # d2/dx2 x^3 = 6x + + +def test_third_order_grad(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(grad(grad(sin(val), val), val), val) AS d3 FROM t" + ) + ) + np.testing.assert_allclose(res["d3"], -np.cos(val)) # d3/dx3 sin = -cos + + +def test_non_grad_query_is_unaffected(ctx): + # Queries without grad() bypass the rewrite and behave normally. + res = _ordered(ctx.sql("SELECT i, val FROM t")) + np.testing.assert_allclose(res["val"], np.linspace(0.1, 3.0, 16)) + + +def test_unsupported_function_raises(ctx): + # atan2 has no derivative rule yet -> a clear error, not a wrong answer. + with pytest.raises(Exception): + ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() + + +def test_grad_over_in_memory_table(ctx): + # grad works over plain DataFusion tables too (not just xarray-registered + # ones): here a coefficient lives in an in-memory MemTable cross-joined to + # the xarray data. d/dval (c * val^2) = c * 2*val, with c = 3. + ctx.register_record_batches( + "coef", [[pa.RecordBatch.from_pydict({"c": [3.0]})]] + ) + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(c * val * val, val) AS d FROM t CROSS JOIN coef" + ) + ) + np.testing.assert_allclose(res["d"], 3.0 * 2.0 * val) + + +def test_differentiate_sql_round_trip(ctx): + # differentiate_sql returns the derivative as SQL text; evaluating it must + # match the analytic derivative. d/dval (sin(val)*val) = cos(val)*val + sin(val). + deriv = xql.differentiate_sql("sin(val) * val", "val", ["val"]) + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql(f"SELECT i, {deriv} AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val) * val + np.sin(val)) + + +def test_grad_inside_aggregate(ctx): + # Differentiation through an aggregate is just linearity: + # AGG(grad(f, x)) == d/dx AGG(f). grad rewrites to plain SQL before the + # aggregate runs, so this composes with no special machinery. + val = np.linspace(0.1, 3.0, 16) + res = ctx.sql( + "SELECT SUM(grad(val * val, val)) AS s, " + "AVG(grad(sin(val), val)) AS a FROM t" + ).to_pandas() + np.testing.assert_allclose(res["s"][0], np.sum(2 * val)) + np.testing.assert_allclose(res["a"][0], np.mean(np.cos(val))) + + +def test_gradient_descent_in_sql(): + # End to end: fit y ~= a*x + b by minimising MSE, with the gradients + # w.r.t. the parameters computed in SQL via AVG(grad(loss, param)). + rng = np.random.default_rng(0) + n = 200 + x = rng.uniform(0.0, 1.0, n) + a_true, b_true = 2.0, -1.0 + y = a_true * x + b_true + rng.normal(0.0, 0.01, n) + data = xr.Dataset( + {"x": (("i",), x), "y": (("i",), y)}, coords={"i": np.arange(n)} + ) + ctx = xql.XarrayContext() + ctx.from_dataset("d", data, chunks={"i": n}) + + resid = "(y - (a * x + b))" + loss = f"{resid} * {resid}" + a, b, lr = 0.0, 0.0, 0.4 + losses = [] + for _ in range(120): + if "params" in ctx._registered_datasets: + ctx.deregister_table("params") + del ctx._registered_datasets["params"] + params = xr.Dataset( + {"a": (("p",), [a]), "b": (("p",), [b])}, coords={"p": [0]} + ) + ctx.from_dataset("params", params, chunks={"p": 1}) + row = ctx.sql( + f"SELECT AVG({loss}) AS loss, " + f"AVG(grad({loss}, a)) AS dl_da, " + f"AVG(grad({loss}, b)) AS dl_db FROM d CROSS JOIN params" + ).to_pandas() + losses.append(float(row["loss"][0])) + a -= lr * float(row["dl_da"][0]) + b -= lr * float(row["dl_db"][0]) + + assert losses[-1] < losses[0] # loss decreased + np.testing.assert_allclose([a, b], [a_true, b_true], atol=0.05) + + +def test_grad_inside_recursive_cte(): + # The headline of #197: grad() *inside* a recursive CTE — a query shape the + # old Substrait bridge could not represent. Newton's method for sqrt(2) + # drives the step with grad(x*x - 2, x) computed in the recursive term: + # x <- x - (x*x - 2) / d/dx(x*x - 2) = x - (x*x - 2) / (2x). + ctx = xql.XarrayContext() + res = ctx.sql( + "WITH RECURSIVE newton AS (" + " SELECT 0 AS step, CAST(1.0 AS DOUBLE) AS x " + " UNION ALL " + " SELECT step + 1 AS step, " + " x - (x * x - 2.0) / grad(x * x - 2.0, x) AS x " + " FROM newton WHERE step < 20" + ") " + "SELECT x FROM newton ORDER BY step DESC LIMIT 1" + ).to_pandas() + np.testing.assert_allclose(res["x"][0], np.sqrt(2.0), atol=1e-9) + + +def test_multi_input_grad_columns(ctx_xy): + # A full Jacobian written as separate scalar grad() columns: + # f = x*y -> df/dx = y, df/dy = x. + context, ds = ctx_xy + res = _ordered( + context.sql( + "SELECT i, grad(x * y, x) AS dfdx, grad(x * y, y) AS dfdy FROM g" + ) + ) + np.testing.assert_allclose(res["dfdx"], ds["y"].values) + np.testing.assert_allclose(res["dfdy"], ds["x"].values) + + +def test_jvp_forward_directional_derivative(ctx_xy): + # jvp(f, x, dx) = df/dx * dx. With f = sin(x)*y and a constant tangent. + context, ds = ctx_xy + x, y = ds["x"].values, ds["y"].values + res = _ordered(context.sql("SELECT i, jvp(sin(x) * y, x, 2.0) AS t FROM g")) + np.testing.assert_allclose(res["t"], (np.cos(x) * y) * 2.0) + + +def test_jvp_multi_input_is_sum(ctx_xy): + # A full directional derivative is the sum of per-input jvp terms: + # df/dx*dx + df/dy*dy for f = x*y, with dx=1, dy=1 -> y + x. + context, ds = ctx_xy + res = _ordered( + context.sql( + "SELECT i, jvp(x * y, x, 1.0) + jvp(x * y, y, 1.0) AS t FROM g" + ) + ) + np.testing.assert_allclose(res["t"], ds["y"].values + ds["x"].values) + + +def test_vjp_reverse_pullback(ctx_xy): + # vjp(f, x, w) = w * df/dx. With f = sin(x)*y and cotangent w = 3.0. + context, ds = ctx_xy + x, y = ds["x"].values, ds["y"].values + res = _ordered(context.sql("SELECT i, vjp(sin(x) * y, x, 3.0) AS s FROM g")) + np.testing.assert_allclose(res["s"], 3.0 * (np.cos(x) * y)) + + +@pytest.fixture +def ctx_mixed(): + # A mixed-dimension dataset registers as schema-qualified tables: + # era5.time_x (surface, 2 dims) + # era5.time_x_level (atmosphere, 3 dims) + rng = np.random.default_rng(1) + ds = xr.Dataset( + { + "sfc": (("time", "x"), rng.uniform(0.5, 2.5, (3, 4))), + "atm": (("time", "x", "level"), rng.uniform(0.5, 2.5, (3, 4, 2))), + }, + coords={"time": [0, 1, 2], "x": np.arange(4.0), "level": [0, 1]}, + ) + context = xql.XarrayContext() + context.from_dataset("era5", ds, chunks={"time": 1}) + return context, ds + + +def test_grad_on_qualified_surface_table(ctx_mixed): + context, ds = ctx_mixed + res = _ordered( + context.sql( + "SELECT time, x, sfc, grad(sin(sfc), sfc) AS d FROM era5.time_x" + ), + key="sfc", + ) + np.testing.assert_allclose(res["d"], np.cos(res["sfc"])) + + +def test_grad_on_qualified_atmosphere_table(ctx_mixed): + context, ds = ctx_mixed + res = _ordered( + context.sql( + "SELECT atm, grad(power(atm, 2), atm) AS d FROM era5.time_x_level" + ), + key="atm", + ) + np.testing.assert_allclose(res["d"], 2.0 * res["atm"]) + + +def test_jvp_and_vjp_agree_for_unit_seed(ctx_xy): + # Forward (unit tangent) and reverse (unit cotangent) coincide for a + # scalar output -- both contract the same partial derivative. + context, _ = ctx_xy + res = _ordered( + context.sql( + "SELECT i, jvp(sin(x) * y, x, 1.0) AS fwd, " + "vjp(sin(x) * y, x, 1.0) AS rev FROM g" + ) + ) + np.testing.assert_allclose(res["fwd"], res["rev"]) diff --git a/xarray_sql/__init__.py b/xarray_sql/__init__.py index d1e5984..c01f295 100644 --- a/xarray_sql/__init__.py +++ b/xarray_sql/__init__.py @@ -1,4 +1,5 @@ from . import cftime +from ._native import differentiate_sql from .df import from_map from .reader import read_xarray, read_xarray_table from .sql import XarrayContext @@ -6,6 +7,7 @@ __all__ = [ "cftime", "XarrayContext", + "differentiate_sql", "read_xarray_table", "read_xarray", "from_map", # deprecated diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 0ec60ad..46fe8e6 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -1,13 +1,21 @@ +import re + import xarray as xr from datafusion import SessionContext from datafusion.catalog import Schema from collections import defaultdict +from . import _native from . import cftime as cft from .df import Chunks from .ds import XarrayDataFrame from .reader import read_xarray_table +# Matches a call to an autograd marker function (``grad(`` / ``jvp(`` / ``vjp(``, +# case-insensitive), used as a cheap gate so ordinary queries skip the grad +# source-to-source rewrite. +_GRAD_CALL = re.compile(r"\b(grad|jvp|vjp)\s*\(", re.IGNORECASE) + class XarrayContext(SessionContext): """A datafusion `SessionContext` that also supports `xarray.Dataset`s.""" @@ -166,6 +174,11 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: ``.to_dataset(dimension_columns=[...])`` for round-tripping the result back to an ``xr.Dataset``. + If the query contains ``grad`` / ``jvp`` / ``vjp`` calls, they are + differentiated and substituted as SQL text *before* planning (see + :meth:`_rewrite_autograd`), so the differentiation works inside any + query shape — recursive CTEs, DML, and subqueries included. + Args: query: A SQL query string. *args: Forwarded to ``SessionContext.sql``. @@ -174,9 +187,34 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: Returns: An :class:`XarrayDataFrame` wrapping the DataFusion DataFrame. """ + if _GRAD_CALL.search(query): + query = self._rewrite_autograd(query) inner = super().sql(query, *args, **kwargs) return XarrayDataFrame(inner, templates=self._registered_datasets) + def _rewrite_autograd(self, query: str) -> str: + """Differentiate ``grad`` / ``jvp`` / ``vjp`` calls into SQL text. + + The differentiation engine lives in the native (Rust) extension and + operates on DataFusion logical expressions. Rather than round-tripping a + whole plan across that extension's boundary, we hand it the query as SQL + text: it parses each marker call, differentiates it symbolically, and + renders the derivative back into the query in place. The result is an + ordinary SQL string this context can plan and execute directly. + + * ``grad(expr, column)`` -> ``d(expr)/d(column)``. + * ``jvp(expr, column, tangent)`` -> forward-mode directional derivative + ``d(expr)/d(column) * tangent`` (seed a tangent on an input). A + multi-input directional derivative is a sum of jvp terms. + * ``vjp(expr, column, cotangent)`` -> reverse-mode pullback + ``cotangent * d(expr)/d(column)`` (seed a cotangent on the output). + + A full gradient/Jacobian is expressed as several scalar columns, e.g. + ``grad(f, x) AS dfdx, grad(f, y) AS dfdy``. + """ + rewritten: str = _native.rewrite_grad_sql(query) + return rewritten + def _group_vars_by_dims(ds: xr.Dataset) -> dict[tuple[str, ...], list[str]]: """Group variables in the dataset based on shared dims.