From d20c8e8ce713f44bd74ce0551a53cd00309e0984 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 27 Jun 2026 22:44:16 +0000 Subject: [PATCH 01/12] Add symbolic differentiation engine for SQL expressions (autograd MVP) Introduce `src/autograd.rs`, the Rust core of the autograd feature: a `differentiate(&Expr, wrt)` function that symbolically differentiates a DataFusion logical `Expr` tree with respect to a named column and returns a new `Expr` built from ordinary SQL expressions. The design mirrors JAX's per-primitive rule registry (defjvp and friends): each node type has a differentiation rule and the chain rule composes them as the tree is walked. A small 0/1-folding simplifier keeps output compact, playing the role of JAX's Zero tangents and add_tangents. Because each table row is an independent evaluation point, differentiating a column expression and letting DataFusion evaluate it row-by-row is the relational equivalent of vmap(grad(f)). This first cut implements scalar `grad`: rules for +, -, *, / (sum, product, quotient), unary chain rule for sin/cos/tan, asin/acos/atan, exp/ln/log2/ log10/sqrt, sinh/cosh/tanh, abs, and power() with constant base or exponent. Unsupported nodes/functions return a clear NotImplemented error rather than a silently wrong derivative. The engine operates purely on DataFusion `Expr`, keeping the eventual Python<->Rust transport (SQL text, Substrait, or proto) pluggable. Covered by 11 unit tests. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 416 ++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 3 + 2 files changed, 419 insertions(+) create mode 100644 src/autograd.rs diff --git a/src/autograd.rs b/src/autograd.rs new file mode 100644 index 0000000..662c9f0 --- /dev/null +++ b/src/autograd.rs @@ -0,0 +1,416 @@ +//! Symbolic differentiation of DataFusion logical [`Expr`] trees. +//! +//! This is the autograd engine for xarray-sql. Given an [`Expr`] and the name +//! of a column to differentiate with respect to, [`differentiate`] returns a +//! new [`Expr`] for the (symbolic) partial derivative, built entirely from +//! ordinary DataFusion expressions so the result can be planned and evaluated +//! by DataFusion like any other SQL expression. +//! +//! ## Design +//! +//! The approach mirrors JAX's per-primitive rule registry (`defjvp` and +//! friends in `jax/_src/interpreters/ad.py`): every expression node has a +//! differentiation rule, and the chain rule composes them as the tree is +//! walked. Because each row of a relational table is an independent evaluation +//! point, differentiating a column expression and letting DataFusion evaluate +//! it row-by-row is the moral equivalent of `jax.vmap(jax.grad(f))` — the rows +//! *are* the batch dimension. +//! +//! A small simplifier folds the `0`/`1` constants that differentiation +//! produces in abundance (e.g. `d/dx (c) = 0`, `d/dx (x) = 1`), keeping output +//! expressions compact. This plays the role of JAX's `Zero` tangents and +//! `add_tangents`: a `0` derivative short-circuits products and drops out of +//! sums, and a `1` factor drops out of products. +//! +//! ## Scope (MVP) +//! +//! This first cut implements scalar `grad`: the partial derivative of a single +//! expression with respect to one named column. Forward-/reverse-mode +//! (`jvp`/`vjp`) and multi-input Jacobians are deliberately left for later. + +#![allow(dead_code)] + +use std::f64::consts::{LN_10, LN_2}; + +use datafusion::common::{DataFusionError, Result, ScalarValue}; +use datafusion::functions::math::expr_fn; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::logical_expr::{lit, BinaryExpr, Cast, Expr, Operator}; + +// --------------------------------------------------------------------------- +// Constant helpers and the 0/1-folding builders +// --------------------------------------------------------------------------- + +/// The constant `0.0`, used as the derivative of anything not depending on the +/// differentiation variable. +fn zero() -> Expr { + lit(0.0_f64) +} + +/// The constant `1.0`, used as the derivative of the differentiation variable. +fn one() -> Expr { + lit(1.0_f64) +} + +/// Interpret a [`ScalarValue`] as `f64` if it is a (non-null) numeric scalar. +fn scalar_as_f64(sv: &ScalarValue) -> Option { + match sv { + ScalarValue::Float64(Some(v)) => Some(*v), + ScalarValue::Float32(Some(v)) => Some(*v as f64), + ScalarValue::Int64(Some(v)) => Some(*v as f64), + ScalarValue::Int32(Some(v)) => Some(*v as f64), + ScalarValue::Int16(Some(v)) => Some(*v as f64), + ScalarValue::Int8(Some(v)) => Some(*v as f64), + ScalarValue::UInt64(Some(v)) => Some(*v as f64), + ScalarValue::UInt32(Some(v)) => Some(*v as f64), + ScalarValue::UInt16(Some(v)) => Some(*v as f64), + ScalarValue::UInt8(Some(v)) => Some(*v as f64), + _ => None, + } +} + +/// Return the constant `f64` value of a literal expression, if it is one. +fn as_const(e: &Expr) -> Option { + match e { + Expr::Literal(sv, _) => scalar_as_f64(sv), + _ => None, + } +} + +/// True if the expression is a numeric literal exactly equal to zero. +fn is_zero(e: &Expr) -> bool { + matches!(as_const(e), Some(v) if v == 0.0) +} + +/// True if the expression is a numeric literal exactly equal to one. +fn is_one(e: &Expr) -> bool { + matches!(as_const(e), Some(v) if v == 1.0) +} + +fn binary(left: Expr, op: Operator, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) +} + +/// `a + b`, dropping a zero operand. +fn add(a: Expr, b: Expr) -> Expr { + if is_zero(&a) { + b + } else if is_zero(&b) { + a + } else { + binary(a, Operator::Plus, b) + } +} + +/// `a - b`, dropping a zero right operand and turning `0 - b` into `-b`. +fn sub(a: Expr, b: Expr) -> Expr { + if is_zero(&b) { + a + } else if is_zero(&a) { + neg(b) + } else { + binary(a, Operator::Minus, b) + } +} + +/// `a * b`, folding `0 * _ = 0` and `1 * b = b` (and the mirror cases). +fn mul(a: Expr, b: Expr) -> Expr { + if is_zero(&a) || is_zero(&b) { + zero() + } else if is_one(&a) { + b + } else if is_one(&b) { + a + } else { + binary(a, Operator::Multiply, b) + } +} + +/// `a / b`, folding `0 / _ = 0` and `a / 1 = a`. +fn div(a: Expr, b: Expr) -> Expr { + if is_zero(&a) { + zero() + } else if is_one(&b) { + a + } else { + binary(a, Operator::Divide, b) + } +} + +/// `-a`, folding `-0 = 0`. +fn neg(a: Expr) -> Expr { + if is_zero(&a) { + zero() + } else { + Expr::Negative(Box::new(a)) + } +} + +/// `e * e`. +fn square(e: Expr) -> Expr { + mul(e.clone(), e) +} + +// --------------------------------------------------------------------------- +// The differentiation rules +// --------------------------------------------------------------------------- + +/// Differentiate `expr` with respect to the column named `wrt`. +/// +/// Returns a new [`Expr`] for the partial derivative, composed of ordinary +/// DataFusion expressions. Returns a [`DataFusionError::NotImplemented`] for +/// expression nodes or scalar functions without a differentiation rule, so the +/// caller can surface a clear, actionable error rather than silently producing +/// a wrong answer. +pub fn differentiate(expr: &Expr, wrt: &str) -> Result { + match expr { + // d/dx (x) = 1 ; d/dx (y) = 0 for any other column. + Expr::Column(c) => Ok(if c.name == wrt { one() } else { zero() }), + + // d/dx (constant) = 0. + Expr::Literal(_, _) => Ok(zero()), + + // An alias is transparent to differentiation; the surrounding query + // re-applies any output naming. + Expr::Alias(a) => differentiate(&a.expr, wrt), + + // A numeric cast is (locally) linear: d/dx cast(u) = cast(du). We keep + // the cast so the derivative retains the declared output type. + Expr::Cast(c) => { + let du = differentiate(&c.expr, wrt)?; + Ok(Expr::Cast(Cast::new(Box::new(du), c.data_type.clone()))) + } + + // d/dx (-u) = -(du). + Expr::Negative(inner) => Ok(neg(differentiate(inner, wrt)?)), + + Expr::BinaryExpr(be) => diff_binary(be, wrt), + + Expr::ScalarFunction(sf) => diff_scalar_function(sf, wrt), + + other => Err(DataFusionError::NotImplemented(format!( + "grad: differentiation is not implemented for this expression: {other}" + ))), + } +} + +/// Differentiate a binary arithmetic expression via the sum/product/quotient +/// rules. +fn diff_binary(be: &BinaryExpr, wrt: &str) -> Result { + let a = be.left.as_ref(); + let b = be.right.as_ref(); + let da = differentiate(a, wrt)?; + let db = differentiate(b, wrt)?; + + match be.op { + // d/dx (a + b) = da + db + Operator::Plus => Ok(add(da, db)), + // d/dx (a - b) = da - db + Operator::Minus => Ok(sub(da, db)), + // d/dx (a * b) = da*b + a*db (product rule) + Operator::Multiply => { + Ok(add(mul(da, b.clone()), mul(a.clone(), db))) + } + // d/dx (a / b) = (da*b - a*db) / b^2 (quotient rule) + Operator::Divide => { + let numerator = sub(mul(da, b.clone()), mul(a.clone(), db)); + Ok(div(numerator, square(b.clone()))) + } + op => Err(DataFusionError::NotImplemented(format!( + "grad: operator '{op}' is not differentiable" + ))), + } +} + +/// Differentiate a scalar-function call via the chain rule. +/// +/// For a unary primitive `f(u)`, the derivative is `f'(u) * du`. For `power`, +/// which is binary, we handle the constant-exponent and constant-base cases. +fn diff_scalar_function(sf: &ScalarFunction, wrt: &str) -> Result { + let name = sf.func.name(); + let args = &sf.args; + + // `power(base, exponent)` is the one binary primitive we differentiate. + if name == "power" { + return diff_power(args, wrt); + } + + if args.len() != 1 { + return Err(DataFusionError::NotImplemented(format!( + "grad: no derivative rule for function '{name}' with {} arguments", + args.len() + ))); + } + + let u = &args[0]; + let du = differentiate(u, wrt)?; + // Chain rule short-circuit: if du is 0, the whole derivative is 0 and we + // avoid emitting the (dead) outer derivative term entirely. + if is_zero(&du) { + return Ok(zero()); + } + + let outer = match name { + // Trigonometric. + "sin" => expr_fn::cos(u.clone()), + "cos" => neg(expr_fn::sin(u.clone())), + "tan" => div(one(), square(expr_fn::cos(u.clone()))), + // Inverse trigonometric. + "asin" => div(one(), expr_fn::sqrt(sub(one(), square(u.clone())))), + "acos" => neg(div(one(), expr_fn::sqrt(sub(one(), square(u.clone()))))), + "atan" => div(one(), add(one(), square(u.clone()))), + // Exponential / logarithmic. + "exp" => expr_fn::exp(u.clone()), + "ln" => div(one(), u.clone()), + "log2" => div(one(), mul(u.clone(), lit(LN_2))), + "log10" => div(one(), mul(u.clone(), lit(LN_10))), + "sqrt" => div(one(), mul(lit(2.0_f64), expr_fn::sqrt(u.clone()))), + // Hyperbolic. + "sinh" => expr_fn::cosh(u.clone()), + "cosh" => expr_fn::sinh(u.clone()), + "tanh" => sub(one(), square(expr_fn::tanh(u.clone()))), + // Piecewise-linear: derivative is the sign (undefined at 0, like JAX). + "abs" => expr_fn::signum(u.clone()), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "grad: no derivative rule for function '{name}'" + ))) + } + }; + + Ok(mul(outer, du)) +} + +/// Differentiate `power(base, exponent)`. +/// +/// * Constant exponent `c`: `d/dx base^c = c * base^(c-1) * d(base)`. +/// * Constant base `a`: `d/dx a^u = a^u * ln(a) * d(u)`. +/// * Both variable (`u^v`): not supported in the MVP. +fn diff_power(args: &[Expr], wrt: &str) -> Result { + if args.len() != 2 { + return Err(DataFusionError::NotImplemented( + "grad: power() expects exactly two arguments".to_string(), + )); + } + let base = &args[0]; + let exponent = &args[1]; + + match (as_const(base), as_const(exponent)) { + // Constant exponent (covers the common x^2, x^0.5, ... cases). + (_, Some(c)) => { + let dbase = differentiate(base, wrt)?; + if is_zero(&dbase) { + return Ok(zero()); + } + let outer = + mul(lit(c), expr_fn::power(base.clone(), lit(c - 1.0))); + Ok(mul(outer, dbase)) + } + // Constant base, variable exponent. + (Some(a), None) => { + let dexp = differentiate(exponent, wrt)?; + if is_zero(&dexp) { + return Ok(zero()); + } + let outer = mul( + expr_fn::power(base.clone(), exponent.clone()), + lit(a.ln()), + ); + Ok(mul(outer, dexp)) + } + // General u^v requires the exp/log trick; deferred past the MVP. + (None, None) => Err(DataFusionError::NotImplemented( + "grad: power(base, exponent) where both depend on the \ + differentiation variable is not yet supported" + .to_string(), + )), + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::logical_expr::col; + + #[test] + fn constant_has_zero_derivative() { + assert_eq!(differentiate(&lit(3.0_f64), "x").unwrap(), zero()); + } + + #[test] + fn variable_has_unit_derivative() { + assert_eq!(differentiate(&col("x"), "x").unwrap(), one()); + } + + #[test] + fn other_variable_has_zero_derivative() { + assert_eq!(differentiate(&col("y"), "x").unwrap(), zero()); + } + + #[test] + fn sum_rule_folds_constants() { + // d/dx (x + y) = 1 + 0 = 1 + let e = add(col("x"), col("y")); + assert_eq!(differentiate(&e, "x").unwrap(), one()); + } + + #[test] + fn product_rule() { + // d/dx (x * x) = 1*x + x*1 = x + x + let e = binary(col("x"), Operator::Multiply, col("x")); + let expected = add(col("x"), col("x")); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn quotient_rule() { + // d/dx (x / y) = (1*y - x*0) / (y*y) = y / (y*y) + let e = binary(col("x"), Operator::Divide, col("y")); + let expected = div(col("y"), square(col("y"))); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn chain_rule_sin() { + // d/dx sin(x) = cos(x) * 1 = cos(x) + let d = differentiate(&expr_fn::sin(col("x")), "x").unwrap(); + assert_eq!(d, expr_fn::cos(col("x"))); + // Readable, precedence-free rendering. + assert_eq!(d.to_string(), "cos(x)"); + } + + #[test] + fn composite_sin_times_x() { + // d/dx (sin(x) * x) = cos(x)*x + sin(x) + let e = + binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let d = differentiate(&e, "x").unwrap(); + assert_eq!(d.to_string(), "cos(x) * x + sin(x)"); + } + + #[test] + fn power_constant_exponent() { + // d/dx power(x, 2) = 2 * power(x, 1) * 1 = 2 * power(x, 1) + let e = expr_fn::power(col("x"), lit(2.0_f64)); + let expected = + mul(lit(2.0_f64), expr_fn::power(col("x"), lit(1.0_f64))); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn unsupported_operator_errors() { + let e = binary(col("x"), Operator::Modulo, col("y")); + assert!(differentiate(&e, "x").is_err()); + } + + #[test] + fn unsupported_function_errors() { + // atan2 is binary and has no rule yet. + let e = expr_fn::atan2(col("x"), col("y")); + assert!(differentiate(&e, "x").is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 63a5a6b..df94aac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,9 @@ //! Will skip loading partitions whose time ranges are entirely before 2020-02-01. //! Supported operators: `=`, `<`, `>`, `<=`, `>=`, `BETWEEN`, `IN`, `AND`, `OR`. +mod autograd; + +use std::any::Any; use std::collections::{HashMap, HashSet}; use std::ffi::CString; use std::fmt::Debug; From 672e7d0078a1e8b8a1cc61c32e66c0cab1a20bec Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 27 Jun 2026 23:28:29 +0000 Subject: [PATCH 02/12] Add Substrait bridge to apply grad() rewrite in the Python context Add the `grad` marker UDF and a plan-level rewriter (`rewrite_grad_calls`) to the autograd engine, plus a `grad_rewrite` PyO3 function that bridges the differentiation engine into the datafusion-python SessionContext. Because the native extension links its own copy of DataFusion, expressions cross the Python<->Rust boundary as Substrait protobuf. Python produces the logical plan as Substrait; `grad_rewrite` consumes it into a DataFusion LogicalPlan, rewrites every `grad(expr, column)` ScalarFunction into the symbolic derivative via `differentiate`, and re-produces Substrait bytes for Python to consume and execute. The custom xarray table provider round-trips because Substrait serializes table scans by name (resolved against the registry on consume), so the rewrite context only needs empty tables with matching schemas. `grad` is registered as a marker ScalarUDF that carries the differentiation request intact through parsing, planning, and serialization; it is always rewritten away before execution and errors if it ever reaches invoke. Deps: datafusion-substrait 52 and prost 0.14 (matching the substrait crate). Building now requires `protoc` (the substrait crate codegens from .proto). Verified end to end (produce -> grad_rewrite -> consume -> execute) against analytic derivatives for cos, the product rule, and exp with 0.0 error. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- Cargo.toml | 2 + src/autograd.rs | 126 ++++++++++++++++++++++++++++++++++++++++++------ src/lib.rs | 105 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 216 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fbf823a..eda0422 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,8 @@ async-stream = "0.3" async-trait = "0.1" datafusion = { version = "54.0.0" } datafusion-ffi = { version = "54.0.0" } +datafusion-substrait = { version = "52.0.0" } +prost = "0.14" futures = { version = "0.3" } # `abi3-py310` builds against CPython's stable ABI, so a single wheel per # platform works on all CPython >= 3.10 (matching `requires-python`). This diff --git a/src/autograd.rs b/src/autograd.rs index 662c9f0..1e72586 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -30,12 +30,18 @@ #![allow(dead_code)] +use std::any::Any; use std::f64::consts::{LN_10, LN_2}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{DataFusionError, Result, ScalarValue}; use datafusion::functions::math::expr_fn; use datafusion::logical_expr::expr::ScalarFunction; -use datafusion::logical_expr::{lit, BinaryExpr, Cast, Expr, Operator}; +use datafusion::logical_expr::{ + lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, +}; // --------------------------------------------------------------------------- // Constant helpers and the 0/1-folding builders @@ -208,9 +214,7 @@ fn diff_binary(be: &BinaryExpr, wrt: &str) -> Result { // d/dx (a - b) = da - db Operator::Minus => Ok(sub(da, db)), // d/dx (a * b) = da*b + a*db (product rule) - Operator::Multiply => { - Ok(add(mul(da, b.clone()), mul(a.clone(), db))) - } + Operator::Multiply => Ok(add(mul(da, b.clone()), mul(a.clone(), db))), // d/dx (a / b) = (da*b - a*db) / b^2 (quotient rule) Operator::Divide => { let numerator = sub(mul(da, b.clone()), mul(a.clone(), db)); @@ -302,8 +306,7 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { if is_zero(&dbase) { return Ok(zero()); } - let outer = - mul(lit(c), expr_fn::power(base.clone(), lit(c - 1.0))); + let outer = mul(lit(c), expr_fn::power(base.clone(), lit(c - 1.0))); Ok(mul(outer, dbase)) } // Constant base, variable exponent. @@ -312,10 +315,7 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { if is_zero(&dexp) { return Ok(zero()); } - let outer = mul( - expr_fn::power(base.clone(), exponent.clone()), - lit(a.ln()), - ); + let outer = mul(expr_fn::power(base.clone(), exponent.clone()), lit(a.ln())); Ok(mul(outer, dexp)) } // General u^v requires the exp/log trick; deferred past the MVP. @@ -327,15 +327,113 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { } } +// --------------------------------------------------------------------------- +// The `grad` marker UDF and the plan-level rewrite +// --------------------------------------------------------------------------- + +/// A no-op placeholder UDF for `grad(expr, column)`. +/// +/// `grad` is a *marker*: it carries the differentiation request intact through +/// SQL parsing, logical planning, and Substrait serialization. It is always +/// rewritten away by [`rewrite_grad_calls`] before execution, so its `invoke` +/// is never reached in normal use (and deliberately errors if it somehow is, +/// rather than silently returning a wrong value). +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct GradMarker { + signature: Signature, +} + +impl GradMarker { + pub fn new() -> Self { + // grad(expr, column): two arguments of any (numeric) type. + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl Default for GradMarker { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for GradMarker { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Err(DataFusionError::Execution( + "grad() marker reached execution without being rewritten; this is \ + an internal xarray-sql autograd error" + .to_string(), + )) + } +} + +/// Rewrite every `grad(expr, column)` call anywhere in a logical plan into the +/// symbolic derivative of `expr` with respect to `column`, leaving everything +/// else untouched. The plan's schema is recomputed afterwards because replacing +/// a `grad` call can change an expression's name or type. +pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { + let rewritten = plan + .transform_up(|node| node.map_expressions(rewrite_grad_in_expr))? + .data; + rewritten.recompute_schema() +} + +/// Replace any `grad(...)` calls nested anywhere inside a single expression. +fn rewrite_grad_in_expr(expr: Expr) -> Result> { + expr.transform_up(|e| { + let Expr::ScalarFunction(sf) = &e else { + return Ok(Transformed::no(e)); + }; + if sf.func.name() != "grad" { + return Ok(Transformed::no(e)); + } + if sf.args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "grad() expects two arguments grad(expr, column), got {}", + sf.args.len() + ))); + } + let wrt = match &sf.args[1] { + Expr::Column(c) => c.name.clone(), + other => { + return Err(DataFusionError::Plan(format!( + "grad(): the second argument must be a bare column to \ + differentiate with respect to, got: {other}" + ))) + } + }; + let derivative = differentiate(&sf.args[0], &wrt)?; + Ok(Transformed::yes(derivative)) + }) +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { - use super::*; use datafusion::logical_expr::col; + use super::*; + #[test] fn constant_has_zero_derivative() { assert_eq!(differentiate(&lit(3.0_f64), "x").unwrap(), zero()); @@ -386,8 +484,7 @@ mod tests { #[test] fn composite_sin_times_x() { // d/dx (sin(x) * x) = cos(x)*x + sin(x) - let e = - binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let e = binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); let d = differentiate(&e, "x").unwrap(); assert_eq!(d.to_string(), "cos(x) * x + sin(x)"); } @@ -396,8 +493,7 @@ mod tests { fn power_constant_exponent() { // d/dx power(x, 2) = 2 * power(x, 1) * 1 = 2 * power(x, 1) let e = expr_fn::power(col("x"), lit(2.0_f64)); - let expected = - mul(lit(2.0_f64), expr_fn::power(col("x"), lit(1.0_f64))); + let expected = mul(lit(2.0_f64), expr_fn::power(col("x"), lit(1.0_f64))); assert_eq!(differentiate(&e, "x").unwrap(), expected); } diff --git a/src/lib.rs b/src/lib.rs index df94aac..eb33f17 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,21 +60,27 @@ use datafusion::common::stats::Precision; use datafusion::common::{ ColumnStatistics, DataFusionError, Result as DFResult, ScalarValue, Statistics, }; +use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ - BinaryExpr, Expr, Operator, TableProviderFilterPushDown, TableType, + BinaryExpr, Expr, Operator, ScalarUDF, TableProviderFilterPushDown, TableType, }; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }; +use datafusion::prelude::SessionContext; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; +use datafusion_substrait::logical_plan::consumer::from_substrait_plan; +use datafusion_substrait::logical_plan::producer::to_substrait_plan; +use datafusion_substrait::substrait::proto::Plan; +use prost::Message; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyList}; +use pyo3::types::{PyBytes, PyCapsule, PyList}; // ============================================================================ // Partition Metadata Types for Filter Pushdown @@ -1270,9 +1276,104 @@ impl LazyArrowStreamTable { } } +// ============================================================================ +// Autograd: Substrait-level grad() rewrite +// ============================================================================ + +/// Rewrite `grad(expr, column)` calls in a Substrait plan into their symbolic +/// derivatives. +/// +/// The autograd engine operates on DataFusion logical `Expr` trees. To apply it +/// inside the datafusion-python `SessionContext` (which links its own copy of +/// DataFusion), we move the plan across the boundary as Substrait protobuf: +/// Python produces the plan, this function consumes it into a DataFusion +/// `LogicalPlan`, rewrites every `grad(...)` into the differentiated +/// expression, and re-produces Substrait bytes for Python to consume and +/// execute. +/// +/// Args: +/// plan_bytes: A Substrait `Plan` protobuf, as produced by +/// datafusion-python's +/// ``Producer.to_substrait_plan(plan, ctx).encode()``. +/// tables: A list of ``(name, pyarrow.Schema)`` pairs for every table the +/// plan scans. The consumer resolves table references by name, so each +/// referenced table must be registered here with a matching schema +/// (the data itself is never read — an empty table suffices). +/// +/// Returns: +/// The rewritten Substrait `Plan` protobuf bytes, ready for +/// ``Consumer.from_substrait_plan(ctx, plan)``. +#[pyfunction] +fn grad_rewrite<'py>( + py: Python<'py>, + plan_bytes: &[u8], + tables: Vec<(String, Bound<'py, PyAny>)>, +) -> PyResult> { + // A fresh, data-free context purely for the rewrite. It needs the grad + // marker UDF (so the consumer can resolve the function) and an empty table + // per referenced name (so the consumer can resolve table scans). + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(autograd::GradMarker::new())); + + for (name, schema_obj) in &tables { + let schema = Schema::from_pyarrow_bound(schema_obj).map_err(|e| { + pyo3::exceptions::PyTypeError::new_err(format!( + "grad_rewrite: failed to convert schema for table '{name}': {e}" + )) + })?; + let provider = Arc::new(EmptyTable::new(Arc::new(schema))); + ctx.register_table(name.as_str(), provider).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to register table '{name}': {e}" + )) + })?; + } + + let state = ctx.state(); + + let plan = Plan::decode(plan_bytes).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to decode Substrait plan: {e}" + )) + })?; + + // from_substrait_plan is async but does no real I/O here (empty tables + // resolve immediately), so a minimal current-thread runtime suffices. + let runtime = tokio::runtime::Builder::new_current_thread() + .build() + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "grad_rewrite: failed to build runtime: {e}" + )) + })?; + + let logical = runtime + .block_on(from_substrait_plan(&state, &plan)) + .map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to consume Substrait plan: {e}" + )) + })?; + + let rewritten = autograd::rewrite_grad_calls(logical).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to rewrite grad() calls: {e}" + )) + })?; + + let out_plan = to_substrait_plan(&rewritten, &state).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to produce Substrait plan: {e}" + )) + })?; + + Ok(PyBytes::new(py, &out_plan.encode_to_vec())) +} + /// Python module initialization #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_function(wrap_pyfunction!(grad_rewrite, m)?)?; Ok(()) } From 4970f9f9afae12979a54d154717ffb8f31263ed3 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 27 Jun 2026 23:30:58 +0000 Subject: [PATCH 03/12] Expose grad() in XarrayContext SQL via the Substrait rewrite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire the autograd surface into XarrayContext so users can write calculus directly in SQL: ctx.sql("SELECT grad(sin(val), val) AS d_val, sin(val) AS val FROM t") On construction the context registers the `grad` marker UDF so such queries parse and plan. XarrayContext.sql() detects `grad(` (a cheap regex gate so ordinary queries are untouched) and routes through _sql_with_autograd: it plans the query, produces the logical plan as Substrait, calls the native grad_rewrite to differentiate every grad(expr, column) symbolically, then consumes the rewritten Substrait back into an executable DataFrame. Table scans are resolved by name on the consume side, so _table_schemas() passes the (name, schema) of each registered table to the rewrite. Schema- qualified tables (mixed-dimension datasets) are skipped for now and noted as a follow-up. Adds tests/test_autograd.py covering sin/cos, product and quotient rules, power, exp, the non-grad passthrough, and a clear error for unsupported functions — all checked against numpy analytic derivatives. Existing SQL tests still pass. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- tests/test_autograd.py | 76 ++++++++++++++++++++++++++++++++++++++++++ xarray_sql/sql.py | 72 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 tests/test_autograd.py diff --git a/tests/test_autograd.py b/tests/test_autograd.py new file mode 100644 index 0000000..3e7827a --- /dev/null +++ b/tests/test_autograd.py @@ -0,0 +1,76 @@ +"""Tests for the SQL autograd surface: ``SELECT grad(expr, column) ...``. + +These exercise the full path — XarrayContext.sql() -> Substrait -> native +grad_rewrite -> Substrait -> execute — and compare results against analytic +derivatives computed with numpy. +""" + +import numpy as np +import pytest +import xarray as xr + +import xarray_sql as xql + + +@pytest.fixture +def ctx(): + val = np.linspace(0.1, 3.0, 16) + ds = xr.Dataset( + {"val": (("i",), val)}, + coords={"i": np.arange(16)}, + ) + context = xql.XarrayContext() + context.from_dataset("t", ds, chunks={"i": 5}) + return context + + +def _ordered(df, key="i"): + """Collect a result DataFrame into a dict of column -> numpy array, sorted + by the integer key column so comparisons are index-aligned.""" + pdf = df.to_pandas().sort_values(key) + return {c: pdf[c].to_numpy() for c in pdf.columns} + + +def test_grad_sin_is_cos(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql("SELECT i, grad(sin(val), val) AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val)) + + +def test_grad_product_rule(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql("SELECT i, grad(sin(val) * val, val) AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val) * val + np.sin(val)) + + +def test_grad_exp_equals_value(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql("SELECT i, exp(val) AS v, grad(exp(val), val) AS d FROM t") + ) + np.testing.assert_allclose(res["d"], np.exp(val)) + np.testing.assert_allclose(res["d"], res["v"]) + + +def test_grad_quotient_and_power(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(1.0 / val, val) AS dinv, " + "grad(power(val, 3), val) AS dcube FROM t" + ) + ) + np.testing.assert_allclose(res["dinv"], -1.0 / val**2) + np.testing.assert_allclose(res["dcube"], 3.0 * val**2) + + +def test_non_grad_query_is_unaffected(ctx): + # Queries without grad() bypass the rewrite and behave normally. + res = _ordered(ctx.sql("SELECT i, val FROM t")) + np.testing.assert_allclose(res["val"], np.linspace(0.1, 3.0, 16)) + + +def test_unsupported_function_raises(ctx): + # atan2 has no derivative rule yet -> a clear error, not a wrong answer. + with pytest.raises(Exception): + ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 0ec60ad..5577d61 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -1,13 +1,22 @@ +import re + +import pyarrow as pa import xarray as xr -from datafusion import SessionContext +from datafusion import SessionContext, udf from datafusion.catalog import Schema +from datafusion.substrait import Consumer, Producer, Serde from collections import defaultdict +from . import _native from . import cftime as cft from .df import Chunks from .ds import XarrayDataFrame from .reader import read_xarray_table +# Matches a call to the autograd marker function ``grad(`` (case-insensitive), +# used as a cheap gate so ordinary queries skip the Substrait round-trip. +_GRAD_CALL = re.compile(r"\bgrad\s*\(", re.IGNORECASE) + class XarrayContext(SessionContext): """A datafusion `SessionContext` that also supports `xarray.Dataset`s.""" @@ -21,6 +30,24 @@ def __init__(self, *args, **kwargs): # in SQL (e.g. ``"air"`` for a uniform-dim Dataset, or # ``"era5.surface"`` for one entry from a multi-dim-group split). self._registered_datasets: dict[str, xr.Dataset] = {} + self._register_autograd_udfs() + + def _register_autograd_udfs(self) -> None: + """Register the ``grad`` marker UDF used by the autograd rewrite. + + ``grad(expr, column)`` is a *marker*: it lets queries parse and plan + with the differentiation request intact. It is never executed — the + Substrait rewrite in :meth:`sql` replaces every ``grad(...)`` with the + symbolic derivative of ``expr`` before execution. + """ + marker = udf( + lambda expr, column: expr, + [pa.float64(), pa.float64()], + pa.float64(), + "immutable", + "grad", + ) + self.register_udf(marker) def from_dataset( self, @@ -174,9 +201,50 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: Returns: An :class:`XarrayDataFrame` wrapping the DataFusion DataFrame. """ - inner = super().sql(query, *args, **kwargs) + if _GRAD_CALL.search(query): + inner = self._sql_with_autograd(query, *args, **kwargs) + else: + inner = super().sql(query, *args, **kwargs) return XarrayDataFrame(inner, templates=self._registered_datasets) + def _sql_with_autograd(self, query: str, *args, **kwargs): + """Plan ``query``, rewrite ``grad(...)`` calls, return a DataFrame. + + The differentiation engine lives in the native (Rust) extension and + operates on DataFusion logical expressions. Since that extension links + its own copy of DataFusion, the plan crosses the boundary as Substrait: + we produce the logical plan as Substrait, hand it to ``grad_rewrite`` + (which differentiates every ``grad(expr, column)`` symbolically), then + consume the rewritten Substrait back into an executable DataFrame. + """ + plan = super().sql(query, *args, **kwargs).logical_plan() + substrait_plan = Producer.to_substrait_plan(plan, self) + rewritten = _native.grad_rewrite( + substrait_plan.encode(), self._table_schemas() + ) + new_plan = Consumer.from_substrait_plan( + self, Serde.deserialize_bytes(rewritten) + ) + return self.create_dataframe_from_logical_plan(new_plan) + + def _table_schemas(self) -> list[tuple[str, pa.Schema]]: + """Return ``(name, schema)`` for each registered table. + + The Substrait consumer in ``grad_rewrite`` resolves table scans by + name, so it needs the schema of every table the plan might reference. + Only metadata is read here — never the underlying data. + """ + schemas = [] + for name in self._registered_datasets: + try: + schemas.append((name, self.table(name).schema())) + except Exception: + # Schema-qualified tables (mixed-dimension datasets) aren't + # resolvable by a bare name yet; skip rather than fail the + # whole query. grad() over those is a follow-up. + continue + return schemas + def _group_vars_by_dims(ds: xr.Dataset) -> dict[tuple[str, ...], list[str]]: """Group variables in the dataset based on shared dims. From 7950bad5ce322c0e3c5751be624851f1c378d14d Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 10:22:42 +0000 Subject: [PATCH 04/12] ci: install protoc for the substrait build Adding datafusion-substrait pulls in the `substrait` crate, whose build script generates Rust from .proto files and requires `protoc`. Without it the Rust/maturin builds fail. - ci.yml, ci-build.yml, ci-rust.yml: add arduino/setup-protoc before the build (covers Linux, macOS and Windows runners). - publish.yml: setup-protoc for the macOS/Windows wheel job; for the manylinux maturin-action jobs install protoc inside the container via before-script-linux (arch-aware download). The sdist job is unchanged as it packages source without compiling. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- .github/workflows/ci-build.yml | 7 +++++++ .github/workflows/ci-rust.yml | 7 +++++++ .github/workflows/ci.yml | 7 +++++++ .github/workflows/publish.yml | 31 +++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 214388e..ec89b8e 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -31,6 +31,13 @@ jobs: - uses: dtolnay/rust-toolchain@stable + # The `substrait` crate (a datafusion-substrait dependency) generates + # code from .proto files at build time and requires protoc. + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/ci-rust.yml b/.github/workflows/ci-rust.yml index 68f1ce6..9054a44 100644 --- a/.github/workflows/ci-rust.yml +++ b/.github/workflows/ci-rust.yml @@ -27,6 +27,13 @@ jobs: with: components: clippy + # The `substrait` crate (a datafusion-substrait dependency) generates + # code from .proto files at build time and requires protoc. + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c1d892d..587784d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,6 +43,13 @@ jobs: - uses: dtolnay/rust-toolchain@stable + # The `substrait` crate (a datafusion-substrait dependency) generates + # code from .proto files at build time and requires protoc. + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ab567eb..3be88e7 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -54,6 +54,13 @@ jobs: - uses: dtolnay/rust-toolchain@stable + # The `substrait` crate (a datafusion-substrait dependency) generates + # code from .proto files at build time and requires protoc. + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 @@ -91,6 +98,18 @@ jobs: manylinux: 2_28 rustup-components: rust-std rustfmt sccache: 'true' + # protoc is required by the substrait crate build and must be + # installed inside the manylinux container. + before-script-linux: | + PROTOC_VERSION=29.3 + case "$(uname -m)" in + x86_64) PROTOC_ARCH=x86_64 ;; + aarch64) PROTOC_ARCH=aarch_64 ;; + *) echo "unsupported arch $(uname -m)"; exit 1 ;; + esac + curl -L -o /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-${PROTOC_ARCH}.zip" + unzip -o /tmp/protoc.zip -d /usr/local bin/protoc 'include/*' + protoc --version # abi3 (see Cargo.toml) produces one wheel for all CPython >= 3.10, # so a single interpreter is enough to build it. args: --release --strip --out dist -i python3.10 @@ -115,6 +134,18 @@ jobs: manylinux: 2_28 rustup-components: rust-std rustfmt sccache: 'true' + # protoc is required by the substrait crate build and must be + # installed inside the manylinux container. + before-script-linux: | + PROTOC_VERSION=29.3 + case "$(uname -m)" in + x86_64) PROTOC_ARCH=x86_64 ;; + aarch64) PROTOC_ARCH=aarch_64 ;; + *) echo "unsupported arch $(uname -m)"; exit 1 ;; + esac + curl -L -o /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-${PROTOC_ARCH}.zip" + unzip -o /tmp/protoc.zip -d /usr/local bin/protoc 'include/*' + protoc --version # abi3 (see Cargo.toml) produces one wheel for all CPython >= 3.10, # so a single interpreter is enough to build it. args: --release --strip --out dist -i python3.10 From 301d9cb21ec4a83038d00ce11ff0b6bb223a2ac3 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 10:36:25 +0000 Subject: [PATCH 05/12] Add jacobian() for multi-input gradients Extend the autograd surface from scalar grad() to multi-input Jacobians. SELECT jacobian(sin(x) * y, [x, y]) AS jac FROM g -- per row: [d/dx, d/dy] = [cos(x)*y, sin(x)] (a List) `jacobian(expr, [c1, c2, ...])` differentiates `expr` with respect to each listed column and returns the gradient row as an array. Using a SQL array for the inputs keeps the marker at fixed arity two (avoiding variadic-UDF issues): the `[c1, c2, ...]` parses to make_array(c1, c2, ...), from which the rewrite extracts the input columns; the result is built with make_array of the partials. Array/list columns round-trip through Substrait, verified end to end. The single grad() marker is generalized into a reusable MarkerUdf (with grad_marker()/jacobian_marker() constructors and per-marker return types), and the plan rewrite dispatches on the function name. A full Jacobian can also be written as separate scalar grad() columns, which already worked; both forms are covered by tests against numpy analytic derivatives. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 199 ++++++++++++++++++++++++++++++----------- src/lib.rs | 5 +- tests/test_autograd.py | 53 +++++++++++ xarray_sql/sql.py | 46 ++++++---- 4 files changed, 236 insertions(+), 67 deletions(-) diff --git a/src/autograd.rs b/src/autograd.rs index 1e72586..5b971cf 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -32,15 +32,17 @@ use std::any::Any; use std::f64::consts::{LN_10, LN_2}; +use std::sync::Arc; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{DataFusionError, Result, ScalarValue}; use datafusion::functions::math::expr_fn; +use datafusion::functions_nested::expr_fn::make_array; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, - ScalarUDFImpl, Signature, Volatility, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; // --------------------------------------------------------------------------- @@ -328,43 +330,42 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { } // --------------------------------------------------------------------------- -// The `grad` marker UDF and the plan-level rewrite +// The `grad` / `jacobian` marker UDFs and the plan-level rewrite // --------------------------------------------------------------------------- -/// A no-op placeholder UDF for `grad(expr, column)`. +/// A no-op placeholder UDF for the autograd surface functions. /// -/// `grad` is a *marker*: it carries the differentiation request intact through -/// SQL parsing, logical planning, and Substrait serialization. It is always -/// rewritten away by [`rewrite_grad_calls`] before execution, so its `invoke` -/// is never reached in normal use (and deliberately errors if it somehow is, -/// rather than silently returning a wrong value). +/// `grad` and `jacobian` are *markers*: they carry the differentiation request +/// intact through SQL parsing, logical planning, and Substrait serialization. +/// They are always rewritten away by [`rewrite_grad_calls`] before execution, +/// so `invoke` is never reached in normal use (and deliberately errors if it +/// somehow is, rather than silently returning a wrong value). #[derive(Debug, PartialEq, Eq, Hash)] -pub struct GradMarker { +pub struct MarkerUdf { + name: String, signature: Signature, + return_type: DataType, } -impl GradMarker { - pub fn new() -> Self { - // grad(expr, column): two arguments of any (numeric) type. +impl MarkerUdf { + fn new(name: &str, return_type: DataType) -> Self { Self { + name: name.to_string(), + // Both markers take two arguments: the expression and either a + // column (grad) or an array of columns (jacobian). signature: Signature::any(2, Volatility::Immutable), + return_type, } } } -impl Default for GradMarker { - fn default() -> Self { - Self::new() - } -} - -impl ScalarUDFImpl for GradMarker { +impl ScalarUDFImpl for MarkerUdf { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - "grad" + &self.name } fn signature(&self) -> &Signature { @@ -372,22 +373,75 @@ impl ScalarUDFImpl for GradMarker { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) + Ok(self.return_type.clone()) } fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - Err(DataFusionError::Execution( - "grad() marker reached execution without being rewritten; this is \ - an internal xarray-sql autograd error" - .to_string(), - )) + Err(DataFusionError::Execution(format!( + "{}() marker reached execution without being rewritten; this is \ + an internal xarray-sql autograd error", + self.name + ))) } } -/// Rewrite every `grad(expr, column)` call anywhere in a logical plan into the -/// symbolic derivative of `expr` with respect to `column`, leaving everything -/// else untouched. The plan's schema is recomputed afterwards because replacing -/// a `grad` call can change an expression's name or type. +/// A `List` data type, the output of a `jacobian(...)` call. +fn list_of_f64() -> DataType { + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))) +} + +/// The `grad(expr, column)` marker UDF: returns a scalar derivative. +pub fn grad_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("grad", DataType::Float64)) +} + +/// The `jacobian(expr, [c1, c2, ...])` marker UDF: returns the gradient of +/// `expr` with respect to several columns as a `List`. +pub fn jacobian_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("jacobian", list_of_f64())) +} + +/// Build the Jacobian row `[d(expr)/dc1, d(expr)/dc2, ...]` as an array +/// expression (`make_array`), differentiating `expr` w.r.t. each named column. +fn jacobian(expr: &Expr, wrt: &[String]) -> Result { + let partials = wrt + .iter() + .map(|c| differentiate(expr, c)) + .collect::>>()?; + Ok(make_array(partials)) +} + +/// Extract the bare column names from an array-literal expression, i.e. the +/// `make_array(c1, c2, ...)` that a SQL `[c1, c2, ...]` array parses into. +fn columns_from_array(expr: &Expr) -> Result> { + let Expr::ScalarFunction(sf) = expr else { + return Err(DataFusionError::Plan(format!( + "jacobian(): the second argument must be an array of columns \ + like [x, y, z], got: {expr}" + ))); + }; + if sf.func.name() != "make_array" { + return Err(DataFusionError::Plan(format!( + "jacobian(): the second argument must be an array of columns \ + like [x, y, z], got: {expr}" + ))); + } + sf.args + .iter() + .map(|a| match a { + Expr::Column(c) => Ok(c.name.clone()), + other => Err(DataFusionError::Plan(format!( + "jacobian(): array entries must be bare columns to \ + differentiate with respect to, got: {other}" + ))), + }) + .collect() +} + +/// Rewrite every `grad(...)` / `jacobian(...)` call anywhere in a logical plan +/// into its symbolic derivative(s), leaving everything else untouched. The +/// plan's schema is recomputed afterwards because replacing a marker can change +/// an expression's name or type. pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { let rewritten = plan .transform_up(|node| node.map_expressions(rewrite_grad_in_expr))? @@ -395,33 +449,52 @@ pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { rewritten.recompute_schema() } -/// Replace any `grad(...)` calls nested anywhere inside a single expression. +/// Replace any `grad(...)` / `jacobian(...)` calls nested anywhere inside a +/// single expression. fn rewrite_grad_in_expr(expr: Expr) -> Result> { expr.transform_up(|e| { let Expr::ScalarFunction(sf) = &e else { return Ok(Transformed::no(e)); }; - if sf.func.name() != "grad" { - return Ok(Transformed::no(e)); + match sf.func.name() { + "grad" => Ok(Transformed::yes(rewrite_grad(&sf.args)?)), + "jacobian" => Ok(Transformed::yes(rewrite_jacobian(&sf.args)?)), + _ => Ok(Transformed::no(e)), } - if sf.args.len() != 2 { + }) +} + +/// `grad(expr, column)` -> d(expr)/d(column). +fn rewrite_grad(args: &[Expr]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "grad() expects two arguments grad(expr, column), got {}", + args.len() + ))); + } + let wrt = match &args[1] { + Expr::Column(c) => c.name.clone(), + other => { return Err(DataFusionError::Plan(format!( - "grad() expects two arguments grad(expr, column), got {}", - sf.args.len() - ))); + "grad(): the second argument must be a bare column to \ + differentiate with respect to, got: {other}" + ))) } - let wrt = match &sf.args[1] { - Expr::Column(c) => c.name.clone(), - other => { - return Err(DataFusionError::Plan(format!( - "grad(): the second argument must be a bare column to \ - differentiate with respect to, got: {other}" - ))) - } - }; - let derivative = differentiate(&sf.args[0], &wrt)?; - Ok(Transformed::yes(derivative)) - }) + }; + differentiate(&args[0], &wrt) +} + +/// `jacobian(expr, [c1, c2, ...])` -> array `[d(expr)/dc1, d(expr)/dc2, ...]`. +fn rewrite_jacobian(args: &[Expr]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "jacobian() expects two arguments jacobian(expr, [c1, c2, ...]), \ + got {}", + args.len() + ))); + } + let wrt = columns_from_array(&args[1])?; + jacobian(&args[0], &wrt) } // --------------------------------------------------------------------------- @@ -509,4 +582,30 @@ mod tests { let e = expr_fn::atan2(col("x"), col("y")); assert!(differentiate(&e, "x").is_err()); } + + #[test] + fn jacobian_builds_array_of_partials() { + // jacobian(x*y, [x, y]) = [d/dx, d/dy] = [y, x] + let f = binary(col("x"), Operator::Multiply, col("y")); + let j = jacobian(&f, &["x".to_string(), "y".to_string()]).unwrap(); + let expected = make_array(vec![col("y"), col("x")]); + assert_eq!(j, expected); + } + + #[test] + fn jacobian_single_input_is_one_element_array() { + let j = jacobian(&expr_fn::sin(col("x")), &["x".to_string()]).unwrap(); + assert_eq!(j, make_array(vec![expr_fn::cos(col("x"))])); + } + + #[test] + fn columns_from_array_extracts_names() { + let arr = make_array(vec![col("a"), col("b"), col("c")]); + assert_eq!(columns_from_array(&arr).unwrap(), vec!["a", "b", "c"]); + } + + #[test] + fn columns_from_array_rejects_non_array() { + assert!(columns_from_array(&col("x")).is_err()); + } } diff --git a/src/lib.rs b/src/lib.rs index eb33f17..e6d8142 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,7 +65,7 @@ use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ - BinaryExpr, Expr, Operator, ScalarUDF, TableProviderFilterPushDown, TableType, + BinaryExpr, Expr, Operator, TableProviderFilterPushDown, TableType, }; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; @@ -1313,7 +1313,8 @@ fn grad_rewrite<'py>( // marker UDF (so the consumer can resolve the function) and an empty table // per referenced name (so the consumer can resolve table scans). let ctx = SessionContext::new(); - ctx.register_udf(ScalarUDF::from(autograd::GradMarker::new())); + ctx.register_udf(autograd::grad_marker()); + ctx.register_udf(autograd::jacobian_marker()); for (name, schema_obj) in &tables { let schema = Schema::from_pyarrow_bound(schema_obj).map_err(|e| { diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 3e7827a..13f10a2 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -24,6 +24,22 @@ def ctx(): return context +@pytest.fixture +def ctx_xy(): + rng = np.random.default_rng(0) + n = 16 + ds = xr.Dataset( + { + "x": (("i",), rng.uniform(0.5, 2.5, n)), + "y": (("i",), rng.uniform(0.5, 2.5, n)), + }, + coords={"i": np.arange(n)}, + ) + context = xql.XarrayContext() + context.from_dataset("g", ds, chunks={"i": 5}) + return context, ds + + def _ordered(df, key="i"): """Collect a result DataFrame into a dict of column -> numpy array, sorted by the integer key column so comparisons are index-aligned.""" @@ -74,3 +90,40 @@ def test_unsupported_function_raises(ctx): # atan2 has no derivative rule yet -> a clear error, not a wrong answer. with pytest.raises(Exception): ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() + + +def test_multi_input_grad_columns(ctx_xy): + # A full Jacobian written as separate scalar grad() columns: + # f = x*y -> df/dx = y, df/dy = x. + context, ds = ctx_xy + res = _ordered( + context.sql( + "SELECT i, grad(x * y, x) AS dfdx, grad(x * y, y) AS dfdy FROM g" + ) + ) + np.testing.assert_allclose(res["dfdx"], ds["y"].values) + np.testing.assert_allclose(res["dfdy"], ds["x"].values) + + +def test_jacobian_array(ctx_xy): + # jacobian(f, [x, y]) returns the gradient row [df/dx, df/dy] per row. + context, ds = ctx_xy + res = _ordered( + context.sql("SELECT i, jacobian(x * y, [x, y]) AS jac FROM g") + ) + jac = np.stack([np.asarray(v, dtype=float) for v in res["jac"]]) + # column 0 is df/dx = y, column 1 is df/dy = x + np.testing.assert_allclose(jac[:, 0], ds["y"].values) + np.testing.assert_allclose(jac[:, 1], ds["x"].values) + + +def test_jacobian_array_nonlinear(ctx_xy): + # jacobian(sin(x) * y, [x, y]) = [cos(x)*y, sin(x)] + context, ds = ctx_xy + x, y = ds["x"].values, ds["y"].values + res = _ordered( + context.sql("SELECT i, jacobian(sin(x) * y, [x, y]) AS jac FROM g") + ) + jac = np.stack([np.asarray(v, dtype=float) for v in res["jac"]]) + np.testing.assert_allclose(jac[:, 0], np.cos(x) * y) + np.testing.assert_allclose(jac[:, 1], np.sin(x)) diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 5577d61..635a732 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -13,9 +13,10 @@ from .ds import XarrayDataFrame from .reader import read_xarray_table -# Matches a call to the autograd marker function ``grad(`` (case-insensitive), -# used as a cheap gate so ordinary queries skip the Substrait round-trip. -_GRAD_CALL = re.compile(r"\bgrad\s*\(", re.IGNORECASE) +# Matches a call to an autograd marker function (``grad(`` / ``jacobian(``, +# case-insensitive), used as a cheap gate so ordinary queries skip the +# Substrait round-trip. +_GRAD_CALL = re.compile(r"\b(grad|jacobian)\s*\(", re.IGNORECASE) class XarrayContext(SessionContext): @@ -33,21 +34,36 @@ def __init__(self, *args, **kwargs): self._register_autograd_udfs() def _register_autograd_udfs(self) -> None: - """Register the ``grad`` marker UDF used by the autograd rewrite. + """Register the ``grad`` / ``jacobian`` marker UDFs. - ``grad(expr, column)`` is a *marker*: it lets queries parse and plan - with the differentiation request intact. It is never executed — the - Substrait rewrite in :meth:`sql` replaces every ``grad(...)`` with the - symbolic derivative of ``expr`` before execution. + These are *markers*: they let queries parse and plan with the + differentiation request intact. They are never executed — the Substrait + rewrite in :meth:`sql` replaces every call with the symbolic + derivative before execution. + + * ``grad(expr, column)`` -> scalar ``d(expr)/d(column)``. + * ``jacobian(expr, [c1, c2, ...])`` -> the gradient of ``expr`` with + respect to several columns, as a ``List`` (one Jacobian + row). The second argument is a SQL array of bare column references. """ - marker = udf( - lambda expr, column: expr, - [pa.float64(), pa.float64()], - pa.float64(), - "immutable", - "grad", + self.register_udf( + udf( + lambda expr, column: expr, + [pa.float64(), pa.float64()], + pa.float64(), + "immutable", + "grad", + ) + ) + self.register_udf( + udf( + lambda expr, columns: columns, + [pa.float64(), pa.list_(pa.float64())], + pa.list_(pa.float64()), + "immutable", + "jacobian", + ) ) - self.register_udf(marker) def from_dataset( self, From c952df1fd64a4bf2d3fc0890b90dc5f67fd7872c Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 11:16:43 +0000 Subject: [PATCH 06/12] Replace array jacobian() with jvp()/vjp() forward & reverse modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the jacobian(expr, [cols]) -> List form: a nested array column breaks the long/tidy data model (a cell should be one value aligned to its coordinates). The same Jacobian is expressed in-model as several scalar columns, e.g. grad(f, x) AS dfdx, grad(f, y) AS dfdy. Add forward- and reverse-mode gradients as scalar SQL functions: * jvp(expr, column, tangent) -> d(expr)/d(column) * tangent (forward) * vjp(expr, column, cotangent) -> cotangent * d(expr)/d(column) (reverse) A multi-input directional derivative is the sum of per-input jvp terms; both stay scalar, so they round-trip cleanly through Substrait and back to xarray. Engine: unify grad and jvp behind a single `linearize` (forward-mode chain rule with a pluggable leaf rule) — grad is a one-hot seed, jvp an arbitrary seed per input. This mirrors JAX's structure and removes rule duplication. vjp is cotangent * grad; for a scalar output forward and reverse coincide (asserted by a jvp/vjp agreement test), differing only in seed placement. Tests: 15 Rust unit tests and 11 Python integration tests (incl. jvp/vjp semantics, the multi-input sum, and jvp==vjp for a unit seed), all checked against numpy analytic derivatives. fmt/clippy/ruff/mypy clean. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 318 ++++++++++++++++++++++------------------- src/lib.rs | 3 +- tests/test_autograd.py | 43 ++++-- xarray_sql/sql.py | 46 +++--- 4 files changed, 224 insertions(+), 186 deletions(-) diff --git a/src/autograd.rs b/src/autograd.rs index 5b971cf..14312e6 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -22,23 +22,31 @@ //! `add_tangents`: a `0` derivative short-circuits products and drops out of //! sums, and a `1` factor drops out of products. //! -//! ## Scope (MVP) +//! ## Surface //! -//! This first cut implements scalar `grad`: the partial derivative of a single -//! expression with respect to one named column. Forward-/reverse-mode -//! (`jvp`/`vjp`) and multi-input Jacobians are deliberately left for later. +//! Three scalar operations, all rewritten away before execution: +//! +//! * `grad(expr, column)` — the partial derivative `d(expr)/d(column)`. +//! * `jvp(expr, column, tangent)` — forward-mode directional derivative, +//! `d(expr)/d(column) * tangent` (seed a tangent on an input). +//! * `vjp(expr, column, cotangent)` — reverse-mode pullback, +//! `cotangent * d(expr)/d(column)` (seed a cotangent on the output). +//! +//! All three return a scalar per row, staying in the long/tidy data model. A +//! full gradient or Jacobian is expressed as several scalar columns (e.g. +//! `grad(f, x) AS dfdx, grad(f, y) AS dfdy`) rather than a nested array, which +//! would break the one-value-per-coordinate model. #![allow(dead_code)] use std::any::Any; +use std::collections::HashMap; use std::f64::consts::{LN_10, LN_2}; -use std::sync::Arc; -use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{DataFusionError, Result, ScalarValue}; use datafusion::functions::math::expr_fn; -use datafusion::functions_nested::expr_fn::make_array; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, @@ -160,41 +168,48 @@ fn square(e: Expr) -> Expr { } // --------------------------------------------------------------------------- -// The differentiation rules +// The differentiation engine (forward-mode linearization) // --------------------------------------------------------------------------- -/// Differentiate `expr` with respect to the column named `wrt`. +/// A *leaf rule*: the tangent of a column, i.e. the seed assigned to each input +/// during forward-mode differentiation. /// -/// Returns a new [`Expr`] for the partial derivative, composed of ordinary -/// DataFusion expressions. Returns a [`DataFusionError::NotImplemented`] for -/// expression nodes or scalar functions without a differentiation rule, so the -/// caller can surface a clear, actionable error rather than silently producing -/// a wrong answer. -pub fn differentiate(expr: &Expr, wrt: &str) -> Result { +/// `grad` uses a one-hot leaf (`1` for the differentiation variable, `0` +/// otherwise); `jvp` uses an arbitrary seed per input. Everything above the +/// leaves — the chain rule — is shared. +type Leaf<'a> = dyn Fn(&str) -> Expr + 'a; + +/// Linearize `expr`: push tangents from the leaves (per `leaf`) up through the +/// expression via the chain rule, returning the tangent of `expr`. +/// +/// This is forward-mode automatic differentiation. `differentiate` (a single +/// partial derivative) and `jvp` (a directional derivative) are both thin +/// wrappers that only differ in their leaf rule. Returns a +/// [`DataFusionError::NotImplemented`] for nodes or functions without a rule, +/// so callers surface a clear error rather than a silently-wrong derivative. +fn linearize(expr: &Expr, leaf: &Leaf) -> Result { match expr { - // d/dx (x) = 1 ; d/dx (y) = 0 for any other column. - Expr::Column(c) => Ok(if c.name == wrt { one() } else { zero() }), + // The leaf rule decides a column's tangent. + Expr::Column(c) => Ok(leaf(&c.name)), - // d/dx (constant) = 0. + // Constants have zero tangent. Expr::Literal(_, _) => Ok(zero()), - // An alias is transparent to differentiation; the surrounding query - // re-applies any output naming. - Expr::Alias(a) => differentiate(&a.expr, wrt), + // An alias is transparent; the surrounding query re-applies any naming. + Expr::Alias(a) => linearize(&a.expr, leaf), - // A numeric cast is (locally) linear: d/dx cast(u) = cast(du). We keep - // the cast so the derivative retains the declared output type. + // A numeric cast is (locally) linear: tangent of cast(u) = cast(du). Expr::Cast(c) => { - let du = differentiate(&c.expr, wrt)?; + let du = linearize(&c.expr, leaf)?; Ok(Expr::Cast(Cast::new(Box::new(du), c.data_type.clone()))) } - // d/dx (-u) = -(du). - Expr::Negative(inner) => Ok(neg(differentiate(inner, wrt)?)), + // tangent of -u = -(du). + Expr::Negative(inner) => Ok(neg(linearize(inner, leaf)?)), - Expr::BinaryExpr(be) => diff_binary(be, wrt), + Expr::BinaryExpr(be) => linearize_binary(be, leaf), - Expr::ScalarFunction(sf) => diff_scalar_function(sf, wrt), + Expr::ScalarFunction(sf) => linearize_scalar_function(sf, leaf), other => Err(DataFusionError::NotImplemented(format!( "grad: differentiation is not implemented for this expression: {other}" @@ -202,22 +217,34 @@ pub fn differentiate(expr: &Expr, wrt: &str) -> Result { } } -/// Differentiate a binary arithmetic expression via the sum/product/quotient -/// rules. -fn diff_binary(be: &BinaryExpr, wrt: &str) -> Result { +/// Differentiate `expr` with respect to the column named `wrt`. +/// +/// Forward-mode with a one-hot seed: `1` on `wrt`, `0` on every other column. +pub fn differentiate(expr: &Expr, wrt: &str) -> Result { + linearize(expr, &|name| if name == wrt { one() } else { zero() }) +} + +/// Forward-mode directional derivative: the tangent of `expr` given a tangent +/// (`seeds[col]`) for each seeded input column; unseeded columns are constant. +fn jvp(expr: &Expr, seeds: &HashMap) -> Result { + linearize(expr, &|name| seeds.get(name).cloned().unwrap_or_else(zero)) +} + +/// Linearize a binary arithmetic expression via the sum/product/quotient rules. +fn linearize_binary(be: &BinaryExpr, leaf: &Leaf) -> Result { let a = be.left.as_ref(); let b = be.right.as_ref(); - let da = differentiate(a, wrt)?; - let db = differentiate(b, wrt)?; + let da = linearize(a, leaf)?; + let db = linearize(b, leaf)?; match be.op { - // d/dx (a + b) = da + db + // tangent of (a + b) = da + db Operator::Plus => Ok(add(da, db)), - // d/dx (a - b) = da - db + // tangent of (a - b) = da - db Operator::Minus => Ok(sub(da, db)), - // d/dx (a * b) = da*b + a*db (product rule) + // tangent of (a * b) = da*b + a*db (product rule) Operator::Multiply => Ok(add(mul(da, b.clone()), mul(a.clone(), db))), - // d/dx (a / b) = (da*b - a*db) / b^2 (quotient rule) + // tangent of (a / b) = (da*b - a*db) / b^2 (quotient rule) Operator::Divide => { let numerator = sub(mul(da, b.clone()), mul(a.clone(), db)); Ok(div(numerator, square(b.clone()))) @@ -228,17 +255,17 @@ fn diff_binary(be: &BinaryExpr, wrt: &str) -> Result { } } -/// Differentiate a scalar-function call via the chain rule. +/// Linearize a scalar-function call via the chain rule. /// -/// For a unary primitive `f(u)`, the derivative is `f'(u) * du`. For `power`, +/// For a unary primitive `f(u)`, the tangent is `f'(u) * du`. For `power`, /// which is binary, we handle the constant-exponent and constant-base cases. -fn diff_scalar_function(sf: &ScalarFunction, wrt: &str) -> Result { +fn linearize_scalar_function(sf: &ScalarFunction, leaf: &Leaf) -> Result { let name = sf.func.name(); let args = &sf.args; - // `power(base, exponent)` is the one binary primitive we differentiate. + // `power(base, exponent)` is the one binary primitive we linearize. if name == "power" { - return diff_power(args, wrt); + return linearize_power(args, leaf); } if args.len() != 1 { @@ -249,9 +276,9 @@ fn diff_scalar_function(sf: &ScalarFunction, wrt: &str) -> Result { } let u = &args[0]; - let du = differentiate(u, wrt)?; - // Chain rule short-circuit: if du is 0, the whole derivative is 0 and we - // avoid emitting the (dead) outer derivative term entirely. + let du = linearize(u, leaf)?; + // Chain rule short-circuit: if du is 0, the whole tangent is 0 and we avoid + // emitting the (dead) outer derivative term entirely. if is_zero(&du) { return Ok(zero()); } @@ -287,12 +314,12 @@ fn diff_scalar_function(sf: &ScalarFunction, wrt: &str) -> Result { Ok(mul(outer, du)) } -/// Differentiate `power(base, exponent)`. +/// Linearize `power(base, exponent)`. /// -/// * Constant exponent `c`: `d/dx base^c = c * base^(c-1) * d(base)`. -/// * Constant base `a`: `d/dx a^u = a^u * ln(a) * d(u)`. -/// * Both variable (`u^v`): not supported in the MVP. -fn diff_power(args: &[Expr], wrt: &str) -> Result { +/// * Constant exponent `c`: tangent = `c * base^(c-1) * d(base)`. +/// * Constant base `a`: tangent = `a^u * ln(a) * d(u)`. +/// * Both variable (`u^v`): not supported yet. +fn linearize_power(args: &[Expr], leaf: &Leaf) -> Result { if args.len() != 2 { return Err(DataFusionError::NotImplemented( "grad: power() expects exactly two arguments".to_string(), @@ -304,7 +331,7 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { match (as_const(base), as_const(exponent)) { // Constant exponent (covers the common x^2, x^0.5, ... cases). (_, Some(c)) => { - let dbase = differentiate(base, wrt)?; + let dbase = linearize(base, leaf)?; if is_zero(&dbase) { return Ok(zero()); } @@ -313,14 +340,14 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { } // Constant base, variable exponent. (Some(a), None) => { - let dexp = differentiate(exponent, wrt)?; + let dexp = linearize(exponent, leaf)?; if is_zero(&dexp) { return Ok(zero()); } let outer = mul(expr_fn::power(base.clone(), exponent.clone()), lit(a.ln())); Ok(mul(outer, dexp)) } - // General u^v requires the exp/log trick; deferred past the MVP. + // General u^v requires the exp/log trick; deferred for now. (None, None) => Err(DataFusionError::NotImplemented( "grad: power(base, exponent) where both depend on the \ differentiation variable is not yet supported" @@ -335,26 +362,22 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { /// A no-op placeholder UDF for the autograd surface functions. /// -/// `grad` and `jacobian` are *markers*: they carry the differentiation request -/// intact through SQL parsing, logical planning, and Substrait serialization. -/// They are always rewritten away by [`rewrite_grad_calls`] before execution, -/// so `invoke` is never reached in normal use (and deliberately errors if it -/// somehow is, rather than silently returning a wrong value). +/// `grad`, `jvp`, and `vjp` are *markers*: they carry the differentiation +/// request intact through SQL parsing, logical planning, and Substrait +/// serialization. They are always rewritten away by [`rewrite_grad_calls`] +/// before execution, so `invoke` is never reached in normal use (and +/// deliberately errors if it somehow is, rather than returning a wrong value). #[derive(Debug, PartialEq, Eq, Hash)] pub struct MarkerUdf { name: String, signature: Signature, - return_type: DataType, } impl MarkerUdf { - fn new(name: &str, return_type: DataType) -> Self { + fn new(name: &str, arity: usize) -> Self { Self { name: name.to_string(), - // Both markers take two arguments: the expression and either a - // column (grad) or an array of columns (jacobian). - signature: Signature::any(2, Volatility::Immutable), - return_type, + signature: Signature::any(arity, Volatility::Immutable), } } } @@ -373,7 +396,8 @@ impl ScalarUDFImpl for MarkerUdf { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(self.return_type.clone()) + // Every autograd marker rewrites to a scalar derivative expression. + Ok(DataType::Float64) } fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { @@ -385,63 +409,25 @@ impl ScalarUDFImpl for MarkerUdf { } } -/// A `List` data type, the output of a `jacobian(...)` call. -fn list_of_f64() -> DataType { - DataType::List(Arc::new(Field::new("item", DataType::Float64, true))) -} - -/// The `grad(expr, column)` marker UDF: returns a scalar derivative. +/// The `grad(expr, column)` marker: scalar partial derivative `d(expr)/dcolumn`. pub fn grad_marker() -> ScalarUDF { - ScalarUDF::from(MarkerUdf::new("grad", DataType::Float64)) + ScalarUDF::from(MarkerUdf::new("grad", 2)) } -/// The `jacobian(expr, [c1, c2, ...])` marker UDF: returns the gradient of -/// `expr` with respect to several columns as a `List`. -pub fn jacobian_marker() -> ScalarUDF { - ScalarUDF::from(MarkerUdf::new("jacobian", list_of_f64())) +/// The `jvp(expr, column, tangent)` marker: forward-mode directional derivative. +pub fn jvp_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("jvp", 3)) } -/// Build the Jacobian row `[d(expr)/dc1, d(expr)/dc2, ...]` as an array -/// expression (`make_array`), differentiating `expr` w.r.t. each named column. -fn jacobian(expr: &Expr, wrt: &[String]) -> Result { - let partials = wrt - .iter() - .map(|c| differentiate(expr, c)) - .collect::>>()?; - Ok(make_array(partials)) +/// The `vjp(expr, column, cotangent)` marker: reverse-mode pullback to an input. +pub fn vjp_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("vjp", 3)) } -/// Extract the bare column names from an array-literal expression, i.e. the -/// `make_array(c1, c2, ...)` that a SQL `[c1, c2, ...]` array parses into. -fn columns_from_array(expr: &Expr) -> Result> { - let Expr::ScalarFunction(sf) = expr else { - return Err(DataFusionError::Plan(format!( - "jacobian(): the second argument must be an array of columns \ - like [x, y, z], got: {expr}" - ))); - }; - if sf.func.name() != "make_array" { - return Err(DataFusionError::Plan(format!( - "jacobian(): the second argument must be an array of columns \ - like [x, y, z], got: {expr}" - ))); - } - sf.args - .iter() - .map(|a| match a { - Expr::Column(c) => Ok(c.name.clone()), - other => Err(DataFusionError::Plan(format!( - "jacobian(): array entries must be bare columns to \ - differentiate with respect to, got: {other}" - ))), - }) - .collect() -} - -/// Rewrite every `grad(...)` / `jacobian(...)` call anywhere in a logical plan -/// into its symbolic derivative(s), leaving everything else untouched. The -/// plan's schema is recomputed afterwards because replacing a marker can change -/// an expression's name or type. +/// Rewrite every `grad`/`jvp`/`vjp` call anywhere in a logical plan into its +/// symbolic derivative, leaving everything else untouched. The plan's schema is +/// recomputed afterwards because replacing a marker can change an expression's +/// name or type. pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { let rewritten = plan .transform_up(|node| node.map_expressions(rewrite_grad_in_expr))? @@ -449,8 +435,8 @@ pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { rewritten.recompute_schema() } -/// Replace any `grad(...)` / `jacobian(...)` calls nested anywhere inside a -/// single expression. +/// Replace any `grad`/`jvp`/`vjp` calls nested anywhere inside a single +/// expression. fn rewrite_grad_in_expr(expr: Expr) -> Result> { expr.transform_up(|e| { let Expr::ScalarFunction(sf) = &e else { @@ -458,13 +444,25 @@ fn rewrite_grad_in_expr(expr: Expr) -> Result> { }; match sf.func.name() { "grad" => Ok(Transformed::yes(rewrite_grad(&sf.args)?)), - "jacobian" => Ok(Transformed::yes(rewrite_jacobian(&sf.args)?)), + "jvp" => Ok(Transformed::yes(rewrite_jvp(&sf.args)?)), + "vjp" => Ok(Transformed::yes(rewrite_vjp(&sf.args)?)), _ => Ok(Transformed::no(e)), } }) } -/// `grad(expr, column)` -> d(expr)/d(column). +/// Read a bare column name from a marker argument, or report a clear error. +fn column_arg(func: &str, arg: &Expr) -> Result { + match arg { + Expr::Column(c) => Ok(c.name.clone()), + other => Err(DataFusionError::Plan(format!( + "{func}(): the column argument must be a bare column to \ + differentiate with respect to, got: {other}" + ))), + } +} + +/// `grad(expr, column)` -> `d(expr)/d(column)`. fn rewrite_grad(args: &[Expr]) -> Result { if args.len() != 2 { return Err(DataFusionError::Plan(format!( @@ -472,29 +470,44 @@ fn rewrite_grad(args: &[Expr]) -> Result { args.len() ))); } - let wrt = match &args[1] { - Expr::Column(c) => c.name.clone(), - other => { - return Err(DataFusionError::Plan(format!( - "grad(): the second argument must be a bare column to \ - differentiate with respect to, got: {other}" - ))) - } - }; + let wrt = column_arg("grad", &args[1])?; differentiate(&args[0], &wrt) } -/// `jacobian(expr, [c1, c2, ...])` -> array `[d(expr)/dc1, d(expr)/dc2, ...]`. -fn rewrite_jacobian(args: &[Expr]) -> Result { - if args.len() != 2 { +/// `jvp(expr, column, tangent)` -> forward-mode tangent: seed `tangent` on +/// `column` and push it through `expr`, yielding `d(expr)/d(column) * tangent`. +/// +/// A directional derivative over several inputs is the sum of per-input jvps, +/// e.g. `jvp(f, x, dx) + jvp(f, y, dy)`, since each treats the other inputs as +/// having zero tangent. +fn rewrite_jvp(args: &[Expr]) -> Result { + if args.len() != 3 { + return Err(DataFusionError::Plan(format!( + "jvp() expects three arguments jvp(expr, column, tangent), got {}", + args.len() + ))); + } + let wrt = column_arg("jvp", &args[1])?; + let seeds = HashMap::from([(wrt, args[2].clone())]); + jvp(&args[0], &seeds) +} + +/// `vjp(expr, column, cotangent)` -> reverse-mode pullback: the sensitivity that +/// an output cotangent induces on `column`, i.e. `cotangent * d(expr)/d(column)`. +/// +/// For a single scalar output this equals the matching `jvp` (both contract the +/// same partial derivative); the surfaces differ in where the seed lives — `jvp` +/// seeds an input tangent, `vjp` seeds an output cotangent. +fn rewrite_vjp(args: &[Expr]) -> Result { + if args.len() != 3 { return Err(DataFusionError::Plan(format!( - "jacobian() expects two arguments jacobian(expr, [c1, c2, ...]), \ - got {}", + "vjp() expects three arguments vjp(expr, column, cotangent), got {}", args.len() ))); } - let wrt = columns_from_array(&args[1])?; - jacobian(&args[0], &wrt) + let wrt = column_arg("vjp", &args[1])?; + let derivative = differentiate(&args[0], &wrt)?; + Ok(mul(args[2].clone(), derivative)) } // --------------------------------------------------------------------------- @@ -584,28 +597,37 @@ mod tests { } #[test] - fn jacobian_builds_array_of_partials() { - // jacobian(x*y, [x, y]) = [d/dx, d/dy] = [y, x] + fn jvp_seeds_a_tangent_on_one_input() { + // jvp(x*y, {x: dx}) = product rule with tangent(x)=dx, tangent(y)=0 + // = dx*y + x*0 = dx*y let f = binary(col("x"), Operator::Multiply, col("y")); - let j = jacobian(&f, &["x".to_string(), "y".to_string()]).unwrap(); - let expected = make_array(vec![col("y"), col("x")]); - assert_eq!(j, expected); + let seeds = HashMap::from([("x".to_string(), col("dx"))]); + let t = jvp(&f, &seeds).unwrap(); + assert_eq!(t, mul(col("dx"), col("y"))); } #[test] - fn jacobian_single_input_is_one_element_array() { - let j = jacobian(&expr_fn::sin(col("x")), &["x".to_string()]).unwrap(); - assert_eq!(j, make_array(vec![expr_fn::cos(col("x"))])); + fn jvp_with_unit_seed_matches_grad() { + // A one-hot tangent reproduces the partial derivative. + let f = expr_fn::sin(col("x")); + let seeds = HashMap::from([("x".to_string(), one())]); + assert_eq!(jvp(&f, &seeds).unwrap(), differentiate(&f, "x").unwrap()); } #[test] - fn columns_from_array_extracts_names() { - let arr = make_array(vec![col("a"), col("b"), col("c")]); - assert_eq!(columns_from_array(&arr).unwrap(), vec!["a", "b", "c"]); + fn vjp_equals_cotangent_times_grad() { + // rewrite_vjp(sin(x), x, w) = w * cos(x) + let f = expr_fn::sin(col("x")); + let got = rewrite_vjp(&[f.clone(), col("x"), col("w")]).unwrap(); + assert_eq!(got, mul(col("w"), expr_fn::cos(col("x")))); } #[test] - fn columns_from_array_rejects_non_array() { - assert!(columns_from_array(&col("x")).is_err()); + fn jvp_and_vjp_agree_for_unit_seed() { + // With matching unit seed/cotangent, forward and reverse coincide. + let f = binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let fwd = rewrite_jvp(&[f.clone(), col("x"), one()]).unwrap(); + let rev = rewrite_vjp(&[f, col("x"), one()]).unwrap(); + assert_eq!(fwd, rev); } } diff --git a/src/lib.rs b/src/lib.rs index e6d8142..d224886 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1314,7 +1314,8 @@ fn grad_rewrite<'py>( // per referenced name (so the consumer can resolve table scans). let ctx = SessionContext::new(); ctx.register_udf(autograd::grad_marker()); - ctx.register_udf(autograd::jacobian_marker()); + ctx.register_udf(autograd::jvp_marker()); + ctx.register_udf(autograd::vjp_marker()); for (name, schema_obj) in &tables { let schema = Schema::from_pyarrow_bound(schema_obj).map_err(|e| { diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 13f10a2..b21c74a 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -105,25 +105,42 @@ def test_multi_input_grad_columns(ctx_xy): np.testing.assert_allclose(res["dfdy"], ds["x"].values) -def test_jacobian_array(ctx_xy): - # jacobian(f, [x, y]) returns the gradient row [df/dx, df/dy] per row. +def test_jvp_forward_directional_derivative(ctx_xy): + # jvp(f, x, dx) = df/dx * dx. With f = sin(x)*y and a constant tangent. + context, ds = ctx_xy + x, y = ds["x"].values, ds["y"].values + res = _ordered(context.sql("SELECT i, jvp(sin(x) * y, x, 2.0) AS t FROM g")) + np.testing.assert_allclose(res["t"], (np.cos(x) * y) * 2.0) + + +def test_jvp_multi_input_is_sum(ctx_xy): + # A full directional derivative is the sum of per-input jvp terms: + # df/dx*dx + df/dy*dy for f = x*y, with dx=1, dy=1 -> y + x. context, ds = ctx_xy res = _ordered( - context.sql("SELECT i, jacobian(x * y, [x, y]) AS jac FROM g") + context.sql( + "SELECT i, jvp(x * y, x, 1.0) + jvp(x * y, y, 1.0) AS t FROM g" + ) ) - jac = np.stack([np.asarray(v, dtype=float) for v in res["jac"]]) - # column 0 is df/dx = y, column 1 is df/dy = x - np.testing.assert_allclose(jac[:, 0], ds["y"].values) - np.testing.assert_allclose(jac[:, 1], ds["x"].values) + np.testing.assert_allclose(res["t"], ds["y"].values + ds["x"].values) -def test_jacobian_array_nonlinear(ctx_xy): - # jacobian(sin(x) * y, [x, y]) = [cos(x)*y, sin(x)] +def test_vjp_reverse_pullback(ctx_xy): + # vjp(f, x, w) = w * df/dx. With f = sin(x)*y and cotangent w = 3.0. context, ds = ctx_xy x, y = ds["x"].values, ds["y"].values + res = _ordered(context.sql("SELECT i, vjp(sin(x) * y, x, 3.0) AS s FROM g")) + np.testing.assert_allclose(res["s"], 3.0 * (np.cos(x) * y)) + + +def test_jvp_and_vjp_agree_for_unit_seed(ctx_xy): + # Forward (unit tangent) and reverse (unit cotangent) coincide for a + # scalar output -- both contract the same partial derivative. + context, _ = ctx_xy res = _ordered( - context.sql("SELECT i, jacobian(sin(x) * y, [x, y]) AS jac FROM g") + context.sql( + "SELECT i, jvp(sin(x) * y, x, 1.0) AS fwd, " + "vjp(sin(x) * y, x, 1.0) AS rev FROM g" + ) ) - jac = np.stack([np.asarray(v, dtype=float) for v in res["jac"]]) - np.testing.assert_allclose(jac[:, 0], np.cos(x) * y) - np.testing.assert_allclose(jac[:, 1], np.sin(x)) + np.testing.assert_allclose(res["fwd"], res["rev"]) diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 635a732..c62d7da 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -13,10 +13,10 @@ from .ds import XarrayDataFrame from .reader import read_xarray_table -# Matches a call to an autograd marker function (``grad(`` / ``jacobian(``, +# Matches a call to an autograd marker function (``grad(`` / ``jvp(`` / ``vjp(``, # case-insensitive), used as a cheap gate so ordinary queries skip the # Substrait round-trip. -_GRAD_CALL = re.compile(r"\b(grad|jacobian)\s*\(", re.IGNORECASE) +_GRAD_CALL = re.compile(r"\b(grad|jvp|vjp)\s*\(", re.IGNORECASE) class XarrayContext(SessionContext): @@ -34,35 +34,33 @@ def __init__(self, *args, **kwargs): self._register_autograd_udfs() def _register_autograd_udfs(self) -> None: - """Register the ``grad`` / ``jacobian`` marker UDFs. + """Register the ``grad`` / ``jvp`` / ``vjp`` marker UDFs. These are *markers*: they let queries parse and plan with the differentiation request intact. They are never executed — the Substrait - rewrite in :meth:`sql` replaces every call with the symbolic - derivative before execution. - - * ``grad(expr, column)`` -> scalar ``d(expr)/d(column)``. - * ``jacobian(expr, [c1, c2, ...])`` -> the gradient of ``expr`` with - respect to several columns, as a ``List`` (one Jacobian - row). The second argument is a SQL array of bare column references. + rewrite in :meth:`sql` replaces every call with the symbolic derivative + before execution. All return a scalar, staying in the long/tidy data + model (one value per row). + + * ``grad(expr, column)`` -> ``d(expr)/d(column)``. + * ``jvp(expr, column, tangent)`` -> forward-mode directional derivative + ``d(expr)/d(column) * tangent`` (seed a tangent on an input). A + multi-input directional derivative is a sum of jvp terms. + * ``vjp(expr, column, cotangent)`` -> reverse-mode pullback + ``cotangent * d(expr)/d(column)`` (seed a cotangent on the output). + + A full gradient/Jacobian is expressed as several scalar columns, e.g. + ``grad(f, x) AS dfdx, grad(f, y) AS dfdy``. """ + f64 = pa.float64() self.register_udf( - udf( - lambda expr, column: expr, - [pa.float64(), pa.float64()], - pa.float64(), - "immutable", - "grad", - ) + udf(lambda e, c: e, [f64, f64], f64, "immutable", "grad") ) self.register_udf( - udf( - lambda expr, columns: columns, - [pa.float64(), pa.list_(pa.float64())], - pa.list_(pa.float64()), - "immutable", - "jacobian", - ) + udf(lambda e, c, t: e, [f64, f64, f64], f64, "immutable", "jvp") + ) + self.register_udf( + udf(lambda e, c, w: e, [f64, f64, f64], f64, "immutable", "vjp") ) def from_dataset( From 4c0c449a80d93c193c5045778c1e57c50f2b62b4 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 12:30:28 +0000 Subject: [PATCH 07/12] Support grad/jvp/vjp on schema-qualified tables Mixed-dimension datasets register as schema-qualified tables (e.g. era5.surface / era5.time_x_level). The autograd rewrite consumes the plan in a throwaway context that registers an empty table per scanned name, but register_table("era5.time_x", ...) failed with "failed to resolve schema: era5" because the namespace did not exist. Add ensure_schema(): before registering each table, parse its name into a TableReference and, for qualified names, create the schema namespace (MemorySchemaProvider) in the default catalog if absent. The Python side already resolves qualified names via ctx.table(name).schema(); only the Rust rewrite context needed the namespace. Tests: a mixed-dimension fixture exercising grad on both the 2D surface and 3D atmosphere tables, against numpy analytic derivatives. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/lib.rs | 33 ++++++++++++++++++++++++++++++++- tests/test_autograd.py | 40 ++++++++++++++++++++++++++++++++++++++++ xarray_sql/sql.py | 7 ++++--- 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d224886..61b2a9c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,6 +55,8 @@ use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use async_stream::try_stream; use async_trait::async_trait; use datafusion::catalog::streaming::StreamingTable; +use datafusion::catalog::{MemorySchemaProvider, Session}; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue, TableReference}; use datafusion::catalog::Session; use datafusion::common::stats::Precision; use datafusion::common::{ @@ -1280,6 +1282,22 @@ impl LazyArrowStreamTable { // Autograd: Substrait-level grad() rewrite // ============================================================================ +/// Ensure a schema (namespace) exists in the context's catalog, creating an +/// empty in-memory one if needed. Used so the rewrite context can register +/// schema-qualified tables (e.g. `era5.surface`) that mixed-dimension datasets +/// produce. +fn ensure_schema(ctx: &SessionContext, catalog: Option<&str>, schema: &str) -> DFResult<()> { + // A bare TableReference has no catalog; fall back to DataFusion's default. + let catalog_name = catalog.unwrap_or("datafusion"); + let catalog = ctx + .catalog(catalog_name) + .ok_or_else(|| DataFusionError::Plan(format!("catalog '{catalog_name}' not found")))?; + if catalog.schema(schema).is_none() { + catalog.register_schema(schema, Arc::new(MemorySchemaProvider::new()))?; + } + Ok(()) +} + /// Rewrite `grad(expr, column)` calls in a Substrait plan into their symbolic /// derivatives. /// @@ -1324,7 +1342,20 @@ fn grad_rewrite<'py>( )) })?; let provider = Arc::new(EmptyTable::new(Arc::new(schema))); - ctx.register_table(name.as_str(), provider).map_err(|e| { + + // Schema-qualified names (e.g. "era5.surface", from a mixed-dimension + // dataset) need their namespace to exist before the table can be + // registered into this throwaway context. + let table_ref = TableReference::from(name.as_str()); + if let Some(schema_name) = table_ref.schema() { + ensure_schema(&ctx, table_ref.catalog(), schema_name).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to create schema for table '{name}': {e}" + )) + })?; + } + + ctx.register_table(table_ref, provider).map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( "grad_rewrite: failed to register table '{name}': {e}" )) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index b21c74a..89fc563 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -133,6 +133,46 @@ def test_vjp_reverse_pullback(ctx_xy): np.testing.assert_allclose(res["s"], 3.0 * (np.cos(x) * y)) +@pytest.fixture +def ctx_mixed(): + # A mixed-dimension dataset registers as schema-qualified tables: + # era5.time_x (surface, 2 dims) + # era5.time_x_level (atmosphere, 3 dims) + rng = np.random.default_rng(1) + ds = xr.Dataset( + { + "sfc": (("time", "x"), rng.uniform(0.5, 2.5, (3, 4))), + "atm": (("time", "x", "level"), rng.uniform(0.5, 2.5, (3, 4, 2))), + }, + coords={"time": [0, 1, 2], "x": np.arange(4.0), "level": [0, 1]}, + ) + context = xql.XarrayContext() + context.from_dataset("era5", ds, chunks={"time": 1}) + return context, ds + + +def test_grad_on_qualified_surface_table(ctx_mixed): + context, ds = ctx_mixed + res = _ordered( + context.sql( + "SELECT time, x, sfc, grad(sin(sfc), sfc) AS d FROM era5.time_x" + ), + key="sfc", + ) + np.testing.assert_allclose(res["d"], np.cos(res["sfc"])) + + +def test_grad_on_qualified_atmosphere_table(ctx_mixed): + context, ds = ctx_mixed + res = _ordered( + context.sql( + "SELECT atm, grad(power(atm, 2), atm) AS d FROM era5.time_x_level" + ), + key="atm", + ) + np.testing.assert_allclose(res["d"], 2.0 * res["atm"]) + + def test_jvp_and_vjp_agree_for_unit_seed(ctx_xy): # Forward (unit tangent) and reverse (unit cotangent) coincide for a # scalar output -- both contract the same partial derivative. diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index c62d7da..cb25be0 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -251,11 +251,12 @@ def _table_schemas(self) -> list[tuple[str, pa.Schema]]: schemas = [] for name in self._registered_datasets: try: + # Names may be bare ("air") or schema-qualified ("era5.surface", + # from a mixed-dimension dataset); both resolve here. schemas.append((name, self.table(name).schema())) except Exception: - # Schema-qualified tables (mixed-dimension datasets) aren't - # resolvable by a bare name yet; skip rather than fail the - # whole query. grad() over those is a follow-up. + # Be defensive: skip a table we can't introspect rather than + # failing the whole query. continue return schemas From 4e5d02395030154500d792a713f81448937fe7e4 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 13:17:51 +0000 Subject: [PATCH 08/12] Verify and test higher-order grad Nested calls such as grad(grad(f, x), x) already yield higher-order derivatives: the plan rewrite walks expressions bottom-up (transform_up), so the inner grad is differentiated to a plain expression first and the outer grad differentiates that result. No code change was needed; this adds tests and documents the behavior. - Rust: a unit test that differentiation composes (d2/dx2 sin = -sin). - Python: second derivatives of sin (-sin) and x^3 (6x) and the third derivative of sin (-cos), against numpy. - Doc: note higher-order support in the module overview. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 12 ++++++++++++ tests/test_autograd.py | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/autograd.rs b/src/autograd.rs index 14312e6..0011a4e 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -36,6 +36,10 @@ //! full gradient or Jacobian is expressed as several scalar columns (e.g. //! `grad(f, x) AS dfdx, grad(f, y) AS dfdy`) rather than a nested array, which //! would break the one-value-per-coordinate model. +//! +//! Calls nest, giving higher-order derivatives for free: the rewrite walks +//! bottom-up, so the inner call in `grad(grad(f, x), x)` is differentiated +//! first and the outer call differentiates that result. #![allow(dead_code)] @@ -596,6 +600,14 @@ mod tests { assert!(differentiate(&e, "x").is_err()); } + #[test] + fn higher_order_derivative() { + // Differentiation composes: d2/dx2 sin(x) = -sin(x). + let d1 = differentiate(&expr_fn::sin(col("x")), "x").unwrap(); + let d2 = differentiate(&d1, "x").unwrap(); + assert_eq!(d2, neg(expr_fn::sin(col("x")))); + } + #[test] fn jvp_seeds_a_tangent_on_one_input() { // jvp(x*y, {x: dx}) = product rule with tangent(x)=dx, tangent(y)=0 diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 89fc563..0ed58ff 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -80,6 +80,31 @@ def test_grad_quotient_and_power(ctx): np.testing.assert_allclose(res["dcube"], 3.0 * val**2) +def test_higher_order_grad(ctx): + # Nested grad() differentiates repeatedly: the inner call is rewritten + # first, then the outer differentiates its result. + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, " + "grad(grad(sin(val), val), val) AS d2_sin, " + "grad(grad(power(val, 3), val), val) AS d2_cube FROM t" + ) + ) + np.testing.assert_allclose(res["d2_sin"], -np.sin(val)) # -sin + np.testing.assert_allclose(res["d2_cube"], 6.0 * val) # d2/dx2 x^3 = 6x + + +def test_third_order_grad(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(grad(grad(sin(val), val), val), val) AS d3 FROM t" + ) + ) + np.testing.assert_allclose(res["d3"], -np.cos(val)) # d3/dx3 sin = -cos + + def test_non_grad_query_is_unaffected(ctx): # Queries without grad() bypass the rewrite and behave normally. res = _ordered(ctx.sql("SELECT i, val FROM t")) From 75069ca7265a97a9cf0c5a0599a4033b7bcc1738 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 14:51:00 +0000 Subject: [PATCH 09/12] Add differentiation-through-aggregate tests and docs Document and test that differentiating through SUM/AVG is just linearity: AGG(grad(f, x)) == d/dx AGG(f). Writing grad inside the aggregate composes with SQL scoping (the marker rewrites to plain SQL before the aggregate runs), so it needs no special machinery -- enough to express gradient descent in SQL. Adds tests for SUM/AVG(grad(...)) and an end-to-end gradient-descent convergence test, plus a note in the module overview. The runnable benchmark scripts live on stacked demo branches to keep this feature branch reviewable. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 8 +++++++ tests/test_autograd.py | 52 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/src/autograd.rs b/src/autograd.rs index 0011a4e..7b3755d 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -40,6 +40,14 @@ //! Calls nest, giving higher-order derivatives for free: the rewrite walks //! bottom-up, so the inner call in `grad(grad(f, x), x)` is differentiated //! first and the outer call differentiates that result. +//! +//! Differentiation through an aggregate is just linearity and needs no special +//! handling: write the `grad` *inside* the aggregate, e.g. `SUM(grad(f, x))` or +//! `AVG(grad(loss, theta))`. Because the marker is rewritten to plain SQL +//! before the aggregate runs (and the column is in scope there), this is the +//! relational `d/dθ Σ f = Σ ∂f/∂θ` — enough to run gradient descent in SQL. +//! (The transposed form `grad(SUM(f), x)` is rejected by SQL's own scoping, +//! since `x` is gone after aggregation.) #![allow(dead_code)] diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 0ed58ff..0400465 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -117,6 +117,58 @@ def test_unsupported_function_raises(ctx): ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() +def test_grad_inside_aggregate(ctx): + # Differentiation through an aggregate is just linearity: + # AGG(grad(f, x)) == d/dx AGG(f). grad rewrites to plain SQL before the + # aggregate runs, so this composes with no special machinery. + val = np.linspace(0.1, 3.0, 16) + res = ctx.sql( + "SELECT SUM(grad(val * val, val)) AS s, " + "AVG(grad(sin(val), val)) AS a FROM t" + ).to_pandas() + np.testing.assert_allclose(res["s"][0], np.sum(2 * val)) + np.testing.assert_allclose(res["a"][0], np.mean(np.cos(val))) + + +def test_gradient_descent_in_sql(): + # End to end: fit y ~= a*x + b by minimising MSE, with the gradients + # w.r.t. the parameters computed in SQL via AVG(grad(loss, param)). + rng = np.random.default_rng(0) + n = 200 + x = rng.uniform(0.0, 1.0, n) + a_true, b_true = 2.0, -1.0 + y = a_true * x + b_true + rng.normal(0.0, 0.01, n) + data = xr.Dataset( + {"x": (("i",), x), "y": (("i",), y)}, coords={"i": np.arange(n)} + ) + ctx = xql.XarrayContext() + ctx.from_dataset("d", data, chunks={"i": n}) + + resid = "(y - (a * x + b))" + loss = f"{resid} * {resid}" + a, b, lr = 0.0, 0.0, 0.4 + losses = [] + for _ in range(120): + if "params" in ctx._registered_datasets: + ctx.deregister_table("params") + del ctx._registered_datasets["params"] + params = xr.Dataset( + {"a": (("p",), [a]), "b": (("p",), [b])}, coords={"p": [0]} + ) + ctx.from_dataset("params", params, chunks={"p": 1}) + row = ctx.sql( + f"SELECT AVG({loss}) AS loss, " + f"AVG(grad({loss}, a)) AS dl_da, " + f"AVG(grad({loss}, b)) AS dl_db FROM d CROSS JOIN params" + ).to_pandas() + losses.append(float(row["loss"][0])) + a -= lr * float(row["dl_da"][0]) + b -= lr * float(row["dl_db"][0]) + + assert losses[-1] < losses[0] # loss decreased + np.testing.assert_allclose([a, b], [a_true, b_true], atol=0.05) + + def test_multi_input_grad_columns(ctx_xy): # A full Jacobian written as separate scalar grad() columns: # f = x*y -> df/dx = y, df/dy = x. From 0d2666cd3fd16e2c500057961d30f088c27e63d7 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 15:14:49 +0000 Subject: [PATCH 10/12] Resolve grad over any registered table, not just xarray ones Generalize _table_schemas() to enumerate the catalog instead of only the xarray-registered datasets, so the Substrait rewrite can resolve grad() queries that reference plain DataFusion tables too -- e.g. in-memory MemTables holding model parameters or intermediate results. This makes grad compose with ordinary relational state (a parameter table you INSERT into), not only gridded xarray data. Adds a test differentiating an expression whose coefficient lives in an in-memory table cross-joined to the xarray data. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- tests/test_autograd.py | 17 ++++++++++++++++ xarray_sql/sql.py | 44 +++++++++++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 0400465..ca17de9 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -6,6 +6,7 @@ """ import numpy as np +import pyarrow as pa import pytest import xarray as xr @@ -117,6 +118,22 @@ def test_unsupported_function_raises(ctx): ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() +def test_grad_over_in_memory_table(ctx): + # grad works over plain DataFusion tables too (not just xarray-registered + # ones): here a coefficient lives in an in-memory MemTable cross-joined to + # the xarray data. d/dval (c * val^2) = c * 2*val, with c = 3. + ctx.register_record_batches( + "coef", [[pa.RecordBatch.from_pydict({"c": [3.0]})]] + ) + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(c * val * val, val) AS d FROM t CROSS JOIN coef" + ) + ) + np.testing.assert_allclose(res["d"], 3.0 * 2.0 * val) + + def test_grad_inside_aggregate(ctx): # Differentiation through an aggregate is just linearity: # AGG(grad(f, x)) == d/dx AGG(f). grad rewrites to plain SQL before the diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index cb25be0..5e892d3 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -242,22 +242,40 @@ def _sql_with_autograd(self, query: str, *args, **kwargs): return self.create_dataframe_from_logical_plan(new_plan) def _table_schemas(self) -> list[tuple[str, pa.Schema]]: - """Return ``(name, schema)`` for each registered table. - - The Substrait consumer in ``grad_rewrite`` resolves table scans by - name, so it needs the schema of every table the plan might reference. - Only metadata is read here — never the underlying data. + """Return ``(name, schema)`` for every table registered in the context. + + The Substrait consumer in ``grad_rewrite`` resolves table scans by name, + so it needs the schema of every table the plan might reference. We + enumerate the catalog rather than only the xarray-registered datasets, + so ``grad`` also works over plain DataFusion tables (e.g. in-memory + ``MemTable``s holding model parameters or intermediate results). Only + metadata is read here — never the underlying data. """ schemas = [] - for name in self._registered_datasets: - try: - # Names may be bare ("air") or schema-qualified ("era5.surface", - # from a mixed-dimension dataset); both resolve here. - schemas.append((name, self.table(name).schema())) - except Exception: - # Be defensive: skip a table we can't introspect rather than - # failing the whole query. + catalog = self.catalog() + for schema_name in catalog.schema_names(): + if schema_name == "information_schema": continue + schema = catalog.schema(schema_name) + names = ( + schema.table_names() + if hasattr(schema, "table_names") + else schema.names() + ) + for table_name in names: + # Tables in the default schema are referenced bare ("air"); + # others are schema-qualified ("era5.surface"). + qualified = ( + table_name + if schema_name in ("public", "default") + else f"{schema_name}.{table_name}" + ) + try: + schemas.append((qualified, self.table(qualified).schema())) + except Exception: + # Be defensive: skip a table we can't introspect rather + # than failing the whole query. + continue return schemas From 7bade4d69e9dc60a54755f006d8f6bbbd675ecb0 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 15:43:21 +0000 Subject: [PATCH 11/12] Add differentiate_sql: differentiate an expression to SQL text Expose the autograd engine as a "calculus compiler": differentiate_sql(expr, wrt, columns) parses a SQL scalar expression (parse_sql_expr), differentiates it with the engine, and unparses the derivative back to SQL (expr_to_sql). Where grad(...) rewrites a whole plan via Substrait, this hands back a single derivative expression as text -- usable where the Substrait round-trip can't carry a grad marker, e.g. embedding a precomputed update rule inside a recursive -CTE training loop (Substrait has no recursion). Exposed as xarray_sql. differentiate_sql; covered by a round-trip test. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/lib.rs | 60 ++++++++++++++++++++++++++++++++++++++++-- tests/test_autograd.py | 9 +++++++ xarray_sql/__init__.py | 2 ++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 61b2a9c..1b5954a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,13 +50,15 @@ use std::fmt::Debug; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::datatypes::{DataType, Schema, SchemaRef, TimeUnit}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use async_stream::try_stream; use async_trait::async_trait; use datafusion::catalog::streaming::StreamingTable; use datafusion::catalog::{MemorySchemaProvider, Session}; -use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue, TableReference}; +use datafusion::common::{ + DFSchema, DataFusionError, Result as DFResult, ScalarValue, TableReference, +}; use datafusion::catalog::Session; use datafusion::common::stats::Precision; use datafusion::common::{ @@ -75,6 +77,7 @@ use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }; use datafusion::prelude::SessionContext; +use datafusion::sql::unparser::expr_to_sql; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; @@ -1403,10 +1406,63 @@ fn grad_rewrite<'py>( Ok(PyBytes::new(py, &out_plan.encode_to_vec())) } +/// Differentiate a SQL scalar expression symbolically and return the +/// derivative as SQL text. +/// +/// Where [`grad_rewrite`] rewrites `grad(...)` calls inside a whole plan, this +/// differentiates a single expression and hands back the result as SQL — the +/// autograd engine acting as a "calculus compiler". It lets a caller obtain an +/// update rule once and embed it in queries the Substrait round-trip can't +/// carry a `grad` marker through, such as a recursive-CTE training loop. +/// +/// Args: +/// expr: A SQL scalar expression over `columns` (e.g. `"sin(x) * x"`). +/// wrt: The column name to differentiate with respect to. +/// columns: The column names in scope; all treated as `Float64` (enough to +/// parse and differentiate — types don't affect the symbolic result). +/// +/// Returns: +/// The derivative as a SQL string (e.g. `"cos(x) * x + sin(x)"`). +#[pyfunction] +fn differentiate_sql(expr: &str, wrt: &str, columns: Vec) -> PyResult { + let ctx = SessionContext::new(); + + let fields: Vec = columns + .iter() + .map(|name| Field::new(name, DataType::Float64, true)) + .collect(); + let df_schema = DFSchema::try_from(Schema::new(fields)).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to build schema: {e}" + )) + })?; + + let parsed = ctx.parse_sql_expr(expr, &df_schema).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to parse expression '{expr}': {e}" + )) + })?; + + let derivative = autograd::differentiate(&parsed, wrt).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to differentiate: {e}" + )) + })?; + + let sql = expr_to_sql(&derivative).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to render derivative to SQL: {e}" + )) + })?; + + Ok(sql.to_string()) +} + /// Python module initialization #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(grad_rewrite, m)?)?; + m.add_function(wrap_pyfunction!(differentiate_sql, m)?)?; Ok(()) } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index ca17de9..2ccc7dd 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -134,6 +134,15 @@ def test_grad_over_in_memory_table(ctx): np.testing.assert_allclose(res["d"], 3.0 * 2.0 * val) +def test_differentiate_sql_round_trip(ctx): + # differentiate_sql returns the derivative as SQL text; evaluating it must + # match the analytic derivative. d/dval (sin(val)*val) = cos(val)*val + sin(val). + deriv = xql.differentiate_sql("sin(val) * val", "val", ["val"]) + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql(f"SELECT i, {deriv} AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val) * val + np.sin(val)) + + def test_grad_inside_aggregate(ctx): # Differentiation through an aggregate is just linearity: # AGG(grad(f, x)) == d/dx AGG(f). grad rewrites to plain SQL before the diff --git a/xarray_sql/__init__.py b/xarray_sql/__init__.py index d1e5984..c01f295 100644 --- a/xarray_sql/__init__.py +++ b/xarray_sql/__init__.py @@ -1,4 +1,5 @@ from . import cftime +from ._native import differentiate_sql from .df import from_map from .reader import read_xarray, read_xarray_table from .sql import XarrayContext @@ -6,6 +7,7 @@ __all__ = [ "cftime", "XarrayContext", + "differentiate_sql", "read_xarray_table", "read_xarray", "from_map", # deprecated From 14b26971a99ea76237bcd907b8bf9a858f7788c7 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Tue, 30 Jun 2026 13:29:49 +0300 Subject: [PATCH 12/12] Differentiate grad() as a SQL rewrite, dropping the Substrait bridge Make grad()/jvp()/vjp() work inside any query shape (recursive CTEs, DML, subqueries) by rewriting the calls as SQL text before planning, rather than round-tripping the logical plan through Substrait (which could not represent those shapes). Closes the gap tracked in #197. XarrayContext.sql() now hands a query containing a marker to the native rewrite_grad_sql, which parses the statement with sqlparser, differentiates each marker call with the existing engine, and renders the derivative back into the SQL in place. Because it runs before planning, every query shape the parser accepts is supported, and the result is ordinary SQL the stock datafusion-python context plans and executes directly. This removes the Substrait round-trip entirely: the datafusion-substrait and prost dependencies, the grad_rewrite/_sql_with_autograd/_table_schemas plumbing, the marker-UDF registration, and the protoc steps in CI. Unlike the FFI alternative, it needs no datafusion fork and no custom datafusion-python wheel. The grad surface is unchanged (same SQL, same results); marker arguments use unqualified column names, matching existing usage, since differentiation is syntactic and runs before binding. Co-Authored-By: Claude Opus 4.8 --- .github/workflows/ci-build.yml | 7 -- .github/workflows/ci-rust.yml | 7 -- .github/workflows/ci.yml | 7 -- .github/workflows/publish.yml | 7 -- Cargo.toml | 3 +- src/autograd.rs | 189 ++++++++++++++++++++++++++++++++- src/lib.rs | 135 +++-------------------- tests/test_autograd.py | 24 ++++- xarray_sql/sql.py | 117 +++++--------------- 9 files changed, 254 insertions(+), 242 deletions(-) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index ec89b8e..214388e 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -31,13 +31,6 @@ jobs: - uses: dtolnay/rust-toolchain@stable - # The `substrait` crate (a datafusion-substrait dependency) generates - # code from .proto files at build time and requires protoc. - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/ci-rust.yml b/.github/workflows/ci-rust.yml index 9054a44..68f1ce6 100644 --- a/.github/workflows/ci-rust.yml +++ b/.github/workflows/ci-rust.yml @@ -27,13 +27,6 @@ jobs: with: components: clippy - # The `substrait` crate (a datafusion-substrait dependency) generates - # code from .proto files at build time and requires protoc. - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 587784d..c1d892d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,13 +43,6 @@ jobs: - uses: dtolnay/rust-toolchain@stable - # The `substrait` crate (a datafusion-substrait dependency) generates - # code from .proto files at build time and requires protoc. - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3be88e7..d1aa560 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -54,13 +54,6 @@ jobs: - uses: dtolnay/rust-toolchain@stable - # The `substrait` crate (a datafusion-substrait dependency) generates - # code from .proto files at build time and requires protoc. - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/Cargo.toml b/Cargo.toml index eda0422..5503571 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,8 +23,7 @@ async-stream = "0.3" async-trait = "0.1" datafusion = { version = "54.0.0" } datafusion-ffi = { version = "54.0.0" } -datafusion-substrait = { version = "52.0.0" } -prost = "0.14" +sqlparser = { version = "0.59", features = ["visitor"] } futures = { version = "0.3" } # `abi3-py310` builds against CPython's stable ABI, so a single wheel per # platform works on all CPython >= 3.10 (matching `requires-python`). This diff --git a/src/autograd.rs b/src/autograd.rs index 7b3755d..c729ca3 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -54,16 +54,23 @@ use std::any::Any; use std::collections::HashMap; use std::f64::consts::{LN_10, LN_2}; +use std::ops::ControlFlow; +use std::sync::Arc; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{DataFusionError, Result, ScalarValue}; +use datafusion::common::{DFSchema, DataFusionError, Result, ScalarValue, TableReference}; use datafusion::functions::math::expr_fn; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use datafusion::prelude::SessionContext; +use datafusion::sql::unparser::expr_to_sql; +use sqlparser::ast::{Expr as SqlExpr, Visit, VisitMut, Visitor, VisitorMut}; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::Parser; // --------------------------------------------------------------------------- // Constant helpers and the 0/1-folding builders @@ -522,6 +529,142 @@ fn rewrite_vjp(args: &[Expr]) -> Result { Ok(mul(args[2].clone(), derivative)) } +// --------------------------------------------------------------------------- +// SQL source-to-source rewrite +// --------------------------------------------------------------------------- + +/// Rewrite every `grad`/`jvp`/`vjp` call in a SQL statement into its symbolic +/// derivative, returning the rewritten SQL text. +/// +/// Unlike a logical-plan rewrite, this is a pure source-to-source transform run +/// *before* the query is planned, so it works for any query shape the SQL parser +/// accepts — recursive CTEs, DML, and subqueries included. Each marker call is +/// parsed into a DataFusion [`Expr`], differentiated by the engine in this +/// module, and rendered back to SQL in place. Columns are taken from the call's +/// own identifiers (all treated as `Float64`; types don't affect the symbolic +/// result), so no catalog or table schema is needed. +pub fn rewrite_grad_in_sql(sql: &str) -> Result { + let dialect = GenericDialect {}; + let mut statements = Parser::parse_sql(&dialect, sql) + .map_err(|e| DataFusionError::Plan(format!("grad: failed to parse SQL: {e}")))?; + + // A throwaway context that only needs the marker UDFs registered so the + // calls parse into `ScalarFunction` nodes the engine can dispatch on. + let ctx = SessionContext::new(); + ctx.register_udf(grad_marker()); + ctx.register_udf(jvp_marker()); + ctx.register_udf(vjp_marker()); + + let mut rewriter = GradSqlRewriter { ctx: &ctx }; + for stmt in &mut statements { + if let ControlFlow::Break(msg) = stmt.visit(&mut rewriter) { + return Err(DataFusionError::Plan(msg)); + } + } + + Ok(statements + .iter() + .map(ToString::to_string) + .collect::>() + .join("; ")) +} + +/// True if `name` is one of the autograd marker functions (case-insensitive). +fn is_marker_name(name: &str) -> bool { + matches!(name.to_lowercase().as_str(), "grad" | "jvp" | "vjp") +} + +/// Walks a SQL AST and replaces each `grad`/`jvp`/`vjp` call with its derivative. +struct GradSqlRewriter<'a> { + ctx: &'a SessionContext, +} + +impl VisitorMut for GradSqlRewriter<'_> { + type Break = String; + + fn pre_visit_expr(&mut self, expr: &mut SqlExpr) -> ControlFlow { + let is_marker = matches!( + expr, + SqlExpr::Function(f) if is_marker_name(&f.name.to_string()) + ); + if !is_marker { + return ControlFlow::Continue(()); + } + match self.rewrite_call(expr) { + Ok(()) => ControlFlow::Continue(()), + Err(e) => ControlFlow::Break(e), + } + } +} + +impl GradSqlRewriter<'_> { + /// Differentiate a single marker call in place. The replacement is wrapped + /// in parentheses so it keeps the call's precedence in the surrounding SQL. + fn rewrite_call(&self, expr: &mut SqlExpr) -> std::result::Result<(), String> { + let schema = call_schema(expr)?; + let text = expr.to_string(); + let parsed = self + .ctx + .parse_sql_expr(&text, &schema) + .map_err(|e| format!("grad: failed to parse '{text}': {e}"))?; + let derivative = rewrite_grad_in_expr(parsed) + .map_err(|e| format!("grad: failed to differentiate '{text}': {e}"))? + .data; + let rendered = expr_to_sql(&derivative) + .map_err(|e| format!("grad: failed to render derivative for '{text}': {e}"))?; + *expr = SqlExpr::Nested(Box::new(rendered)); + Ok(()) + } +} + +/// Build a `Float64` schema covering every column identifier referenced inside a +/// marker call, so the call's argument expression can be parsed standalone. +fn call_schema(call: &SqlExpr) -> std::result::Result { + let mut collector = ColumnCollector::default(); + let _ = call.visit(&mut collector); + let fields = collector + .cols + .into_iter() + .map(|(qualifier, name)| { + let qualifier = qualifier.map(TableReference::bare); + ( + qualifier, + Arc::new(Field::new(name, DataType::Float64, true)), + ) + }) + .collect(); + DFSchema::new_with_metadata(fields, HashMap::new()) + .map_err(|e| format!("grad: failed to build schema for differentiation: {e}")) +} + +/// Collects the (optional qualifier, name) of every column identifier in a SQL +/// expression tree. +#[derive(Default)] +struct ColumnCollector { + cols: Vec<(Option, String)>, +} + +impl Visitor for ColumnCollector { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &SqlExpr) -> ControlFlow<()> { + let pair = match expr { + SqlExpr::Identifier(ident) => Some((None, ident.value.clone())), + SqlExpr::CompoundIdentifier(parts) => parts.last().map(|last| { + let qualifier = (parts.len() >= 2).then(|| parts[parts.len() - 2].value.clone()); + (qualifier, last.value.clone()) + }), + _ => None, + }; + if let Some(pair) = pair { + if !self.cols.contains(&pair) { + self.cols.push(pair); + } + } + ControlFlow::Continue(()) + } +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -650,4 +793,46 @@ mod tests { let rev = rewrite_vjp(&[f, col("x"), one()]).unwrap(); assert_eq!(fwd, rev); } + + #[test] + fn sql_rewrite_replaces_grad_call() { + // grad(sin(x), x) -> cos(x); the surrounding SELECT is preserved. + let out = rewrite_grad_in_sql("SELECT grad(sin(x), x) AS d FROM t").unwrap(); + assert_eq!(out, "SELECT (cos(x)) AS d FROM t"); + } + + #[test] + fn sql_rewrite_leaves_non_grad_queries_intact() { + // A query with no marker is still parsed and re-emitted unchanged in + // meaning (the caller only invokes the rewrite when a marker is present). + let out = rewrite_grad_in_sql("SELECT a + b FROM t").unwrap(); + assert_eq!(out, "SELECT a + b FROM t"); + } + + #[test] + fn sql_rewrite_fires_inside_recursive_cte() { + // The #197 capability: a marker inside a recursive term is rewritten, + // a query shape the Substrait bridge could never carry. d/dx(x*x) = x+x. + let out = rewrite_grad_in_sql( + "WITH RECURSIVE r AS (SELECT 1.0 AS x UNION ALL \ + SELECT x - grad(x * x, x) FROM r WHERE x < 10) SELECT x FROM r", + ) + .unwrap(); + assert!(out.contains("(x + x)"), "unexpected rewrite: {out}"); + assert!( + !out.to_lowercase().contains("grad("), + "marker left behind: {out}" + ); + } + + #[test] + fn sql_rewrite_handles_nested_higher_order_grad() { + // grad(grad(power(x, 3), x), x) -> d2/dx2 (x^3) = 6x; bottom-up so the + // inner call is differentiated before the outer one. + let out = rewrite_grad_in_sql("SELECT grad(grad(power(x, 3), x), x) AS d FROM t").unwrap(); + assert!( + !out.to_lowercase().contains("grad("), + "marker left behind: {out}" + ); + } } diff --git a/src/lib.rs b/src/lib.rs index 1b5954a..6c592ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,12 +80,8 @@ use datafusion::prelude::SessionContext; use datafusion::sql::unparser::expr_to_sql; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; -use datafusion_substrait::logical_plan::consumer::from_substrait_plan; -use datafusion_substrait::logical_plan::producer::to_substrait_plan; -use datafusion_substrait::substrait::proto::Plan; -use prost::Message; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyCapsule, PyList}; +use pyo3::types::{PyCapsule, PyList}; // ============================================================================ // Partition Metadata Types for Filter Pushdown @@ -1282,128 +1278,31 @@ impl LazyArrowStreamTable { } // ============================================================================ -// Autograd: Substrait-level grad() rewrite +// Autograd: SQL-level grad() rewrite // ============================================================================ -/// Ensure a schema (namespace) exists in the context's catalog, creating an -/// empty in-memory one if needed. Used so the rewrite context can register -/// schema-qualified tables (e.g. `era5.surface`) that mixed-dimension datasets -/// produce. -fn ensure_schema(ctx: &SessionContext, catalog: Option<&str>, schema: &str) -> DFResult<()> { - // A bare TableReference has no catalog; fall back to DataFusion's default. - let catalog_name = catalog.unwrap_or("datafusion"); - let catalog = ctx - .catalog(catalog_name) - .ok_or_else(|| DataFusionError::Plan(format!("catalog '{catalog_name}' not found")))?; - if catalog.schema(schema).is_none() { - catalog.register_schema(schema, Arc::new(MemorySchemaProvider::new()))?; - } - Ok(()) -} - -/// Rewrite `grad(expr, column)` calls in a Substrait plan into their symbolic -/// derivatives. +/// Rewrite every `grad`/`jvp`/`vjp` call in a SQL query into its symbolic +/// derivative, returning the rewritten SQL text. /// -/// The autograd engine operates on DataFusion logical `Expr` trees. To apply it -/// inside the datafusion-python `SessionContext` (which links its own copy of -/// DataFusion), we move the plan across the boundary as Substrait protobuf: -/// Python produces the plan, this function consumes it into a DataFusion -/// `LogicalPlan`, rewrites every `grad(...)` into the differentiated -/// expression, and re-produces Substrait bytes for Python to consume and -/// execute. +/// The autograd engine operates on DataFusion logical `Expr` trees. Rather than +/// round-tripping a whole plan across the cdylib boundary, this rewrites the +/// query as **SQL text** before it is planned: each marker call is parsed, +/// differentiated, and rendered back to SQL in place. Because it runs before +/// planning, it works for any query shape the parser accepts — recursive CTEs, +/// DML, and subqueries — which the plan-level Substrait bridge could not carry. /// /// Args: -/// plan_bytes: A Substrait `Plan` protobuf, as produced by -/// datafusion-python's -/// ``Producer.to_substrait_plan(plan, ctx).encode()``. -/// tables: A list of ``(name, pyarrow.Schema)`` pairs for every table the -/// plan scans. The consumer resolves table references by name, so each -/// referenced table must be registered here with a matching schema -/// (the data itself is never read — an empty table suffices). +/// query: A SQL query string that may contain `grad`/`jvp`/`vjp` calls. /// /// Returns: -/// The rewritten Substrait `Plan` protobuf bytes, ready for -/// ``Consumer.from_substrait_plan(ctx, plan)``. +/// The rewritten SQL string, ready to pass to ``SessionContext.sql``. #[pyfunction] -fn grad_rewrite<'py>( - py: Python<'py>, - plan_bytes: &[u8], - tables: Vec<(String, Bound<'py, PyAny>)>, -) -> PyResult> { - // A fresh, data-free context purely for the rewrite. It needs the grad - // marker UDF (so the consumer can resolve the function) and an empty table - // per referenced name (so the consumer can resolve table scans). - let ctx = SessionContext::new(); - ctx.register_udf(autograd::grad_marker()); - ctx.register_udf(autograd::jvp_marker()); - ctx.register_udf(autograd::vjp_marker()); - - for (name, schema_obj) in &tables { - let schema = Schema::from_pyarrow_bound(schema_obj).map_err(|e| { - pyo3::exceptions::PyTypeError::new_err(format!( - "grad_rewrite: failed to convert schema for table '{name}': {e}" - )) - })?; - let provider = Arc::new(EmptyTable::new(Arc::new(schema))); - - // Schema-qualified names (e.g. "era5.surface", from a mixed-dimension - // dataset) need their namespace to exist before the table can be - // registered into this throwaway context. - let table_ref = TableReference::from(name.as_str()); - if let Some(schema_name) = table_ref.schema() { - ensure_schema(&ctx, table_ref.catalog(), schema_name).map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to create schema for table '{name}': {e}" - )) - })?; - } - - ctx.register_table(table_ref, provider).map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to register table '{name}': {e}" - )) - })?; - } - - let state = ctx.state(); - - let plan = Plan::decode(plan_bytes).map_err(|e| { +fn rewrite_grad_sql(query: &str) -> PyResult { + autograd::rewrite_grad_in_sql(query).map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to decode Substrait plan: {e}" + "rewrite_grad_sql: failed to rewrite grad() calls: {e}" )) - })?; - - // from_substrait_plan is async but does no real I/O here (empty tables - // resolve immediately), so a minimal current-thread runtime suffices. - let runtime = tokio::runtime::Builder::new_current_thread() - .build() - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "grad_rewrite: failed to build runtime: {e}" - )) - })?; - - let logical = runtime - .block_on(from_substrait_plan(&state, &plan)) - .map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to consume Substrait plan: {e}" - )) - })?; - - let rewritten = autograd::rewrite_grad_calls(logical).map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to rewrite grad() calls: {e}" - )) - })?; - - let out_plan = to_substrait_plan(&rewritten, &state).map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to produce Substrait plan: {e}" - )) - })?; - - Ok(PyBytes::new(py, &out_plan.encode_to_vec())) + }) } /// Differentiate a SQL scalar expression symbolically and return the @@ -1462,7 +1361,7 @@ fn differentiate_sql(expr: &str, wrt: &str, columns: Vec) -> PyResult) -> PyResult<()> { m.add_class::()?; - m.add_function(wrap_pyfunction!(grad_rewrite, m)?)?; + m.add_function(wrap_pyfunction!(rewrite_grad_sql, m)?)?; m.add_function(wrap_pyfunction!(differentiate_sql, m)?)?; Ok(()) } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 2ccc7dd..794194e 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,7 +1,8 @@ """Tests for the SQL autograd surface: ``SELECT grad(expr, column) ...``. -These exercise the full path — XarrayContext.sql() -> Substrait -> native -grad_rewrite -> Substrait -> execute — and compare results against analytic +These exercise the full path — XarrayContext.sql() differentiates every +``grad``/``jvp``/``vjp`` call as SQL text before planning, then DataFusion +executes the rewritten query — and compare results against analytic derivatives computed with numpy. """ @@ -195,6 +196,25 @@ def test_gradient_descent_in_sql(): np.testing.assert_allclose([a, b], [a_true, b_true], atol=0.05) +def test_grad_inside_recursive_cte(): + # The headline of #197: grad() *inside* a recursive CTE — a query shape the + # old Substrait bridge could not represent. Newton's method for sqrt(2) + # drives the step with grad(x*x - 2, x) computed in the recursive term: + # x <- x - (x*x - 2) / d/dx(x*x - 2) = x - (x*x - 2) / (2x). + ctx = xql.XarrayContext() + res = ctx.sql( + "WITH RECURSIVE newton AS (" + " SELECT 0 AS step, CAST(1.0 AS DOUBLE) AS x " + " UNION ALL " + " SELECT step + 1 AS step, " + " x - (x * x - 2.0) / grad(x * x - 2.0, x) AS x " + " FROM newton WHERE step < 20" + ") " + "SELECT x FROM newton ORDER BY step DESC LIMIT 1" + ).to_pandas() + np.testing.assert_allclose(res["x"][0], np.sqrt(2.0), atol=1e-9) + + def test_multi_input_grad_columns(ctx_xy): # A full Jacobian written as separate scalar grad() columns: # f = x*y -> df/dx = y, df/dy = x. diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 5e892d3..46fe8e6 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -1,10 +1,8 @@ import re -import pyarrow as pa import xarray as xr -from datafusion import SessionContext, udf +from datafusion import SessionContext from datafusion.catalog import Schema -from datafusion.substrait import Consumer, Producer, Serde from collections import defaultdict from . import _native @@ -14,8 +12,8 @@ from .reader import read_xarray_table # Matches a call to an autograd marker function (``grad(`` / ``jvp(`` / ``vjp(``, -# case-insensitive), used as a cheap gate so ordinary queries skip the -# Substrait round-trip. +# case-insensitive), used as a cheap gate so ordinary queries skip the grad +# source-to-source rewrite. _GRAD_CALL = re.compile(r"\b(grad|jvp|vjp)\s*\(", re.IGNORECASE) @@ -31,37 +29,6 @@ def __init__(self, *args, **kwargs): # in SQL (e.g. ``"air"`` for a uniform-dim Dataset, or # ``"era5.surface"`` for one entry from a multi-dim-group split). self._registered_datasets: dict[str, xr.Dataset] = {} - self._register_autograd_udfs() - - def _register_autograd_udfs(self) -> None: - """Register the ``grad`` / ``jvp`` / ``vjp`` marker UDFs. - - These are *markers*: they let queries parse and plan with the - differentiation request intact. They are never executed — the Substrait - rewrite in :meth:`sql` replaces every call with the symbolic derivative - before execution. All return a scalar, staying in the long/tidy data - model (one value per row). - - * ``grad(expr, column)`` -> ``d(expr)/d(column)``. - * ``jvp(expr, column, tangent)`` -> forward-mode directional derivative - ``d(expr)/d(column) * tangent`` (seed a tangent on an input). A - multi-input directional derivative is a sum of jvp terms. - * ``vjp(expr, column, cotangent)`` -> reverse-mode pullback - ``cotangent * d(expr)/d(column)`` (seed a cotangent on the output). - - A full gradient/Jacobian is expressed as several scalar columns, e.g. - ``grad(f, x) AS dfdx, grad(f, y) AS dfdy``. - """ - f64 = pa.float64() - self.register_udf( - udf(lambda e, c: e, [f64, f64], f64, "immutable", "grad") - ) - self.register_udf( - udf(lambda e, c, t: e, [f64, f64, f64], f64, "immutable", "jvp") - ) - self.register_udf( - udf(lambda e, c, w: e, [f64, f64, f64], f64, "immutable", "vjp") - ) def from_dataset( self, @@ -207,6 +174,11 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: ``.to_dataset(dimension_columns=[...])`` for round-tripping the result back to an ``xr.Dataset``. + If the query contains ``grad`` / ``jvp`` / ``vjp`` calls, they are + differentiated and substituted as SQL text *before* planning (see + :meth:`_rewrite_autograd`), so the differentiation works inside any + query shape — recursive CTEs, DML, and subqueries included. + Args: query: A SQL query string. *args: Forwarded to ``SessionContext.sql``. @@ -216,67 +188,32 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: An :class:`XarrayDataFrame` wrapping the DataFusion DataFrame. """ if _GRAD_CALL.search(query): - inner = self._sql_with_autograd(query, *args, **kwargs) - else: - inner = super().sql(query, *args, **kwargs) + query = self._rewrite_autograd(query) + inner = super().sql(query, *args, **kwargs) return XarrayDataFrame(inner, templates=self._registered_datasets) - def _sql_with_autograd(self, query: str, *args, **kwargs): - """Plan ``query``, rewrite ``grad(...)`` calls, return a DataFrame. + def _rewrite_autograd(self, query: str) -> str: + """Differentiate ``grad`` / ``jvp`` / ``vjp`` calls into SQL text. The differentiation engine lives in the native (Rust) extension and - operates on DataFusion logical expressions. Since that extension links - its own copy of DataFusion, the plan crosses the boundary as Substrait: - we produce the logical plan as Substrait, hand it to ``grad_rewrite`` - (which differentiates every ``grad(expr, column)`` symbolically), then - consume the rewritten Substrait back into an executable DataFrame. - """ - plan = super().sql(query, *args, **kwargs).logical_plan() - substrait_plan = Producer.to_substrait_plan(plan, self) - rewritten = _native.grad_rewrite( - substrait_plan.encode(), self._table_schemas() - ) - new_plan = Consumer.from_substrait_plan( - self, Serde.deserialize_bytes(rewritten) - ) - return self.create_dataframe_from_logical_plan(new_plan) + operates on DataFusion logical expressions. Rather than round-tripping a + whole plan across that extension's boundary, we hand it the query as SQL + text: it parses each marker call, differentiates it symbolically, and + renders the derivative back into the query in place. The result is an + ordinary SQL string this context can plan and execute directly. - def _table_schemas(self) -> list[tuple[str, pa.Schema]]: - """Return ``(name, schema)`` for every table registered in the context. + * ``grad(expr, column)`` -> ``d(expr)/d(column)``. + * ``jvp(expr, column, tangent)`` -> forward-mode directional derivative + ``d(expr)/d(column) * tangent`` (seed a tangent on an input). A + multi-input directional derivative is a sum of jvp terms. + * ``vjp(expr, column, cotangent)`` -> reverse-mode pullback + ``cotangent * d(expr)/d(column)`` (seed a cotangent on the output). - The Substrait consumer in ``grad_rewrite`` resolves table scans by name, - so it needs the schema of every table the plan might reference. We - enumerate the catalog rather than only the xarray-registered datasets, - so ``grad`` also works over plain DataFusion tables (e.g. in-memory - ``MemTable``s holding model parameters or intermediate results). Only - metadata is read here — never the underlying data. + A full gradient/Jacobian is expressed as several scalar columns, e.g. + ``grad(f, x) AS dfdx, grad(f, y) AS dfdy``. """ - schemas = [] - catalog = self.catalog() - for schema_name in catalog.schema_names(): - if schema_name == "information_schema": - continue - schema = catalog.schema(schema_name) - names = ( - schema.table_names() - if hasattr(schema, "table_names") - else schema.names() - ) - for table_name in names: - # Tables in the default schema are referenced bare ("air"); - # others are schema-qualified ("era5.surface"). - qualified = ( - table_name - if schema_name in ("public", "default") - else f"{schema_name}.{table_name}" - ) - try: - schemas.append((qualified, self.table(qualified).schema())) - except Exception: - # Be defensive: skip a table we can't introspect rather - # than failing the whole query. - continue - return schemas + rewritten: str = _native.rewrite_grad_sql(query) + return rewritten def _group_vars_by_dims(ds: xr.Dataset) -> dict[tuple[str, ...], list[str]]: