Context
The autograd grad() surface (#192) is applied by round-tripping the whole logical plan through Substrait: XarrayContext.sql() plans the query, produces a Substrait plan, the native extension (grad_rewrite) consumes it, rewrites every grad(...) into the differentiated expression, and re-produces Substrait for Python to execute.
This works for ordinary SELECT / aggregate / join plans, but Substrait can't represent every plan shape, so grad() can't appear inside:
- Recursive CTEs —
to_substrait_plan rejects RecursiveQuery (NotImplemented("Unsupported plan type: RecursiveQuery ...")). This is what blocks a fully-declarative training loop with grad inside the recursion.
- DML —
INSERT ... SELECT grad(...) doesn't convert (DML plans aren't produced to Substrait).
- Scalar subqueries also don't convert (
Cannot convert <subquery> to Substrait).
In the gradient-descent demo (#195) we worked around this with differentiate_sql (differentiate once to SQL text, then iterate it in a recursive CTE that contains no grad marker). That's a good pattern, but the original goal — grad() directly inside the query, including recursive/DML — needs an upstream change.
Root cause
The native engine is a separate cdylib that statically links its own copy of DataFusion, so it can't share Expr/LogicalPlan objects with datafusion-python — hence the Substrait serialization bridge. Substrait is then the limiting factor: any plan shape Substrait can't carry, grad() can't be rewritten in.
The fix is to stop round-tripping plans and instead run the rewrite in-planner, as a native DataFusion rewrite that fires for any query shape. That requires one of the following upstream capabilities.
Upstream changes that would enable it
Option 1 (preferred): datafusion-ffi — forward ScalarUDFImpl::simplify for foreign UDFs
If a foreign (FFI-registered) scalar UDF could implement simplify, grad/jvp/vjp would be self-rewriting UDFs: their simplify() differentiates the argument Expr during the SimplifyExpressions pass, which runs for every query shape (recursive CTEs, DML, subqueries) with no Substrait involved.
- Today
FFI_ScalarUDF / ForeignScalarUDF only forward name, signature, return_type/return_field_from_args, coerce_types, invoke_with_args (see datafusion/ffi/src/udf/mod.rs in datafusion-ffi 52) — not simplify, so the default (identity) is used.
- Fix: add
simplify to the FFI_ScalarUDF vtable (an extern "C" entry) and call it from ForeignScalarUDF::simplify. The hard part is passing the argument Exprs across the C ABI — needs an FFI representation of Expr (e.g. reuse datafusion-proto's Expr (de)serialization, or Substrait ExtendedExpression, for just the UDF's args). This is the smallest, most surgical change and unblocks all query shapes at once. It would also let us delete the whole plan-level Substrait bridge.
Option 2: datafusion-python + datafusion-ffi — register a foreign AnalyzerRule / FunctionRewrite over FFI
Expose registering a foreign analyzer/optimizer rule on the SessionContext (e.g. SessionContext.register_analyzer_rule(<capsule>)), backed by a new FFI_AnalyzerRule in datafusion-ffi. Then grad is an analyzer rule that rewrites grad(...) during analysis for any plan. Broader than Option 1 (general extension point) but more API surface.
Option 3 (narrower / long path): Substrait support for recursion + DML
Keep the current bridge but teach it the missing plan shapes:
- DML: Substrait has a
WriteRel; datafusion-substrait would need producer/consumer support for it so INSERT ... SELECT grad(...) round-trips.
- Recursion: Substrait the spec has no recursive-relation construct, so this is an upstream substrait-io spec gap, not just a
datafusion-substrait change — the longest path.
Recommendation
Pursue Option 1 (FFI simplify forwarding). It's the smallest upstream change, removes the Substrait round-trip and its codec complexity entirely, and makes grad() work inside recursive CTEs, DML, and subqueries — i.e. the fully-declarative training loop with grad in the query. Until then, differentiate_sql + a recursive CTE (see #195) is the supported way to express an in-SQL training loop.
Filed as a follow-up to the autograd feature (#192) and its demos (#195); related to the mutable-state discussion in #194.
Context
The autograd
grad()surface (#192) is applied by round-tripping the whole logical plan through Substrait:XarrayContext.sql()plans the query, produces a Substrait plan, the native extension (grad_rewrite) consumes it, rewrites everygrad(...)into the differentiated expression, and re-produces Substrait for Python to execute.This works for ordinary
SELECT/ aggregate / join plans, but Substrait can't represent every plan shape, sograd()can't appear inside:to_substrait_planrejectsRecursiveQuery(NotImplemented("Unsupported plan type: RecursiveQuery ...")). This is what blocks a fully-declarative training loop withgradinside the recursion.INSERT ... SELECT grad(...)doesn't convert (DML plans aren't produced to Substrait).Cannot convert <subquery> to Substrait).In the gradient-descent demo (#195) we worked around this with
differentiate_sql(differentiate once to SQL text, then iterate it in a recursive CTE that contains nogradmarker). That's a good pattern, but the original goal —grad()directly inside the query, including recursive/DML — needs an upstream change.Root cause
The native engine is a separate cdylib that statically links its own copy of DataFusion, so it can't share
Expr/LogicalPlanobjects with datafusion-python — hence the Substrait serialization bridge. Substrait is then the limiting factor: any plan shape Substrait can't carry,grad()can't be rewritten in.The fix is to stop round-tripping plans and instead run the rewrite in-planner, as a native DataFusion rewrite that fires for any query shape. That requires one of the following upstream capabilities.
Upstream changes that would enable it
Option 1 (preferred):
datafusion-ffi— forwardScalarUDFImpl::simplifyfor foreign UDFsIf a foreign (FFI-registered) scalar UDF could implement
simplify,grad/jvp/vjpwould be self-rewriting UDFs: theirsimplify()differentiates the argumentExprduring theSimplifyExpressionspass, which runs for every query shape (recursive CTEs, DML, subqueries) with no Substrait involved.FFI_ScalarUDF/ForeignScalarUDFonly forwardname,signature,return_type/return_field_from_args,coerce_types,invoke_with_args(seedatafusion/ffi/src/udf/mod.rsin datafusion-ffi 52) — notsimplify, so the default (identity) is used.simplifyto theFFI_ScalarUDFvtable (anextern "C"entry) and call it fromForeignScalarUDF::simplify. The hard part is passing the argumentExprs across the C ABI — needs an FFI representation ofExpr(e.g. reuse datafusion-proto'sExpr(de)serialization, or SubstraitExtendedExpression, for just the UDF's args). This is the smallest, most surgical change and unblocks all query shapes at once. It would also let us delete the whole plan-level Substrait bridge.Option 2:
datafusion-python+datafusion-ffi— register a foreignAnalyzerRule/FunctionRewriteover FFIExpose registering a foreign analyzer/optimizer rule on the
SessionContext(e.g.SessionContext.register_analyzer_rule(<capsule>)), backed by a newFFI_AnalyzerRulein datafusion-ffi. Thengradis an analyzer rule that rewritesgrad(...)during analysis for any plan. Broader than Option 1 (general extension point) but more API surface.Option 3 (narrower / long path): Substrait support for recursion + DML
Keep the current bridge but teach it the missing plan shapes:
WriteRel;datafusion-substraitwould need producer/consumer support for it soINSERT ... SELECT grad(...)round-trips.datafusion-substraitchange — the longest path.Recommendation
Pursue Option 1 (FFI
simplifyforwarding). It's the smallest upstream change, removes the Substrait round-trip and its codec complexity entirely, and makesgrad()work inside recursive CTEs, DML, and subqueries — i.e. the fully-declarative training loop withgradin the query. Until then,differentiate_sql+ a recursive CTE (see #195) is the supported way to express an in-SQL training loop.Filed as a follow-up to the autograd feature (#192) and its demos (#195); related to the mutable-state discussion in #194.