forked from SciML/ModelingToolkit.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdifferentials.jl
170 lines (129 loc) · 3.77 KB
/
differentials.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
export Differential, expand_derivatives, @derivatives
"""
$(TYPEDEF)
Represents a differential operator.
# Fields
$(FIELDS)
# Examples
```jldoctest
julia> using ModelingToolkit
julia> @variables x y;
julia> D = Differential(x)
(D'~x())
julia> D(y) # Differentiate y wrt. x
(D'~x())(y())
```
"""
struct Differential <: Function
"""The variable or expression to differentiate with respect to."""
x::Expression
end
(D::Differential)(x) = Operation(D, Expression[x])
Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
Base.convert(::Type{Expr}, D::Differential) = D
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)
"""
$(SIGNATURES)
TODO
"""
function expand_derivatives(O::Operation)
@. O.args = expand_derivatives(O.args)
if isa(O.op, Differential)
(D, o) = (O.op, O.args[1])
isequal(o, D.x) && return Constant(1)
occursin(D.x, o) || return Constant(0)
isa(o, Operation) || return O
isa(o.op, Variable) && return O
return sum(1:length(o.args)) do i
derivative(o, i) * expand_derivatives(D(o.args[i]))
end |> simplify_constants
end
return O
end
expand_derivatives(x) = x
# Don't specialize on the function here
"""
$(SIGNATURES)
Calculate the derivative of the op `O` with respect to its argument with index
`idx`.
# Examples
```jldoctest label1
julia> using ModelingToolkit
julia> @variables x y;
julia> ModelingToolkit.derivative(sin(x), 1)
cos(x())
```
Note that the function does not recurse into the operation's arguments, i.e. the
chain rule is not applied:
```jldoctest label1
julia> myop = sin(x) * y^2
sin(x()) * y() ^ 2
julia> typeof(myop.op) # Op is multiplication function
typeof(*)
julia> ModelingToolkit.derivative(myop, 1) # wrt. sin(x)
y() ^ 2
julia> ModelingToolkit.derivative(myop, 2) # wrt. y^2
sin(x())
```
"""
derivative(O::Operation, idx) = derivative(O.op, (O.args...,), Val(idx))
# Pre-defined derivatives
import DiffRules, SpecialFunctions, NaNMath
for (modu, fun, arity) ∈ DiffRules.diffrules()
for i ∈ 1:arity
@eval function derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i})
M, f = $(modu, fun)
partials = DiffRules.diffrule(M, f, args...)
dx = @static $arity == 1 ? partials : partials[$i]
convert(Expression, dx)
end
end
end
function count_order(x)
@assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!"
n = 1
while !(x.args[1] isa Symbol)
n = n+1
x = x.args[1]
end
n, x.args[1]
end
_repeat_apply(f, n) = n == 1 ? f : f ∘ _repeat_apply(f, n-1)
function _differential_macro(x)
ex = Expr(:block)
lhss = Symbol[]
x = x isa Tuple && first(x).head == :tuple ? first(x).args : x # tuple handling
x = flatten_expr!(x)
for di in x
@assert di isa Expr && di.args[1] == :~ "@derivatives expects a form that looks like `@derivatives D''~t E'~t` or `@derivatives (D''~t), (E'~t)`"
lhs = di.args[2]
rhs = di.args[3]
order, lhs = count_order(lhs)
push!(lhss, lhs)
expr = :($lhs = $_repeat_apply(Differential($rhs), $order))
push!(ex.args, expr)
end
push!(ex.args, Expr(:tuple, lhss...))
ex
end
"""
$(SIGNATURES)
Define one or more differentials.
# Examples
```jldoctest
julia> using ModelingToolkit
julia> @variables x y z;
julia> @derivatives Dx'~x Dy'~y # Create differentials wrt. x and y
((D'~x()), (D'~y()))
julia> Dx(z) # Differentiate z wrt. x
(D'~x())(z())
julia> Dy(z) # Differentiate z wrt. y
(D'~y())(z())
```
"""
macro derivatives(x...)
esc(_differential_macro(x))
end
function calculate_jacobian(eqs, dvs)
Expression[Differential(dv)(eq) for eq ∈ eqs, dv ∈ dvs]
end