-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathcontractor.jl
368 lines (251 loc) · 8.98 KB
/
contractor.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
# Own version of gensym:
const symbol_number = [1]
doc"""Return a new, unique symbol like _z10_"""
function make_symbol()
i = symbol_number[1]
symbol_number[1] += 1
symbol("_z", i, "_")
end
function insert_variables(ex) # numbers are leaves
ex, Symbol[], Symbol[], quote end
end
function insert_variables(ex::Symbol) # symbols are leaves
ex, [ex], Symbol[], quote end
end
doc"""
`insert_variables` takes a Julia `Expr`ession and
recursively replaces operations like `a+b` by assignments
of the form `_z10_ = a+b`, where `_z10_` is a distinct symbol,
created using `make_symbol` (which is like `gensym`, but more readable).
Returns:
1. generated variable at head of tree;
2. sorted vector of leaf (user) variables contained in tree;
3. vector of generator intermediate variables;
3. generated code.
Usage: `IntervalConstraintProgramming.insert_variables(:(x^2 + y^2))`
"""
function insert_variables(ex::Expr)
op = ex.args[1]
# rewrite +(a,b,c) as +(a,+(b,c)):
# TODO: Use @match here!
if op in (:+, :*) && length(ex.args) > 3
return insert_variables( :( ($op)($(ex.args[2]), ($op)($(ex.args[3:end]...) )) ))
end
new_code = quote end
current_args = [] # the arguments in the current expression that will be added
all_vars = Set{Symbol}() # all variables contained in the sub-expressions
generated_variables = Symbol[]
for arg in ex.args[2:end]
top, contained_vars, generated, code = insert_variables(arg)
push!(current_args, top)
union!(all_vars, contained_vars)
append!(new_code.args, code.args) # add previously-generated code
append!(generated_variables, generated)
end
new_var = make_symbol()
push!(generated_variables, new_var)
# rename each occurrence of user-defined function:
if op ∉ keys(rev_ops)
# check if user-defined function
counter, op = increment_counter!(op)
if counter == 1
# create the function
end
# throw(ArgumentError("Operation $op not currently supported"))
end
top_level_code = :($(new_var) = ($op)($(current_args...))) # new top-level code
push!(new_code.args, top_level_code)
return new_var, sort(collect(all_vars)), generated_variables, new_code
end
function constraint_code(root_var, constraint)
# if constraint == Interval(-∞, ∞)
# constraint_code = :($(root_var) = $(root_var) ∩ _A_)
# # push!(all_vars, :_A_)
#
# else
constraint_code = :($(root_var) = $(root_var) ∩ $constraint)
# end
return constraint_code
# new_code = quote end
# push!(new_code.args, constraint_code)
end
function forward_pass(ex::Expr)
root, all_vars, generated, code = insert_variables(ex)
forward_pass(root, all_vars, generated, code)
end
function forward_pass(root, all_vars, generated, code)
make_function(all_vars, generated, code)
end
function backward_pass(ex::Expr) #, constraint::Interval)
root, all_vars, generated, code = insert_variables(ex)
backward_pass(root, all_vars, generated, code)
end
doc"""`backward_pass` replaces e.g. `z = a + b` with
the corresponding reverse-mode function, `(z, a, b) = plusRev(z, a, b)`
"""
function backward_pass(root_var, all_vars, generated, code) #, constraint::Interval)
new_code = quote end
for line in reverse(code.args) # run backwards
if line.head == :line # line number node
continue
end
(var, op, args) = @match line begin
(var_ = op_(args__)) => (var, op, args)
end
new_args = []
push!(new_args, var)
append!(new_args, args)
rev_op = rev_ops[op] # find the reverse operation
rev_code = :($(rev_op)($(new_args...)))
return_args = copy(new_args)
# delete non-symbols in return args:
for (i, arg) in enumerate(return_args)
if !(isa(arg, Symbol))
return_args[i] = :_
end
end
return_tuple = Expr(:tuple, return_args...) # make tuple out of array
# or: :($(return_args...),)
new_line = :($(return_tuple) = $(rev_code))
push!(new_code.args, new_line)
end
all_vars = sort(all_vars)
make_function(vcat(all_vars, generated), all_vars, new_code)
end
doc"""
`forward_backward` takes in an expression like `x^2 + y^2` and outputs
code for the forward-backward contractor
TODO: Add intersections in forward direction
"""
function forward_backward(ex::Expr, constraint::Interval=entireinterval())
new_ex = copy(ex)
# Step 1: Forward pass using insert_variables
root_var, all_vars, generated, code = insert_variables(new_ex)
# Step 2: Add constraint code:
local constraint_code
if constraint == Interval(-∞, ∞)
constraint_code = :($(root_var) = $(root_var) ∩ _A_)
push!(all_vars, :_A_)
else
constraint_code = :($(root_var) = $(root_var) ∩ $constraint)
end
new_code = copy(code)
push!(new_code.args, constraint_code)
# Step 3: Backwards pass
# replace e.g. z = a + b with reverse mode function plusRev(z, a, b)
for line in reverse(code.args) # run backwards
if line.head == :line # line number node
continue
end
(var, op, args) = @match line begin
(var_ = op_(args__)) => (var, op, args)
end
new_args = []
push!(new_args, var)
append!(new_args, args)
rev_op = rev_ops[op] # find the reverse operation
rev_code = :($(rev_op)($(new_args...)))
return_args = copy(new_args)
# delete non-symbols in return args:
for (i, arg) in enumerate(return_args)
if !(isa(arg, Symbol))
return_args[i] = :_
end
end
return_tuple = Expr(:tuple, return_args...) # make tuple out of array
# or: :($(return_args...),)
new_line = :($(return_tuple) = $(rev_code))
push!(new_code.args, new_line)
end
sort(all_vars), new_code
end
function make_function(all_vars, code)
vars = Expr(:tuple, all_vars...) # make a tuple of the variables
if all_vars[1] == :_A_
vars2 = Expr(:tuple, (all_vars[2:end])...) # miss out _A_
push!(code.args, :(return $(vars2)))
else
push!(code.args, :(return $(vars)))
end
# @show code
function_code = :( $(vars) -> $(code) )
function_code
end
doc"""
Generate code for an anonymous function with given
input arguments, output arguments, and code block.
"""
function make_function(input_args, output_args, code)
input = Expr(:tuple, input_args...) # make a tuple of the variables
output = Expr(:tuple, output_args...) # make a tuple of the variables
new_code = copy(code)
push!(new_code.args, :(return $output))
return :( $input -> $new_code )
end
doc"""`parse_comparison` parses comparisons like `x >= 10`
into the corresponding interval, expressed as `x ∈ [10,∞]`
Returns the expression and the constraint interval
TODO: Allow something like [3,4]' for the complement of [3,4]'"""
function parse_comparison(ex)
expr, limits =
@match ex begin
((a_ <= b_) | (a_ < b_)) => (a, (-∞, b))
((a_ >= b_) | (a_ > b_)) => (a, (b, ∞))
((a_ == b_) | (a_ = b_)) => (a, (b, b))
((a_ <= b_ <= c_)
| (a_ < b_ < c_)
| (a_ <= b_ < c)
| (a_ < b_ <= c)) => (b, (a, c))
((a_ >= b_ >= c_)
| (a_ > b_ > c_)
| (a_ >= b_ > c_)
| (a_ > b_ >= c)) => (b, (c, a))
((a_ ∈ [b_, c_])
| (a_ in [b_, c_])
| (a_ ∈ b_ .. c_)
| (a_ in b_ .. c_)) => (a, (b, c))
_ => (ex, (-∞, ∞))
end
a, b = limits
return (expr, a..b) # expr ∈ [a,b]
end
type Contractor
variables::Vector{Symbol}
constraint_expression::Expr
contractor::Function
code::Expr
end
function Contractor(ex::Expr)
expr, constraint_interval = parse_comparison(ex)
vars, code = forward_backward(expr, constraint_interval)
fn = eval(make_function(vars, code))
Contractor(vars, expr, fn, code)
end
# new call syntax to define a "functor" (object that behaves like a function)
@compat (C::Contractor)(x...) = C.contractor(x...)
function Base.show(io::IO, C::Contractor)
println(io, "Contractor:")
println(io, " - variables: $(C.variables)")
print(io, " - constraint: $(C.constraint_expression)")
end
doc"""Usage:
```
C = @contractor(x^2 + y^2 <= 1)
x = y = @interval(0.5, 1.5)
C(x, y)
`@contractor` makes a function that takes as arguments the variables contained in the expression, in lexicographic order
```
TODO: Hygiene for global variables, or pass in parameters
"""
function contractor(ex)
expr, constraint = parse_comparison(ex)
@show expr, constraint
all_vars, code = forward_backward(expr, constraint)
@show all_vars, code
make_function(all_vars, code)
end
macro contractor(ex)
ex = Meta.quot(ex)
:(Contractor($ex))
end
show_code(c::Contractor) = c.code