diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index d9b8fb31ab..8d0720a1d2 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -542,7 +542,7 @@ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `ps` are used to set the order of the dependent variable and parameter vectors, respectively. """ -struct ODEFunctionExpr{iip} end +struct ODEFunctionExpr{iip, specialize} end struct ODEFunctionClosure{O, I} <: Function f_oop::O @@ -551,7 +551,7 @@ end (f::ODEFunctionClosure)(u, p, t) = f.f_oop(u, p, t) (f::ODEFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t) -function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys), +function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns(sys), ps = parameters(sys), u0 = nothing; version = nothing, tgrad = false, jac = false, p = nothing, @@ -560,14 +560,12 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys), steady_state = false, sparsity = false, observedfun_exp = nothing, - kwargs...) where {iip} + kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunctionExpr`") end f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...) - dict = Dict() - fsym = gensym(:f) _f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip)) tgradsym = gensym(:tgrad) @@ -590,30 +588,28 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys), _jac = :($jacsym = nothing) end + Msym = gensym(:M) M = calculate_massmatrix(sys) - - _M = if sparse && !(u0 === nothing || M === I) - SparseArrays.sparse(M) + if sparse && !(u0 === nothing || M === I) + _M = :($Msym = $(SparseArrays.sparse(M))) elseif u0 === nothing || M === I - M + _M = :($Msym = $M) else - ArrayInterface.restructure(u0 .* u0', M) + _M = :($Msym = $(ArrayInterface.restructure(u0 .* u0', M))) end jp_expr = sparse ? :($similar($(get_jac(sys)[]), Float64)) : :nothing ex = quote - $_f - $_tgrad - $_jac - M = $_M - ODEFunction{$iip}($fsym, - sys = $sys, - jac = $jacsym, - tgrad = $tgradsym, - mass_matrix = M, - jac_prototype = $jp_expr, - sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing), - observed = $observedfun_exp) + let $_f, $_tgrad, $_jac, $_M + ODEFunction{$iip, $specialize}($fsym, + sys = $sys, + jac = $jacsym, + tgrad = $tgradsym, + mass_matrix = $Msym, + jac_prototype = $jp_expr, + sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing), + observed = $observedfun_exp) + end end !linenumbers ? Base.remove_linenums!(ex) : ex end @@ -622,6 +618,14 @@ function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...) ODEFunctionExpr{true}(sys, args...; kwargs...) end +function ODEFunctionExpr{true}(sys::AbstractODESystem, args...; kwargs...) + return ODEFunctionExpr{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) +end + +function ODEFunctionExpr{false}(sys::AbstractODESystem, args...; kwargs...) + return ODEFunctionExpr{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) +end + """ ```julia DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys), diff --git a/test/odesystem.jl b/test/odesystem.jl index 87c192afc2..8281fafe93 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -97,6 +97,25 @@ f.f(du, u, p, 0.1) @test du == [4, 0, -16] @test_throws ArgumentError f.f(u, p, 0.1) +#check iip +f = eval(ODEFunctionExpr(de, [x, y, z], [σ, ρ, β])) +f2 = ODEFunction(de, [x, y, z], [σ, ρ, β]) +@test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2) +@test SciMLBase.specialization(f) === SciMLBase.specialization(f2) +for iip in (true, false) + f = eval(ODEFunctionExpr{iip}(de, [x, y, z], [σ, ρ, β])) + f2 = ODEFunction{iip}(de, [x, y, z], [σ, ρ, β]) + @test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2) === iip + @test SciMLBase.specialization(f) === SciMLBase.specialization(f2) + + for specialize in (SciMLBase.AutoSpecialize, SciMLBase.FullSpecialize) + f = eval(ODEFunctionExpr{iip, specialize}(de, [x, y, z], [σ, ρ, β])) + f2 = ODEFunction{iip, specialize}(de, [x, y, z], [σ, ρ, β]) + @test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2) === iip + @test SciMLBase.specialization(f) === SciMLBase.specialization(f2) === specialize + end +end + #check sparsity f = eval(ODEFunctionExpr(de, [x, y, z], [σ, ρ, β], sparsity = true)) @test f.sparsity == ModelingToolkit.jacobian_sparsity(de)