-
-
Notifications
You must be signed in to change notification settings - Fork 213
/
Copy pathparameters.jl
131 lines (115 loc) · 3.57 KB
/
parameters.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import SymbolicUtils: symtype, term, hasmetadata, issym
@enum VariableType VARIABLE PARAMETER BROWNIAN
struct MTKVariableTypeCtx end
getvariabletype(x, def = VARIABLE) = getmetadata(unwrap(x), MTKVariableTypeCtx, def)
function isparameter(x)
x = unwrap(x)
if x isa Symbolic && (varT = getvariabletype(x, nothing)) !== nothing
return varT === PARAMETER
#TODO: Delete this branch
elseif x isa Symbolic && Symbolics.getparent(x, false) !== false
p = Symbolics.getparent(x)
isparameter(p) ||
(hasmetadata(p, Symbolics.VariableSource) &&
getmetadata(p, Symbolics.VariableSource)[1] == :parameters)
elseif iscall(x) && operation(x) isa Symbolic
varT === PARAMETER || isparameter(operation(x))
elseif iscall(x) && operation(x) == (getindex)
isparameter(arguments(x)[1])
elseif x isa Symbolic
varT === PARAMETER
else
false
end
end
function iscalledparameter(x)
x = unwrap(x)
return isparameter(getmetadata(x, CallWithParent, nothing))
end
function getcalledparameter(x)
x = unwrap(x)
# `parent` is a `CallWithMetadata` with the correct metadata,
# but no namespacing. `operation(x)` has the correct namespacing,
# but is not a `CallWithMetadata` and doesn't have any metadata.
# This approach combines both.
parent = getmetadata(x, CallWithParent)
return CallWithMetadata(operation(x), metadata(parent))
end
"""
toparam(s)
Maps the variable to a parameter.
"""
function toparam(s)
if s isa Symbolics.Arr
Symbolics.wrap(toparam(Symbolics.unwrap(s)))
elseif s isa AbstractArray
map(toparam, s)
else
setmetadata(s, MTKVariableTypeCtx, PARAMETER)
end
end
toparam(s::Num) = wrap(toparam(value(s)))
"""
tovar(s)
Maps the variable to an unknown.
"""
tovar(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, VARIABLE)
tovar(s::Num) = Num(tovar(value(s)))
"""
$(SIGNATURES)
Define one or more known parameters.
See also [`@independent_variables`](@ref), [`@variables`](@ref) and [`@constants`](@ref).
"""
macro parameters(xs...)
Symbolics._parse_vars(:parameters,
Real,
xs,
toparam) |> esc
end
function find_types(array)
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
# t = typeof(x)
get!(set, typeof(x)) do
# if t == Float64
# 1
# else
counter[] += 1
# end
end
end
end
return by.(array)
end
function split_parameters_by_type(ps)
if ps === SciMLBase.NullParameters()
return Float64[], [] #use Float64 to avoid Any type warning
else
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
get!(set, typeof(x)) do
counter[] += 1
end
end
end
idxs = by.(ps)
split_idxs = [Int[]]
for (i, idx) in enumerate(idxs)
if idx > length(split_idxs)
push!(split_idxs, Int[])
end
push!(split_idxs[idx], i)
end
tighten_types = x -> identity.(x)
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
if ps isa StaticArray
parrs = map(x -> SArray{Tuple{size(x)...}}(x), split_ps)
split_ps = SArray{Tuple{size(parrs)...}}(parrs)
end
if length(split_ps) == 1 #Tuple not needed, only 1 type
return split_ps[1], split_idxs
else
return (split_ps...,), split_idxs
end
end
end