-
-
Notifications
You must be signed in to change notification settings - Fork 213
/
Copy pathcode_generation.jl
80 lines (65 loc) · 3.05 KB
/
code_generation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface
using ModelingToolkit: t_nounits as t, D_nounits as D
@testset "`generate_custom_function`" begin
@variables x(t) y(t)[1:3]
@parameters p1=1.0 p2[1:3]=[1.0, 2.0, 3.0] p3::Int=1 p4::Bool=false
sys = complete(ODESystem(Equation[], t, [x; y], [p1, p2, p3, p4]; name = :sys))
u0 = [1.0, 2.0, 3.0, 4.0]
p = ModelingToolkit.MTKParameters(sys, [])
fn1 = generate_custom_function(
sys, x + y[1] + p1 + p2[1] + p3 * t; expression = Val(false))
@test fn1(u0, p, 0.0) == 5.0
fn2 = generate_custom_function(
sys, x + y[1] + p1 + p2[1] + p3 * t, [x], [p1, p2, p3]; expression = Val(false))
@test fn1(u0, p, 0.0) == 5.0
fn3_oop, fn3_iip = generate_custom_function(
sys, [x + y[2], y[3] + p2[2], p1 + p3, 3t]; expression = Val(false))
buffer = zeros(4)
fn3_iip(buffer, u0, p, 1.0)
@test buffer == [4.0, 6.0, 2.0, 3.0]
@test fn3_oop(u0, p, 1.0) == [4.0, 6.0, 2.0, 3.0]
fn4 = generate_custom_function(sys, ifelse(p4, p1, p2[2]); expression = Val(false))
@test fn4(u0, p, 1.0) == 2.0
fn5 = generate_custom_function(sys, ifelse(!p4, p1, p2[2]); expression = Val(false))
@test fn5(u0, p, 1.0) == 1.0
@variables x y[1:3]
sys = complete(NonlinearSystem(Equation[], [x; y], [p1, p2, p3, p4]; name = :sys))
p = MTKParameters(sys, [])
fn1 = generate_custom_function(sys, x + y[1] + p1 + p2[1] + p3; expression = Val(false))
@test fn1(u0, p) == 6.0
fn2 = generate_custom_function(
sys, x + y[1] + p1 + p2[1] + p3, [x], [p1, p2, p3]; expression = Val(false))
@test fn1(u0, p) == 6.0
fn3_oop, fn3_iip = generate_custom_function(
sys, [x + y[2], y[3] + p2[2], p1 + p3]; expression = Val(false))
buffer = zeros(3)
fn3_iip(buffer, u0, p)
@test buffer == [4.0, 6.0, 2.0]
@test fn3_oop(u0, p, 1.0) == [4.0, 6.0, 2.0]
fn4 = generate_custom_function(sys, ifelse(p4, p1, p2[2]); expression = Val(false))
@test fn4(u0, p, 1.0) == 2.0
fn5 = generate_custom_function(sys, ifelse(!p4, p1, p2[2]); expression = Val(false))
@test fn5(u0, p, 1.0) == 1.0
end
@testset "Non-standard array variables" begin
@variables x(t)
@parameters p[0:2] (f::Function)(..)
@mtkbuild sys = ODESystem(D(x) ~ p[0] * x + p[1] * t + p[2] + f(p), t)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => [1.0, 2.0, 3.0], f => sum])
@test prob.ps[p] == [1.0, 2.0, 3.0]
@test prob.ps[p[0]] == 1.0
sol = solve(prob, Tsit5())
@test SciMLBase.successful_retcode(sol)
@testset "Array split across buffers" begin
@variables x(t)[0:2]
@parameters p[1:2] (f::Function)(..)
@named sys = ODESystem(
[D(x[0]) ~ p[1] * x[0] + x[2], D(x[1]) ~ p[2] * f(x) + x[2]], t)
sys, = structural_simplify(sys, ([x[2]], []))
@test is_parameter(sys, x[2])
prob = ODEProblem(sys, [x[0] => 1.0, x[1] => 1.0], (0.0, 1.0),
[p => ones(2), f => sum, x[2] => 2.0])
sol = solve(prob, Tsit5())
@test SciMLBase.successful_retcode(sol)
end
end