Skip to content

Commit 68fdbaf

Browse files
get direct c compilation working
1 parent 7a5b376 commit 68fdbaf

File tree

5 files changed

+120
-45
lines changed

5 files changed

+120
-45
lines changed

docs/src/build_function.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Function Building and Compilation (build_function)

src/build_function.jl

+87-40
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ Keyword Arguments:
6363
programming language
6464
- `MATLABTarget`: Generates an anonymous function for use in MATLAB and Octave
6565
environments
66+
- `fname`: Used by some targets for the name of the function in the target space.
67+
68+
Note that not all build targets support the full compilation interface. Check the
69+
individual target documentation for details.
6670
"""
6771
function build_function(args...;target = JuliaTarget(),kwargs...)
6872
_build_function(target,args...;kwargs...)
@@ -455,11 +459,73 @@ function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1
455459
end
456460
numbered_expr(c::ModelingToolkit.Constant,args...;kwargs...) = c.value
457461

458-
function _build_function(target::StanTarget, eqs, vs, ps, iv,
459-
conv = simplified_expr, expression = Val{true};
462+
"""
463+
Build function target: CTarget
464+
465+
```julia
466+
function _build_function(target::CTarget, eqs::Array{<:Equation}, args...;
467+
conv = simplified_expr, expression = Val{true},
468+
fname = :diffeqf,
469+
lhsname=:du,rhsnames=[Symbol("RHS$i") for i in 1:length(args)],
470+
libpath=tempname(),compiler=:gcc)
471+
```
472+
473+
This builds an in-place C function. Only works on arrays of equations. If
474+
`expression == Val{false}`, then this builds a function in C, compiles it,
475+
and returns a lambda to that compiled function. These special keyword arguments
476+
control the compilation:
477+
478+
- libpath: the path to store the binary. Defaults to a temporary path.
479+
- compiler: which C compiler to use. Defaults to :gcc, which is currently the
480+
only available option.
481+
"""
482+
function _build_function(target::CTarget, eqs::Array{<:Equation}, args...;
483+
conv = simplified_expr, expression = Val{true},
484+
fname = :diffeqf,
485+
lhsname=:du,rhsnames=[Symbol("RHS$i") for i in 1:length(args)],
486+
libpath=tempname(),compiler=:gcc)
487+
differential_equation = string(join([numbered_expr(eq,args...,lhsname=lhsname,
488+
rhsnames=rhsnames,offset=-1) for
489+
(i, eq) enumerate(eqs)],";\n "),";")
490+
491+
argstrs = join(vcat("double* $(lhsname)",[typeof(args[i])<:Array ? "double* $(rhsnames[i])" : "double $(rhsnames[i])" for i in 1:length(args)]),", ")
492+
ex = """
493+
void $fname($(argstrs...)) {
494+
$differential_equation
495+
}
496+
"""
497+
498+
if expression == Val{true}
499+
return ex
500+
else
501+
@assert compiler == :gcc
502+
ex = build_function(eqs,args...;target=ModelingToolkit.CTarget())
503+
open(`gcc -fPIC -O3 -msse3 -xc -shared -o $(libpath * "." * Libdl.dlext) -`, "w") do f
504+
print(f, ex)
505+
end
506+
eval(:((du::Array{Float64},u::Array{Float64},p::Array{Float64},t::Float64) -> ccall(("diffeqf", $libpath), Cvoid, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64), du, u, p, t)))
507+
end
508+
end
509+
510+
"""
511+
Build function target: StanTarget
512+
513+
```julia
514+
function _build_function(target::StanTarget, eqs::Array{<:Equation}, vs, ps, iv;
515+
conv = simplified_expr, expression = Val{true},
460516
fname = :diffeqf, lhsname=:internal_var___du,
461-
varname=:internal_var___u,paramname=:internal_var___p)
462-
rhsnames=[varname,paramname]
517+
rhsnames=[:internal_var___u,:internal_var___p,:internal_var___t])
518+
```
519+
520+
This builds an in-place Stan function compatible with the Stan differential equation solvers.
521+
Unlike other build targets, this one requestions (vs, ps, iv) as the function arguments.
522+
Only allowed on arrays of equations.
523+
"""
524+
function _build_function(target::StanTarget, eqs::Array{<:Equation}, vs, ps, iv;
525+
conv = simplified_expr, expression = Val{true},
526+
fname = :diffeqf, lhsname=:internal_var___du,
527+
rhsnames=[:internal_var___u,:internal_var___p,:internal_var___t])
528+
@assert expression == Val{true}
463529
differential_equation = string(join([numbered_expr(eq,vs,ps,lhsname=lhsname,
464530
rhsnames=rhsnames) for
465531
(i, eq) enumerate(eqs)],";\n "),";")
@@ -472,26 +538,26 @@ function _build_function(target::StanTarget, eqs, vs, ps, iv,
472538
"""
473539
end
474540

475-
function _build_function(target::CTarget, eqs, vs, ps, iv;
541+
"""
542+
Build function target: MATLABTarget
543+
544+
```julia
545+
function _build_function(target::MATLABTarget, eqs::Array{<:Equation}, args...;
476546
conv = simplified_expr, expression = Val{true},
477-
fname = :diffeqf, derivname=:internal_var___du,
478-
varname=:internal_var___u,paramname=:internal_var___p)
479-
differential_equation = string(join([numbered_expr(eq,vs,ps,lhsname=derivname,
480-
rhsnames=[varname,paramname],offset=-1) for
481-
(i, eq) enumerate(eqs)],";\n "),";")
482-
"""
483-
void $fname(double* $derivname, double* $varname, double* $paramname, double $iv) {
484-
$differential_equation
485-
}
486-
"""
487-
end
547+
lhsname=:internal_var___du,
548+
rhsnames=[:internal_var___u,:internal_var___p,:internal_var___t])
549+
```
488550
489-
function _build_function(target::MATLABTarget, eqs, vs, ps, iv;
551+
This builds an out of place anonymous function @(t,rhsnames[1]) to be used in MATLAB.
552+
Compatible with the MATLAB differential equation solvers. Only allowed on arrays
553+
of equations.
554+
"""
555+
function _build_function(target::MATLABTarget, eqs::Array{<:Equation}, args...;
490556
conv = simplified_expr, expression = Val{true},
491-
fname = :diffeqf, derivname=:internal_var___du,
492-
varname=:internal_var___u,paramname=:internal_var___p)
493-
rhsnames=[varname,paramname]
494-
matstr = join([numbered_expr(eq.rhs,vs,ps,lhsname=derivname,
557+
fname = :diffeqf, lhsname=:internal_var___du,
558+
rhsnames=[:internal_var___u,:internal_var___p,:internal_var___t])
559+
@assert expression == Val{true}
560+
matstr = join([numbered_expr(eq.rhs,args...,lhsname=lhsname,
495561
rhsnames=rhsnames) for
496562
(i, eq) enumerate(eqs)],"; ")
497563

@@ -500,22 +566,3 @@ function _build_function(target::MATLABTarget, eqs, vs, ps, iv;
500566
matstr = "$fname = @(t,$(rhsnames[1])) ["*matstr*"];"
501567
matstr
502568
end
503-
504-
"""
505-
compile_cfunction(eqs,args...;libpath=tempname(),compiler=:gcc)
506-
507-
Builds a function in C, compiles it, and returns a lambda to that compiled function.
508-
Arguments match those of `build_function`. Keyword arguments:
509-
510-
- libpath: the path to store the binary. Defaults to a temporary path.
511-
- compiler: which C compiler to use. Defaults to :gcc, which is currently the
512-
only available option.
513-
"""
514-
function compile_cfunction(eqs,args...;libpath=tempname(),compiler=:gcc)
515-
@assert compiler == :gcc
516-
ex = build_function(eqs,args...;target=ModelingToolkit.CTarget())
517-
open(`gcc -fPIC -O3 -msse3 -xc -shared -o $(libpath * "." * Libdl.dlext) -`, "w") do f
518-
print(f, ex)
519-
end
520-
eval(:((du::Array{Float64},u::Array{Float64},p::Array{Float64},t::Float64) -> ccall(("diffeqf", $libpath), Cvoid, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64), du, u, p, t)))
521-
end

test/build_targets.jl

+15-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ eqs = [D(x) ~ a*x - x*y,
1414
}
1515
"""
1616

17-
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget()) ==
17+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget(),
18+
lhsname=:internal_var___du,
19+
rhsnames=[:internal_var___u,:internal_var___p,:t]) ==
1820
"""
1921
void diffeqf(double* internal_var___du, double* internal_var___u, double* internal_var___p, double t) {
2022
internal_var___du[0] = internal_var___p[0] * internal_var___u[0] - internal_var___u[0] * internal_var___u[1];
@@ -28,17 +30,25 @@ eqs = [D(x) ~ a*x - x*y,
2830

2931
sys = ODESystem(eqs,t,[x,y],[a])
3032

31-
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget()) ==
32-
ModelingToolkit.build_function(sys.eqs,[x,y],[a],t,target = ModelingToolkit.CTarget())
33+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget(),
34+
lhsname=:internal_var___du,
35+
rhsnames=[:internal_var___u,:internal_var___p,:t]) ==
36+
ModelingToolkit.build_function(sys.eqs,[x,y],[a],t,target = ModelingToolkit.CTarget(),
37+
lhsname=:internal_var___du,
38+
rhsnames=[:internal_var___u,:internal_var___p,:t])
3339

3440
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget()) ==
3541
ModelingToolkit.build_function(sys.eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget())
3642

3743
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.MATLABTarget()) ==
3844
ModelingToolkit.build_function(sys.eqs,[x,y],[a],t,target = ModelingToolkit.MATLABTarget())
3945

40-
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget()) ==
41-
ModelingToolkit.build_function(sys.eqs,sys.states,sys.ps,sys.iv,target = ModelingToolkit.CTarget())
46+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget(),
47+
lhsname=:internal_var___du,
48+
rhsnames=[:internal_var___u,:internal_var___p,:t]) ==
49+
ModelingToolkit.build_function(sys.eqs,sys.states,sys.ps,sys.iv,target = ModelingToolkit.CTarget(),
50+
lhsname=:internal_var___du,
51+
rhsnames=[:internal_var___u,:internal_var___p,:t])
4252

4353
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget()) ==
4454
ModelingToolkit.build_function(sys.eqs,sys.states,sys.ps,sys.iv,target = ModelingToolkit.StanTarget())

test/ccompile.jl

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using ModelingToolkit, Test
2+
@parameters t a
3+
@variables x y
4+
@derivatives D'~t
5+
eqs = [D(x) ~ a*x - x*y,
6+
D(y) ~ -3y + x*y]
7+
f = build_function(eqs,[x,y],[a],t,expression=Val{false},target=ModelingToolkit.CTarget())
8+
f2 = eval(build_function([x.rhs for x in eqs],[x,y],[a],t)[2])
9+
du = rand(2); du2 = rand(2)
10+
u = rand(2)
11+
p = rand(1)
12+
_t = rand()
13+
f(du,u,p,_t)
14+
f2(du2,u,p,_t)
15+
@test du == du2

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ using SafeTestsets, Test
2929
#@testset "Latexify recipes Test" begin include("latexify.jl") end
3030
@testset "Distributed Test" begin include("distributed.jl") end
3131
@testset "Variable Utils Test" begin include("variable_utils.jl") end
32+
println("Last test requires gcc available in the path!")
33+
@safetestset "C Compilation Test" begin include("ccompile.jl") end

0 commit comments

Comments
 (0)