Skip to content

Commit b55a3c2

Browse files
feat: run trivial initialization in problem constructor
1 parent 57c79e9 commit b55a3c2

File tree

4 files changed

+37
-10
lines changed

4 files changed

+37
-10
lines changed

src/systems/diffeqs/abstractodesystem.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
753753
kwargs1 = merge(kwargs1, (; tstops))
754754
end
755755

756-
return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
756+
return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...))
757757
end
758758
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
759759

@@ -799,8 +799,9 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
799799
kwargs1 = merge(kwargs1, (; tstops))
800800
end
801801

802-
DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
803-
kwargs..., kwargs1...)
802+
return remake(DAEProblem{iip}(
803+
f, du0, u0, tspan, p; differential_vars = differential_vars,
804+
kwargs..., kwargs1...))
804805
end
805806

806807
function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)
@@ -844,7 +845,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
844845
if cbs !== nothing
845846
kwargs1 = merge(kwargs1, (callback = cbs,))
846847
end
847-
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
848+
return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...))
848849
end
849850

850851
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -893,9 +894,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
893894
else
894895
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
895896
end
896-
SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
897+
return remake(SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
897898
noise_rate_prototype =
898-
noise_rate_prototype, kwargs1..., kwargs...)
899+
noise_rate_prototype, kwargs1..., kwargs...))
899900
end
900901

901902
"""

src/systems/diffeqs/sdesystem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,8 @@ function DiffEqBase.SDEProblem{iip, specialize}(
792792

793793
kwargs = filter_kwargs(kwargs)
794794

795-
SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
796-
noise_rate_prototype = noise_rate_prototype, kwargs...)
795+
return remake(SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
796+
noise_rate_prototype = noise_rate_prototype, kwargs...))
797797
end
798798

799799
function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...)

src/systems/nonlinear/nonlinearsystem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
526526
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
527527
check_length, kwargs...)
528528
pt = something(get_metadata(sys), StandardNonlinearProblem())
529-
NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)
529+
return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...))
530530
end
531531

532532
"""
@@ -555,7 +555,7 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
555555
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
556556
check_length, kwargs...)
557557
pt = something(get_metadata(sys), StandardNonlinearProblem())
558-
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
558+
return remake(NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...))
559559
end
560560

561561
const TypeT = Union{DataType, UnionAll}

test/initializationsystem.jl

+26
Original file line numberDiff line numberDiff line change
@@ -1388,3 +1388,29 @@ end
13881388
integ1 = init(oprob1)
13891389
@test integ1[X1] 1.0
13901390
end
1391+
1392+
@testset "Trivial initialization is run on problem construction" begin
1393+
@variables _x(..) y(t)
1394+
@brownian a
1395+
@parameters tot
1396+
x = _x(t)
1397+
@testset "$Problem" for (Problem, lhs, rhs) in [
1398+
(ODEProblem, D, 0.0),
1399+
(SDEProblem, D, a),
1400+
(DDEProblem, D, _x(t - 0.1)),
1401+
(SDDEProblem, D, _x(t - 0.1) + a)
1402+
]
1403+
@mtkbuild sys = System([lhs(x) ~ x + rhs, x + y ~ tot], t;
1404+
guesses = [tot => 1.0], defaults = [tot => missing])
1405+
prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
1406+
@test prob.ps[tot] 2.0
1407+
end
1408+
@testset "$Problem" for Problem in [NonlinearProblem, NonlinearLeastSquaresProblem]
1409+
@parameters p1 p2
1410+
@mtkbuild sys = NonlinearSystem([x^2 + y^2 ~ p1, (x - 1)^2 + (y - 1)^2 ~ p2];
1411+
parameter_dependencies = [p2 ~ 2p1],
1412+
guesses = [p1 => 0.0], defaults = [p1 => missing])
1413+
prob = Problem(sys, [x => 1.0, y => 1.0], [p2 => 6.0])
1414+
@test prob.ps[p1] 3.0
1415+
end
1416+
end

0 commit comments

Comments
 (0)