Skip to content

Commit b18ecc0

Browse files
Merge pull request #3323 from vyudu/BVP-with-constraints
BVProblem with constraints
2 parents 57c79e9 + 3642e1b commit b18ecc0

9 files changed

+559
-12
lines changed

Project.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ AbstractTrees = "0.3, 0.4"
8383
ArrayInterface = "6, 7"
8484
BifurcationKit = "0.4"
8585
BlockArrays = "1.1"
86+
BoundaryValueDiffEqAscher = "1.1.0"
87+
BoundaryValueDiffEqMIRK = "1.4.0"
8688
ChainRulesCore = "1"
8789
Combinatorics = "1"
8890
CommonSolve = "0.2.4"
@@ -157,6 +159,8 @@ julia = "1.9"
157159
[extras]
158160
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
159161
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
162+
BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
163+
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
160164
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
161165
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
162166
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
@@ -189,4 +193,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
189193
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
190194

191195
[targets]
192-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging"]
196+
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging"]

src/ModelingToolkit.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ include("systems/codegen_utils.jl")
155155
include("systems/problem_utils.jl")
156156
include("linearization.jl")
157157

158+
include("systems/optimization/constraints_system.jl")
159+
include("systems/optimization/optimizationsystem.jl")
160+
include("systems/optimization/modelingtoolkitize.jl")
161+
158162
include("systems/nonlinear/nonlinearsystem.jl")
159163
include("systems/nonlinear/homotopy_continuation.jl")
160164
include("systems/diffeqs/odesystem.jl")
@@ -170,10 +174,6 @@ include("systems/discrete_system/discrete_system.jl")
170174

171175
include("systems/jumps/jumpsystem.jl")
172176

173-
include("systems/optimization/constraints_system.jl")
174-
include("systems/optimization/optimizationsystem.jl")
175-
include("systems/optimization/modelingtoolkitize.jl")
176-
177177
include("systems/pde/pdesystem.jl")
178178

179179
include("systems/sparsematrixclil.jl")

src/systems/abstractsystem.jl

+1
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ for prop in [:eqs
823823
:structure
824824
:op
825825
:constraints
826+
:constraintsystem
826827
:controls
827828
:loss
828829
:bcs

src/systems/diffeqs/abstractodesystem.jl

+164
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
735735
if !iscomplete(sys)
736736
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
737737
end
738+
739+
if !isnothing(get_constraintsystem(sys))
740+
error("An ODESystem with constraints cannot be used to construct a regular ODEProblem.
741+
Consider a BVProblem instead.")
742+
end
743+
738744
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
739745
t = tspan !== nothing ? tspan[1] : tspan,
740746
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
@@ -757,6 +763,164 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
757763
end
758764
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
759765

766+
"""
767+
```julia
768+
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
769+
parammap = DiffEqBase.NullParameters();
770+
constraints = nothing, guesses = nothing,
771+
version = nothing, tgrad = false,
772+
jac = true, sparse = true,
773+
simplify = false,
774+
kwargs...) where {iip}
775+
```
776+
777+
Create a boundary value problem from the [`ODESystem`](@ref).
778+
779+
`u0map` is used to specify fixed initial values for the states. Every variable
780+
must have either an initial guess supplied using `guesses` or a fixed initial
781+
value specified using `u0map`.
782+
783+
Boundary value conditions are supplied to ODESystems
784+
in the form of a ConstraintsSystem. These equations
785+
should specify values that state variables should
786+
take at specific points, as in `x(0.5) ~ 1`). More general constraints that
787+
should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be
788+
specified as one of the equations used to build the `ODESystem`.
789+
790+
If an ODESystem without `constraints` is specified, it will be treated as an initial value problem.
791+
792+
```julia
793+
@parameters g t_c = 0.5
794+
@variables x(..) y(t) [state_priority = 10] λ(t)
795+
eqs = [D(D(x(t))) ~ λ * x(t)
796+
D(D(y)) ~ λ * y - g
797+
x(t)^2 + y^2 ~ 1]
798+
cstr = [x(0.5) ~ 1]
799+
@named cstrs = ConstraintsSystem(cstr, t)
800+
@mtkbuild pend = ODESystem(eqs, t)
801+
802+
tspan = (0.0, 1.5)
803+
u0map = [x(t) => 0.6, y => 0.8]
804+
parammap = [g => 1]
805+
guesses = [λ => 1]
806+
constraints = [x(0.5) ~ 1]
807+
808+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
809+
```
810+
811+
If the `ODESystem` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
812+
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
813+
"""
814+
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
815+
BVProblem{true}(sys, args...; kwargs...)
816+
end
817+
818+
function SciMLBase.BVProblem(sys::AbstractODESystem,
819+
u0map::StaticArray,
820+
args...;
821+
kwargs...)
822+
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
823+
end
824+
825+
function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
826+
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
827+
end
828+
829+
function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
830+
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
831+
end
832+
833+
function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
834+
tspan = get_tspan(sys),
835+
parammap = DiffEqBase.NullParameters();
836+
guesses = Dict(),
837+
version = nothing, tgrad = false,
838+
callback = nothing,
839+
check_length = true,
840+
warn_initialize_determined = true,
841+
eval_expression = false,
842+
eval_module = @__MODULE__,
843+
kwargs...) where {iip, specialize}
844+
845+
if !iscomplete(sys)
846+
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
847+
end
848+
!isnothing(callback) && error("BVP solvers do not support callbacks.")
849+
850+
has_alg_eqs(sys) && error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
851+
852+
sts = unknowns(sys)
853+
ps = parameters(sys)
854+
constraintsys = get_constraintsystem(sys)
855+
856+
if !isnothing(constraintsys)
857+
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
858+
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
859+
end
860+
861+
# ODESystems without algebraic equations should use both fixed values + guesses
862+
# for initialization.
863+
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
864+
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
865+
t = tspan !== nothing ? tspan[1] : tspan, guesses,
866+
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
867+
868+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
869+
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
870+
871+
fns = generate_function_bc(sys, u0, u0_idxs, tspan)
872+
bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module)
873+
bc(sol, p, t) = bc_oop(sol, p, t)
874+
bc(resid, u, p, t) = bc_iip(resid, u, p, t)
875+
876+
return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
877+
end
878+
879+
get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
880+
881+
"""
882+
generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan)
883+
884+
Given an ODESystem with constraints, generate the boundary condition function to pass to boundary value problem solvers.
885+
Expression uses the constraints and the provided initial conditions.
886+
"""
887+
function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
888+
iv = get_iv(sys)
889+
sts = unknowns(sys)
890+
ps = parameters(sys)
891+
np = length(ps)
892+
ns = length(sts)
893+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
894+
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
895+
896+
@variables sol(..)[1:ns]
897+
898+
conssys = get_constraintsystem(sys)
899+
cons = Any[]
900+
if !isnothing(conssys)
901+
cons = [con.lhs - con.rhs for con in constraints(conssys)]
902+
903+
for st in get_unknowns(conssys)
904+
x = operation(st)
905+
t = only(arguments(st))
906+
idx = stidxmap[x(iv)]
907+
908+
cons = map(c -> Symbolics.substitute(c, Dict(x(t) => sol(t)[idx])), cons)
909+
end
910+
end
911+
912+
init_conds = Any[]
913+
for i in u0_idxs
914+
expr = sol(tspan[1])[i] - u0[i]
915+
push!(init_conds, expr)
916+
end
917+
918+
exprs = vcat(init_conds, cons)
919+
_p = reorder_parameters(sys, ps)
920+
921+
build_function_wrapper(sys, exprs, sol, _p..., t; output_type = Array, kwargs...)
922+
end
923+
760924
"""
761925
```julia
762926
DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,

src/systems/diffeqs/odesystem.jl

+71-6
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ struct ODESystem <: AbstractODESystem
4949
ctrls::Vector
5050
"""Observed variables."""
5151
observed::Vector{Equation}
52+
"""System of constraints that must be satisfied by the solution to the system."""
53+
constraintsystem::Union{Nothing, ConstraintsSystem}
5254
"""
5355
Time-derivative matrix. Note: this field will not be defined until
5456
[`calculate_tgrad`](@ref) is called on the system.
@@ -191,7 +193,7 @@ struct ODESystem <: AbstractODESystem
191193
"""
192194
parent::Any
193195

194-
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
196+
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
195197
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
196198
torn_matching, initializesystem, initialization_eqs, schedule,
197199
connector_type, preface, cevents,
@@ -212,7 +214,7 @@ struct ODESystem <: AbstractODESystem
212214
u = __get_unit_type(dvs, ps, iv)
213215
check_units(u, deqs)
214216
end
215-
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
217+
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad, jac,
216218
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
217219
initializesystem, initialization_eqs, schedule, connector_type, preface,
218220
cevents, devents, parameter_dependencies, assertions, metadata,
@@ -224,6 +226,7 @@ end
224226
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
225227
controls = Num[],
226228
observed = Equation[],
229+
constraintsystem = nothing,
227230
systems = ODESystem[],
228231
tspan = nothing,
229232
name = nothing,
@@ -297,17 +300,29 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
297300
if is_dde === nothing
298301
is_dde = _check_if_dde(deqs, iv′, systems)
299302
end
303+
304+
if !isempty(systems) && !isnothing(constraintsystem)
305+
conssystems = ConstraintsSystem[]
306+
for sys in systems
307+
cons = get_constraintsystem(sys)
308+
cons !== nothing && push!(conssystems, cons)
309+
end
310+
@show conssystems
311+
@set! constraintsystem.systems = conssystems
312+
end
313+
300314
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
315+
301316
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
302-
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
317+
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
303318
ctrl_jac, Wfact, Wfact_t, name, description, systems,
304319
defaults, guesses, nothing, initializesystem,
305320
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
306321
disc_callbacks, parameter_dependencies, assertions,
307322
metadata, gui_metadata, is_dde, tstops, checks = checks)
308323
end
309324

310-
function ODESystem(eqs, iv; kwargs...)
325+
function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
311326
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
312327

313328
for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -339,8 +354,22 @@ function ODESystem(eqs, iv; kwargs...)
339354
end
340355
algevars = setdiff(allunknowns, diffvars)
341356

342-
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
343-
collect(new_ps); kwargs...)
357+
consvars = OrderedSet()
358+
constraintsystem = nothing
359+
if !isempty(constraints)
360+
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
361+
for st in get_unknowns(constraintsystem)
362+
iscall(st) ?
363+
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
364+
!in(st, allunknowns) && push!(consvars, st)
365+
end
366+
for p in parameters(constraintsystem)
367+
!in(p, new_ps) && push!(new_ps, p)
368+
end
369+
end
370+
371+
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
372+
collect(new_ps); constraintsystem, kwargs...)
344373
end
345374

346375
# NOTE: equality does not check cached Jacobian
@@ -668,3 +697,39 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
668697

669698
return nothing
670699
end
700+
701+
# Validate that all the variables in the BVP constraints are well-formed states or parameters.
702+
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
703+
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
704+
function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
705+
isempty(constraints) && return nothing
706+
707+
constraintsts = OrderedSet()
708+
constraintps = OrderedSet()
709+
710+
for cons in constraints
711+
collect_vars!(constraintsts, constraintps, cons, iv)
712+
end
713+
714+
# Validate the states.
715+
for var in constraintsts
716+
if !iscall(var)
717+
occursin(iv, var) && (var sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
718+
elseif length(arguments(var)) > 1
719+
throw(ArgumentError("Too many arguments for variable $var."))
720+
elseif length(arguments(var)) == 1
721+
arg = only(arguments(var))
722+
operation(var)(iv) sts ||
723+
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
724+
725+
isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat ||
726+
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
727+
728+
isparameter(arg) && push!(constraintps, arg)
729+
else
730+
var sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
731+
end
732+
end
733+
734+
ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
735+
end

src/systems/optimization/constraints_system.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function ConstraintsSystem(constraints, unknowns, ps;
123123
name === nothing &&
124124
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
125125

126-
cstr = value.(Symbolics.canonical_form.(scalarize(constraints)))
126+
cstr = value.(Symbolics.canonical_form.(vcat(scalarize(constraints)...)))
127127
unknowns′ = value.(scalarize(unknowns))
128128
ps′ = value.(ps)
129129

0 commit comments

Comments
 (0)