Skip to content

Make grad() an in-planner rewrite (via FFI) so it works inside any query — recursive CTEs, DML #197

Description

@alxmrs

Context

The autograd grad() surface (#192) is applied by round-tripping the whole logical plan through Substrait: XarrayContext.sql() plans the query, produces a Substrait plan, the native extension (grad_rewrite) consumes it, rewrites every grad(...) into the differentiated expression, and re-produces Substrait for Python to execute.

This works for ordinary SELECT / aggregate / join plans, but Substrait can't represent every plan shape, so grad() can't appear inside:

  • Recursive CTEsto_substrait_plan rejects RecursiveQuery (NotImplemented("Unsupported plan type: RecursiveQuery ...")). This is what blocks a fully-declarative training loop with grad inside the recursion.
  • DMLINSERT ... SELECT grad(...) doesn't convert (DML plans aren't produced to Substrait).
  • Scalar subqueries also don't convert (Cannot convert <subquery> to Substrait).

In the gradient-descent demo (#195) we worked around this with differentiate_sql (differentiate once to SQL text, then iterate it in a recursive CTE that contains no grad marker). That's a good pattern, but the original goal — grad() directly inside the query, including recursive/DML — needs an upstream change.

Root cause

The native engine is a separate cdylib that statically links its own copy of DataFusion, so it can't share Expr/LogicalPlan objects with datafusion-python — hence the Substrait serialization bridge. Substrait is then the limiting factor: any plan shape Substrait can't carry, grad() can't be rewritten in.

The fix is to stop round-tripping plans and instead run the rewrite in-planner, as a native DataFusion rewrite that fires for any query shape. That requires one of the following upstream capabilities.

Upstream changes that would enable it

Option 1 (preferred): datafusion-ffi — forward ScalarUDFImpl::simplify for foreign UDFs

If a foreign (FFI-registered) scalar UDF could implement simplify, grad/jvp/vjp would be self-rewriting UDFs: their simplify() differentiates the argument Expr during the SimplifyExpressions pass, which runs for every query shape (recursive CTEs, DML, subqueries) with no Substrait involved.

  • Today FFI_ScalarUDF / ForeignScalarUDF only forward name, signature, return_type/return_field_from_args, coerce_types, invoke_with_args (see datafusion/ffi/src/udf/mod.rs in datafusion-ffi 52) — not simplify, so the default (identity) is used.
  • Fix: add simplify to the FFI_ScalarUDF vtable (an extern "C" entry) and call it from ForeignScalarUDF::simplify. The hard part is passing the argument Exprs across the C ABI — needs an FFI representation of Expr (e.g. reuse datafusion-proto's Expr (de)serialization, or Substrait ExtendedExpression, for just the UDF's args). This is the smallest, most surgical change and unblocks all query shapes at once. It would also let us delete the whole plan-level Substrait bridge.

Option 2: datafusion-python + datafusion-ffi — register a foreign AnalyzerRule / FunctionRewrite over FFI

Expose registering a foreign analyzer/optimizer rule on the SessionContext (e.g. SessionContext.register_analyzer_rule(<capsule>)), backed by a new FFI_AnalyzerRule in datafusion-ffi. Then grad is an analyzer rule that rewrites grad(...) during analysis for any plan. Broader than Option 1 (general extension point) but more API surface.

Option 3 (narrower / long path): Substrait support for recursion + DML

Keep the current bridge but teach it the missing plan shapes:

  • DML: Substrait has a WriteRel; datafusion-substrait would need producer/consumer support for it so INSERT ... SELECT grad(...) round-trips.
  • Recursion: Substrait the spec has no recursive-relation construct, so this is an upstream substrait-io spec gap, not just a datafusion-substrait change — the longest path.

Recommendation

Pursue Option 1 (FFI simplify forwarding). It's the smallest upstream change, removes the Substrait round-trip and its codec complexity entirely, and makes grad() work inside recursive CTEs, DML, and subqueries — i.e. the fully-declarative training loop with grad in the query. Until then, differentiate_sql + a recursive CTE (see #195) is the supported way to express an in-SQL training loop.

Filed as a follow-up to the autograd feature (#192) and its demos (#195); related to the mutable-state discussion in #194.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions