-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathomeinsum.jl
87 lines (77 loc) · 2.46 KB
/
omeinsum.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import OMEinsum
import ArgParse
import JSON
using KaHyPar
function parse_commandline()
s = ArgParse.ArgParseSettings()
@ArgParse.add_arg_table s begin
"--einsum_json"
arg_type = String
default = "einsum.json"
"--result_json"
arg_type = String
default = "opteinsum.json"
"--sc_target"
arg_type = Int
default = 20
"--beta_start"
arg_type = Float64
default = 0.01
"--beta_step"
arg_type = Float64
default = 0.01
"--beta_stop"
arg_type = Float64
default = 15.0
"--ntrials"
arg_type = Int
default = 10
"--niters"
arg_type = Int
default = 50
"--sc_weight"
arg_type = Float64
default = 1.0
"--rw_weight"
arg_type = Float64
default = 0.2
"--kahypar_init"
action = :store_true
end
return ArgParse.parse_args(s)
end
function main()
parsed_args = parse_commandline()
# println("Parsed args:")
# for (arg,val) in parsed_args
# println(" $arg => $val")
# end
# println(Threads.nthreads())
contraction_args = JSON.parsefile(parsed_args["einsum_json"])
inputs = map(Tuple, contraction_args["inputs"])
output = contraction_args["output"]
eincode = OMEinsum.EinCode(Tuple(inputs), Tuple(output))
size_dict = OMEinsum.uniformsize(eincode, 2)
for (k, v) in contraction_args["size"]
size_dict[k] = v
end
if parsed_args["kahypar_init"]
eincode = OMEinsum.optimize_code(eincode, size_dict, OMEinsum.KaHyParBipartite(
sc_target=parsed_args["sc_target"],
max_group_size=50))
end
algorithm = OMEinsum.TreeSA(
sc_target=parsed_args["sc_target"],
βs=parsed_args["beta_start"]:parsed_args["beta_step"]:parsed_args["beta_stop"],
ntrials=parsed_args["ntrials"],
niters=parsed_args["niters"],
sc_weight=parsed_args["sc_weight"],
rw_weight=parsed_args["rw_weight"],
initializer=parsed_args["kahypar_init"] ? :specified : :greedy
)
# println(parsed_args["beta_start"]:parsed_args["beta_step"]:parsed_args["beta_stop"])
# println(algorithm)
optcode = OMEinsum.optimize_code(eincode, size_dict, algorithm)
OMEinsum.writejson(parsed_args["result_json"], optcode)
end
main()