Skip to content

Commit dffcf99

Browse files
janselpytorchmergebot
authored andcommitted
Misc changes from compiled autograd branch (pytorch#104316)
This PR pulls out some standalone changes from pytorch#103822 Pull Request resolved: pytorch#104316 Approved by: https://github.com/ezyang
1 parent e80787c commit dffcf99

File tree

6 files changed

+57
-24
lines changed

6 files changed

+57
-24
lines changed

aten/src/ATen/TensorGeometry.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ struct TORCH_API TensorGeometry {
113113
return r;
114114
}
115115

116+
std::vector<c10::SymInt>& mutable_sizes() {
117+
return sizes_;
118+
}
119+
std::vector<c10::SymInt>& mutable_strides() {
120+
return strides_;
121+
}
122+
c10::SymInt& mutable_storage_offset() {
123+
return storage_offset_;
124+
}
125+
void recompute() {
126+
// recalculate numel after a change
127+
c10::SymInt numel = 1;
128+
for (const auto& i : sizes_) {
129+
numel = numel * i;
130+
}
131+
numel_ = std::move(numel);
132+
has_symbolic_sizes_strides_ =
133+
!c10::asIntArrayRefSlowOpt(sizes_).has_value();
134+
}
135+
116136
private:
117137
std::vector<c10::SymInt> sizes_;
118138
std::vector<c10::SymInt> strides_;

c10/core/SafePyObject.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,33 @@ struct C10_API SafePyObject {
2222
// Steals a reference to data
2323
SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
2424
: data_(data), pyinterpreter_(pyinterpreter) {}
25+
SafePyObject(SafePyObject&& other)
26+
: data_(std::exchange(other.data_, nullptr)),
27+
pyinterpreter_(other.pyinterpreter_) {}
2528

2629
// In principle this could be copyable if we add an incref to PyInterpreter
2730
// but for now it's easier to just disallow it.
2831
SafePyObject(SafePyObject const&) = delete;
2932
SafePyObject& operator=(SafePyObject const&) = delete;
3033

3134
~SafePyObject() {
32-
(*pyinterpreter_)->decref(data_, /*is_tensor*/ false);
35+
if (data_ != nullptr) {
36+
(*pyinterpreter_)->decref(data_, /*is_tensor*/ false);
37+
}
3338
}
3439

3540
c10::impl::PyInterpreter& pyinterpreter() const {
3641
return *pyinterpreter_;
3742
}
3843
PyObject* ptr(const c10::impl::PyInterpreter*) const;
3944

45+
// stop tracking the current object, and return it
46+
PyObject* release() {
47+
auto rv = data_;
48+
data_ = nullptr;
49+
return rv;
50+
}
51+
4052
private:
4153
PyObject* data_;
4254
c10::impl::PyInterpreter* pyinterpreter_;

tools/autograd/templates/Functions.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ struct TypeAndSize {
5252

5353
Tensor zeros() { return at::zeros_symint(sym_sizes, options); }
5454

55-
private:
5655
std::vector<c10::SymInt> sym_sizes;
5756
at::TensorOptions options;
5857
};

torch/_dynamo/variables/builder.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,19 +1028,10 @@ def wrap_unspecialized_primitive(self, value):
10281028
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
10291029
)
10301030

1031-
wrapped_value = shape_env.create_symintnode(
1032-
# TODO: This is wrong wrong wrong, create_symbol will
1033-
# generate something that is non-negative, but this is
1034-
# not a sound assumption to make.
1035-
# Not fixing as this was a preexisting condition.
1036-
shape_env.create_symbol(
1037-
value,
1038-
source=self.source,
1039-
dynamic_dim=dynamic_dim,
1040-
constraint_dim=None,
1041-
),
1042-
hint=value,
1031+
wrapped_value = shape_env.create_symint_and_symbol(
1032+
value,
10431033
source=self.source,
1034+
dynamic_dim=dynamic_dim,
10441035
)
10451036
self.tx.output.tracked_fakes.append(
10461037
TrackedFake(wrapped_value, self.source, None)

torch/_functorch/aot_autograd.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2799,12 +2799,6 @@ def forward(ctx, *deduped_flat_tensor_args):
27992799
)
28002800

28012801
num_outputs = CompiledFunction.metadata.num_outputs
2802-
num_outputs_aliased_to_inputs = (
2803-
CompiledFunction.metadata.num_outputs_aliased_to_inputs
2804-
)
2805-
num_outputs_aliased_to_intermediates = (
2806-
CompiledFunction.metadata.num_outputs_aliased_to_intermediates
2807-
)
28082802
num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased
28092803
num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases
28102804
num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw
@@ -2979,10 +2973,12 @@ def backward(ctx, *flat_args):
29792973
# Add the seed and offset to args
29802974
rng_args = CUDARngStateHelper.get_torch_state_as_tuple()
29812975

2982-
all_args = (
2983-
list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args) + list(rng_args)
2984-
)
2985-
2976+
all_args = [
2977+
*ctx.symints,
2978+
*ctx.saved_tensors,
2979+
*contiguous_args,
2980+
*rng_args
2981+
]
29862982
del contiguous_args
29872983

29882984
def call_compiled_backward():

torch/fx/experimental/symbolic_shapes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2245,6 +2245,21 @@ def create_symintnode(
22452245
return int(sym)
22462246
return SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
22472247

2248+
def create_symint_and_symbol(self, value, source, dynamic_dim):
2249+
# TODO: This is wrong wrong wrong, create_symbol will
2250+
# generate something that is non-negative, but this is
2251+
# not a sound assumption to make.
2252+
# Not fixing as this was a preexisting condition.
2253+
return self.create_symintnode(
2254+
self.create_symbol(
2255+
value,
2256+
source=source,
2257+
dynamic_dim=dynamic_dim,
2258+
),
2259+
hint=value,
2260+
source=source,
2261+
)
2262+
22482263
def create_symboolnode(self, sym: "sympy.Expr"):
22492264
# This function is only being used in serialization, so we do not track it
22502265
# for validation.

0 commit comments

Comments
 (0)