@@ -93,6 +93,19 @@ struct SDESystem <: AbstractODESystem
93
93
"""
94
94
defaults:: Dict
95
95
"""
96
+ The guesses to use as the initial conditions for the
97
+ initialization system.
98
+ """
99
+ guesses:: Dict
100
+ """
101
+ The system for performing the initialization.
102
+ """
103
+ initializesystem:: Union{Nothing, NonlinearSystem}
104
+ """
105
+ Extra equations to be enforced during the initialization sequence.
106
+ """
107
+ initialization_eqs:: Vector{Equation}
108
+ """
96
109
Type of the system.
97
110
"""
98
111
connector_type:: Any
@@ -144,9 +157,8 @@ struct SDESystem <: AbstractODESystem
144
157
isscheduled:: Bool
145
158
146
159
function SDESystem (tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
147
- tgrad,
148
- jac,
149
- ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
160
+ tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
161
+ guesses, initializesystem, initialization_eqs, connector_type,
150
162
cevents, devents, parameter_dependencies, metadata = nothing , gui_metadata = nothing ,
151
163
complete = false , index_cache = nothing , parent = nothing , is_scalar_noise = false ,
152
164
is_dde = false ,
@@ -171,9 +183,9 @@ struct SDESystem <: AbstractODESystem
171
183
check_units (u, deqs, neqs)
172
184
end
173
185
new (tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
174
- ctrl_jac,
175
- Wfact, Wfact_t, name, description, systems ,
176
- defaults, connector_type, cevents, devents,
186
+ ctrl_jac, Wfact, Wfact_t, name, description, systems,
187
+ defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents ,
188
+ devents,
177
189
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
178
190
is_dde, isscheduled)
179
191
end
@@ -187,6 +199,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
187
199
default_u0 = Dict (),
188
200
default_p = Dict (),
189
201
defaults = _merge (Dict (default_u0), Dict (default_p)),
202
+ guesses = Dict (),
203
+ initializesystem = nothing ,
204
+ initialization_eqs = Equation[],
190
205
name = nothing ,
191
206
description = " " ,
192
207
connector_type = nothing ,
@@ -207,6 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
207
222
dvs′ = value .(dvs)
208
223
ps′ = value .(ps)
209
224
ctrl′ = value .(controls)
225
+ parameter_dependencies, ps′ = process_parameter_dependencies (
226
+ parameter_dependencies, ps′)
210
227
211
228
sysnames = nameof .(systems)
212
229
if length (unique (sysnames)) != length (sysnames)
@@ -217,13 +234,21 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
217
234
" `default_u0` and `default_p` are deprecated. Use `defaults` instead." ,
218
235
:SDESystem , force = true )
219
236
end
220
- defaults = todict (defaults)
221
- defaults = Dict (value (k) => value (v)
222
- for (k, v) in pairs (defaults) if value (v) != = nothing )
223
237
238
+ defaults = Dict {Any, Any} (todict (defaults))
239
+ guesses = Dict {Any, Any} (todict (guesses))
224
240
var_to_name = Dict ()
225
- process_variables! (var_to_name, defaults, dvs′)
226
- process_variables! (var_to_name, defaults, ps′)
241
+ process_variables! (var_to_name, defaults, guesses, dvs′)
242
+ process_variables! (var_to_name, defaults, guesses, ps′)
243
+ process_variables! (
244
+ var_to_name, defaults, guesses, [eq. lhs for eq in parameter_dependencies])
245
+ process_variables! (
246
+ var_to_name, defaults, guesses, [eq. rhs for eq in parameter_dependencies])
247
+ defaults = Dict {Any, Any} (value (k) => value (v)
248
+ for (k, v) in pairs (defaults) if v != = nothing )
249
+ guesses = Dict {Any, Any} (value (k) => value (v)
250
+ for (k, v) in pairs (guesses) if v != = nothing )
251
+
227
252
isempty (observed) || collect_var_to_name! (var_to_name, (eq. lhs for eq in observed))
228
253
229
254
tgrad = RefValue (EMPTY_TGRAD)
@@ -233,14 +258,13 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
233
258
Wfact_t = RefValue (EMPTY_JAC)
234
259
cont_callbacks = SymbolicContinuousCallbacks (continuous_events)
235
260
disc_callbacks = SymbolicDiscreteCallbacks (discrete_events)
236
- parameter_dependencies, ps′ = process_parameter_dependencies (
237
- parameter_dependencies, ps′)
238
261
if is_dde === nothing
239
262
is_dde = _check_if_dde (deqs, iv′, systems)
240
263
end
241
264
SDESystem (Threads. atomic_add! (SYSTEM_COUNT, UInt (1 )),
242
265
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
243
- ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
266
+ ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
267
+ initializesystem, initialization_eqs, connector_type,
244
268
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
245
269
complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
246
270
end
@@ -520,7 +544,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
520
544
version = nothing , tgrad = false , sparse = false ,
521
545
jac = false , Wfact = false , eval_expression = false ,
522
546
eval_module = @__MODULE__ ,
523
- checkbounds = false ,
547
+ checkbounds = false , initialization_data = nothing ,
524
548
kwargs... ) where {iip, specialize}
525
549
if ! iscomplete (sys)
526
550
error (" A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`" )
@@ -591,13 +615,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
591
615
592
616
observedfun = ObservedFunctionCache (sys; eval_expression, eval_module)
593
617
594
- SDEFunction {iip, specialize} (f, g,
618
+ SDEFunction {iip, specialize} (f, g;
595
619
sys = sys,
596
620
jac = _jac === nothing ? nothing : _jac,
597
621
tgrad = _tgrad === nothing ? nothing : _tgrad,
598
622
Wfact = _Wfact === nothing ? nothing : _Wfact,
599
623
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
600
- mass_matrix = _M,
624
+ mass_matrix = _M, initialization_data,
601
625
observed = observedfun)
602
626
end
603
627
@@ -714,7 +738,7 @@ function DiffEqBase.SDEProblem{iip, specialize}(
714
738
end
715
739
f, u0, p = process_SciMLProblem (
716
740
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
717
- kwargs... )
741
+ t = tspan === nothing ? nothing : tspan[ 1 ], kwargs... )
718
742
cbs = process_events (sys; callback, kwargs... )
719
743
sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
720
744
@@ -736,6 +760,8 @@ function DiffEqBase.SDEProblem{iip, specialize}(
736
760
noise = nothing
737
761
end
738
762
763
+ kwargs = filter_kwargs (kwargs)
764
+
739
765
SDEProblem {iip} (f, u0, tspan, p; callback = cbs, noise,
740
766
noise_rate_prototype = noise_rate_prototype, kwargs... )
741
767
end
0 commit comments