forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory_dag.h
176 lines (146 loc) · 6.28 KB
/
memory_dag.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#pragma once
#include <ATen/core/jit_type.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/sparse_bitset.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/type_hashing.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <torch/csrc/Export.h>
// Uses a compressed index representation for faster comparisons
typedef c10::SparseBitVector<256> MemoryLocations;
namespace torch {
namespace jit {
struct Value;
using AliasTypeSet = std::vector<TypePtr>;
// `Element` represents a vertex in the points-to graph. It represents
// anything that could have an aliasing relationship--mostly IR
// `Value`s, but also wildcards or the type inside a container (e.g. `T`
// in `List[T]`)
struct Element {
Element(const Value* value_, unsigned index_);
// wildcard constructor
explicit Element(unsigned index_);
// Index into the owning DAG's bit vector that represents this element.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
unsigned index;
// All elements that this element *may* point to. It's possible to have
// multiple elements that you might point to due to control flow/complex ops
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations pointsTo;
// Backreference for points-to.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations pointedFrom;
// Elements can contain other elements (e.g. List[Tensor])
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations containedElements;
// The values that this element corresponds to. May be empty if this element
// doesn't represent a first-class value.
// This is for debug information only.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unordered_set<const Value*> values;
private:
// Make `from` point at `to`.
void makePointerTo(Element* from, Element* to);
friend class MemoryDAG;
// We memoize the results of `getMemoryLocations` to speed up queries.
// A nullopt means that this cache is not yet populated. Since `MemoryDAG` is
// immutable, this cache should never need to be invalidated.
mutable std::optional<MemoryLocations> cachedMemoryLocations_;
mutable std::optional<MemoryLocations> cachedAllContainedMemoryLocations_;
};
// class MemoryDAG
//
// This class tracks the "A points to B" graph for all values. It is used by
// AliasDb to provide a higher-level API.
//
// We maintain a DAG where:
// - Vertices (called "Elements") represent Values and
// other aliasing entities (e.g. the stuff inside a list)
// - Edges represent a "points-to" relationship.
//
// Leaves in this DAG are entities that don't point to anything, and thus
// correspond to unique "memory locations".
//
// So, by traversing the "points-to" graph to the leaves, you can determine
// which memory locations an element may point to.
class TORCH_API MemoryDAG {
public:
explicit MemoryDAG(std::vector<std::unique_ptr<Element>> indexToElementMap)
: indexToElementMap_(std::move(indexToElementMap)) {}
// explicitly delete copy constructor because otherwise windows build is
// confused for an exported class see
// https://stackoverflow.com/a/51033485/105137
MemoryDAG(const MemoryDAG&) = delete;
MemoryDAG& operator=(const MemoryDAG&) = delete;
// Return the unique memory locations that `Element` might represent.
const MemoryLocations& getMemoryLocations(const Element* e) const;
// Do `a` and `b` potentially share a memory location?
bool mayAlias(const Element* a, const Element* b) const;
// Does `a` hold reference to any memory that is stored in `b`, or vice versa?
bool mayContainAlias(const Element* a, const Element* b) const;
bool mayContainAlias(const Element* a, const at::ArrayRef<Element*> b) const;
bool mayContainAlias(
const at::ArrayRef<Element*> a,
const at::ArrayRef<Element*> b) const;
// Converts from the compressed index representation
const Element* fromIndex(unsigned x) const;
Element* fromIndex(unsigned x);
void collectAllContainedMemoryLocations(
const Element* elem,
MemoryLocations& cont) const;
/**
* The following methods are special cases where we need to mutate the
* internals of MemoryDAG for efficiency reasons. Don't call them unless you
* know what you're doing! In particular, don't add new mutating methods
* without ensuring that you are maintaining cache consistency for memory
* locations.
*/
// Adding wildcards can trigger extremely expensive cache invalidations. This
// method adds them in a more efficient cache-aware way.
void setWildcards(
const std::unordered_set<const Value*>& wildcards,
const ska::flat_hash_map<const Value*, Element*>& elementMap,
const std::function<Element*(const Value*)>& getWildcardElement);
Element* unsafeMakeFreshValue(const Value* v);
private:
const MemoryLocations& getAllContainedMemoryLocations(
const Element* elem) const;
void collectAllContainedMemoryLocationsImpl(
const Element* elem,
MemoryLocations& cont) const;
std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
/**
* Helper to build up the points-to graph.
*
* We separate the "building" into a different class because it allows us to
* cache internally to MemoryDAG without worrying about how the DAG structure
* is mutated.
*/
class TORCH_API MemoryDAGBuilder {
public:
MemoryDAGBuilder() = default;
MemoryDAGBuilder(const MemoryDAGBuilder&) = delete;
MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete;
// Make `from` point at `to`.
void makePointerTo(Element* from, Element* to);
void addToContainedElements(Element* contained, Element* container);
std::unique_ptr<MemoryDAG> createMemoryDAG() && {
return std::make_unique<MemoryDAG>(std::move(indexToElementMap_));
}
// Make a fresh Element (i.e. an Element that doesn't point to anything) and
// return it.
Element* makeFreshValue(const Value* v);
friend MemoryDAG;
private:
// `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses
// the map to construct the `MemoryDAG`
std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
} // namespace jit
} // namespace torch