diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 5441e0a9..619f0f68 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -885,11 +885,26 @@ def _save_t1(subsaveat, save_state): return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats +def _validate_solver(solver: Any) -> AbstractSolver: + """Raise a clear error if `solver` was passed as a class, not an instance.""" + if isinstance(solver, AbstractSolver): + return solver + if isinstance(solver, type) and issubclass(solver, AbstractSolver): + raise ValueError( + "It looks like you forgot to instantiate your solver, e.g. by passing " + "`diffrax.Euler` instead of `diffrax.Euler()`." + ) + raise ValueError( + "Argument `solver` must be an instance of (some subclass of) " + "`diffrax.AbstractSolver`, but its type is not recognised." + ) + + @eqx.filter_jit @eqxi.doc_remove_args("discrete_terminating_event") def diffeqsolve( terms: PyTree[AbstractTerm], - solver: AbstractSolver, + solver: AbstractSolver | type[AbstractSolver], t0: RealScalarLike, t1: RealScalarLike, dt0: RealScalarLike | None, @@ -1014,6 +1029,8 @@ def diffeqsolve( # Initial set-up # + validated_solver: AbstractSolver = _validate_solver(solver) + # Backward compatibility if discrete_terminating_event is not None: warnings.warn( @@ -1100,7 +1117,7 @@ def _promote(yi): del timelikes # Backward compatibility - if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)): + if isinstance(validated_solver, (EulerHeun, ItoMilstein, StratonovichMilstein)): try: _assert_term_compatible( t0, @@ -1108,14 +1125,14 @@ def _promote(yi): args, terms, (ODETerm, AbstractTerm), - solver.term_compatible_contr_kwargs, + validated_solver.term_compatible_contr_kwargs, ) except Exception as _: pass else: warnings.warn( "Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to " - f"{solver.__class__.__name__} is deprecated in favour of " + f"{validated_solver.__class__.__name__} is deprecated in favour of " "`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that " "the same terms can now be passed used for both general " "and SDE-specific solvers!", @@ -1129,20 +1146,22 @@ def _promote(yi): y0, args, terms, - solver.term_structure, - solver.term_compatible_contr_kwargs, + validated_solver.term_structure, + validated_solver.term_compatible_contr_kwargs, ) if is_sde(terms): - if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)): + if not isinstance( + validated_solver, (AbstractItoSolver, AbstractStratonovichSolver) + ): warnings.warn( - f"`{type(solver).__name__}` is not marked as converging to either the " - "Itô or the Stratonovich solution.", + f"`{type(validated_solver).__name__}` is not marked as converging to " + "either the Itô or the Stratonovich solution.", stacklevel=2, ) if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): # Specific check to not work even if using HalfSolver(Euler()) - if isinstance(solver, Euler): + if isinstance(validated_solver, Euler): raise ValueError( "An SDE should not be solved with adaptive step sizes with Euler's " "method, as it may not converge to the correct solution." @@ -1175,26 +1194,28 @@ def _wrap(term): is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm), ) - if isinstance(solver, AbstractImplicitSolver): + if isinstance(validated_solver, AbstractImplicitSolver): def _get_tols(x): outs = [] for attr in ("rtol", "atol", "norm"): if ( - getattr(cast(AbstractImplicitSolver, solver).root_finder, attr) + getattr( + cast(AbstractImplicitSolver, validated_solver).root_finder, attr + ) is use_stepsize_tol ): outs.append(getattr(x, attr)) return tuple(outs) if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): - solver = eqx.tree_at( + validated_solver = eqx.tree_at( lambda s: _get_tols(s.root_finder), - solver, + validated_solver, _get_tols(stepsize_controller), ) else: - if len(_get_tols(solver.root_finder)) > 0: + if len(_get_tols(validated_solver.root_finder)) > 0: raise ValueError( "A fixed step size controller is being used alongside an implicit " "solver, but the tolerances for the implicit solver have not been " @@ -1248,24 +1269,24 @@ def _subsaveat_direction_fn(x): # Initialise states tprev = t0 - error_order = solver.error_order(terms) + error_order = validated_solver.error_order(terms) if controller_state is None: passed_controller_state = False (tnext, controller_state) = stepsize_controller.init( - terms, t0, t1, y0, dt0, args, solver.func, error_order + terms, t0, t1, y0, dt0, args, validated_solver.func, error_order ) else: passed_controller_state = True if dt0 is None: (tnext, _) = stepsize_controller.init( - terms, t0, t1, y0, dt0, args, solver.func, error_order + terms, t0, t1, y0, dt0, args, validated_solver.func, error_order ) else: tnext = t0 + dt0 tnext = jnp.minimum(tnext, t1) if solver_state is None: passed_solver_state = False - solver_state = solver.init(terms, t0, tnext, y0, args) + solver_state = validated_solver.init(terms, t0, tnext, y0, args) else: passed_solver_state = True @@ -1310,7 +1331,14 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState: result = RESULTS.successful if saveat.dense or event is not None: _, _, dense_info_struct, _, _ = eqx.filter_eval_shape( - solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump + validated_solver.step, + terms, + tprev, + tnext, + y0, + args, + solver_state, + made_jump, ) if saveat.dense: if max_steps is None: @@ -1371,7 +1399,7 @@ def _outer_cond_fn(cond_fn_i): y0, args, terms=terms, - solver=solver, + solver=validated_solver, t0=t0, t1=t1, dt0=dt0, @@ -1456,7 +1484,7 @@ def _outer_cond_fn(cond_fn_i): final_state, aux_stats = adjoint.loop( args=args, terms=terms, - solver=solver, + solver=validated_solver, stepsize_controller=stepsize_controller, event=event, saveat=saveat, @@ -1503,7 +1531,7 @@ def _outer_cond_fn(cond_fn_i): ts=final_state.dense_ts, ts_size=final_state.dense_save_index + 1, infos=final_state.dense_infos, - interpolation_cls=solver.interpolation_cls, + interpolation_cls=validated_solver.interpolation_cls, direction=direction, t0_if_trivial=t0, y0_if_trivial=y0, diff --git a/test/test_integrate.py b/test/test_integrate.py index 9918aa20..3daa0ce8 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -587,6 +587,35 @@ def vector_field(t, y, args): diffrax._integrate._PRINT_STATIC = False +def test_uninstantiated_solver_error(): + msg = ( + r"It looks like you forgot to instantiate your solver, e.g. by passing " + r"`diffrax\.Euler` instead of `diffrax\.Euler\(\)`." + ) + term = ODETerm(lambda t, y, args: -y) + with pytest.raises(ValueError, match=msg): + diffrax.diffeqsolve(term, diffrax.Euler, 0, 1, 0.1, 1.0) + with pytest.raises(ValueError, match=msg): + diffrax.diffeqsolve( + MultiTerm( + ODETerm(lambda t, y, args: -y), + ControlTerm( + lambda t, y, args: 0.1 * t, + diffrax.VirtualBrownianTree( + 0, 1, tol=1e-3, shape=(), key=jr.key(0) + ), + ), + ), + diffrax.EulerHeun, + 0, + 1, + 0.1, + 1.0, + ) + with pytest.raises(ValueError, match=r"not recognised"): + diffrax._integrate._validate_solver("not a solver") + + def test_implicit_tol_error(): msg = "the tolerances for the implicit solver have not been specified" with pytest.raises(ValueError, match=msg):