-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathomeinsum_treesa_optimizer.py
50 lines (44 loc) · 1.48 KB
/
omeinsum_treesa_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import List, Set, Dict, Tuple
class OMEinsumTreeSAOptimizer(object):
def __init__(
self,
sc_target: int = 20,
betas: Tuple[float, float, float] = (0.01, 0.01, 15),
ntrials: int = 10,
niters: int = 50,
sc_weight: float = 1.0,
rw_weight: float = 0.2,
):
self.sc_target = sc_target
self.betas = betas
self.ntrials = ntrials
self.niters = niters
self.sc_weight = sc_weight
self.rw_weight = rw_weight
def _contraction_tree_to_contraction_path(self, ei, queue, path, idx):
if ei["isleaf"]:
# OMEinsum provide 1-based index
# but in contraction path we want 0-based index
ei["tensorindex"] -= 1
return idx
assert len(ei["args"]) == 2, "must be a binary tree"
for child in ei["args"]:
idx = self._contraction_tree_to_contraction_path(child, queue, path, idx)
assert "tensorindex" in child
lhs_args = sorted(
[queue.index(child["tensorindex"]) for child in ei["args"]], reverse=True
)
for arg in lhs_args:
queue.pop(arg)
ei["tensorindex"] = idx
path.append(lhs_args)
queue.append(idx)
return idx + 1
def __call__(
self,
inputs: List[Set[str]],
output: Set[str],
size: Dict[str, int],
memory_limit=None,
) -> List[Tuple[int, int]]:
raise NotImplementedError