Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 2c53dc4

Browse files
authoredNov 16, 2020
Add experimental functional while loop staging. (#1117)
1 parent f44d890 commit 2c53dc4

14 files changed

+451
-17
lines changed
 

‎Sources/CX10/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ cc_library(
1111
cc_library(
1212
name = "xla_tensor_wrapper",
1313
srcs = [
14+
"functional_while.cc",
1415
"xla_tensor_wrapper.cc",
1516
"xla_tensor_ops_wrapper.cc",
1617
"xla_tensor_ops_wrapper_generated.cc.inc",

‎Sources/CX10/functional_while.cc

+319
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
// Copyright 2020 TensorFlow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if defined(_WIN32)
16+
#define XLA_API __declspec(dllexport)
17+
#else
18+
#define XLA_API __attribute__((__visibility__("default")))
19+
#endif
20+
21+
#include "xla_tensor_wrapper.h"
22+
23+
#include "absl/container/flat_hash_set.h"
24+
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
25+
#include "tensorflow/compiler/tf2xla/xla_tensor/ir.h"
26+
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
27+
28+
using swift_xla::XLATensor;
29+
using swift_xla::ir::LoweringContext;
30+
using swift_xla::ir::Node;
31+
using swift_xla::ir::NodePtr;
32+
using swift_xla::ir::OpList;
33+
using swift_xla::ir::Output;
34+
using swift_xla::ir::Value;
35+
using swift_xla::ir::XlaOpVector;
36+
37+
xla::Shape ShapeOfXlaOpList(absl::Span<const Value> ops) {
38+
xla::Shape result;
39+
result.set_element_type(xla::TUPLE);
40+
result.mutable_tuple_shapes()->reserve(ops.size());
41+
for (const auto& op : ops) {
42+
xla::ShapeUtil::AppendShapeToTuple(op.shape(), &result);
43+
}
44+
TF_DCHECK_OK(xla::ShapeUtil::ValidateShapeWithOptionalLayout(result));
45+
return result;
46+
}
47+
48+
struct ExtraInputDiscovery {
49+
// TODO: color when building the graph as this can be n^2
50+
// in the number of for loops.
51+
void BackRefVisit(const Output& v, const Node* node = nullptr) {
52+
auto& state = state_map[v.node];
53+
if (!state.visited) {
54+
state.visited = true;
55+
work_list.push_back(v.node);
56+
}
57+
if (node) state.refs.push_back(node);
58+
}
59+
void PlaceholderVisit(const Node* node) {
60+
auto& state = state_map[node];
61+
if (!state.depends_on_placeholder) {
62+
state.depends_on_placeholder = true;
63+
work_list.push_back(node);
64+
}
65+
}
66+
void WorkListBackRefVisit() {
67+
while (!work_list.empty()) {
68+
const Node* node = work_list.back();
69+
work_list.pop_back();
70+
for (const auto& value : node->operands()) {
71+
BackRefVisit(value, node);
72+
}
73+
}
74+
}
75+
void WorkListPlaceholderVisit() {
76+
while (!work_list.empty()) {
77+
const Node* node = work_list.back();
78+
work_list.pop_back();
79+
for (auto* ref : state_map[node].refs) {
80+
PlaceholderVisit(ref);
81+
}
82+
}
83+
}
84+
void BackRefVisitExtraSearch(const Output& v, const NodePtr& n) {
85+
auto& state = state_map[v.node];
86+
if (!state.visited_looking_for_extras) {
87+
state.visited_looking_for_extras = true;
88+
if (state.depends_on_placeholder) {
89+
work_list.push_back(v.node);
90+
} else {
91+
results.push_back(Value(n, v.index));
92+
}
93+
}
94+
}
95+
void WorkListBackRefVisitExtraSearch() {
96+
while (!work_list.empty()) {
97+
const Node* node = work_list.back();
98+
work_list.pop_back();
99+
auto& operands = node->operands();
100+
auto& node_ptrs = node->operand_nodes();
101+
for (size_t i = 0; i < operands.size(); ++i) {
102+
BackRefVisitExtraSearch(operands[i], node_ptrs[i]);
103+
}
104+
}
105+
}
106+
struct State {
107+
State() {}
108+
bool visited =
109+
false; // Has been fully visited if true and work_list.empty().
110+
bool depends_on_placeholder = false;
111+
bool visited_looking_for_extras = false;
112+
std::vector<const Node*> refs;
113+
};
114+
std::vector<const Node*> work_list;
115+
absl::flat_hash_map<const Node*, State> state_map;
116+
std::vector<Value> results;
117+
};
118+
119+
std::vector<Value> DiscoverExtraInputs(absl::Span<const Value> results,
120+
const Value& index_placeholder,
121+
absl::Span<const Value> placeholders) {
122+
ExtraInputDiscovery state;
123+
for (auto& result : results) {
124+
state.BackRefVisit(result);
125+
}
126+
state.WorkListBackRefVisit();
127+
for (auto& placeholder : placeholders) {
128+
state.PlaceholderVisit(placeholder.node.get());
129+
}
130+
state.PlaceholderVisit(index_placeholder.node.get());
131+
state.WorkListPlaceholderVisit();
132+
for (auto& result : results) {
133+
state.BackRefVisitExtraSearch(result, result.node);
134+
}
135+
state.WorkListBackRefVisitExtraSearch();
136+
return std::move(state.results);
137+
}
138+
139+
class XLAFunctionalWhileNode : public swift_xla::ir::Node {
140+
public:
141+
static std::vector<Value> BuildArgs(absl::Span<const Value> initial,
142+
const Value& n,
143+
absl::Span<const Value> extras) {
144+
std::vector<Value> out(initial.begin(), initial.end());
145+
out.push_back(n);
146+
out.insert(out.end(), extras.begin(), extras.end());
147+
return out;
148+
}
149+
static xla::hash_t HashOfResults(absl::Span<const Value> results) {
150+
xla::hash_t hash = 0;
151+
for (auto& result : results)
152+
hash = xla::util::HashCombine(hash, result.hash());
153+
return hash;
154+
}
155+
XLAFunctionalWhileNode(absl::Span<const Value> initial, const Value& n,
156+
const Value& index_placeholder,
157+
absl::Span<const Value> placeholders,
158+
absl::Span<const Value> results)
159+
: Node(swift_xla::ir::OpKind(at::aten::functional_while),
160+
BuildArgs(
161+
initial, n,
162+
DiscoverExtraInputs(results, index_placeholder, placeholders)),
163+
ShapeOfXlaOpList(results), results.size(), HashOfResults(results)),
164+
index_placeholder_(index_placeholder),
165+
placeholders_(placeholders.begin(), placeholders.end()),
166+
results_(results.begin(), results.end()) {}
167+
168+
static xla::XlaOp zeroLike(xla::XlaOp op) {
169+
auto* b = op.builder();
170+
return xla::ConstantLiteral(
171+
b, xla::LiteralUtil::Zero(
172+
swift_xla::XlaHelpers::ShapeOfXlaOp(op).element_type()));
173+
}
174+
175+
static xla::XlaOp oneLike(xla::XlaOp op) {
176+
auto* b = op.builder();
177+
return xla::ConstantLiteral(
178+
b, xla::LiteralUtil::One(
179+
swift_xla::XlaHelpers::ShapeOfXlaOp(op).element_type()));
180+
}
181+
182+
XlaOpVector Lower(LoweringContext* loctx) const {
183+
size_t last_i = placeholders_.size();
184+
185+
auto body_builder = loctx->builder()->CreateSubBuilder("loop_body");
186+
xla::XlaOp initial;
187+
{
188+
std::vector<xla::XlaOp> args;
189+
args.reserve(operands().size() + 1);
190+
for (size_t i = 0; i < last_i; ++i) {
191+
args.push_back(loctx->GetOutputOp(operand(i)));
192+
}
193+
auto tmp = loctx->GetOutputOp(operand(last_i));
194+
auto it = zeroLike(tmp);
195+
args.push_back(it);
196+
args.push_back(tmp);
197+
for (size_t i = last_i + 1; i < operands().size(); ++i) {
198+
args.push_back(loctx->GetOutputOp(operand(i)));
199+
}
200+
201+
initial = xla::Tuple(loctx->builder(), args);
202+
}
203+
xla::XlaOp body_result;
204+
{
205+
auto* b = body_builder.get();
206+
swift_xla::ir::Util::EmissionMap emap;
207+
for (const auto& placeholder : placeholders_) {
208+
emap[placeholder.node.get()] = swift_xla::ir::Util::kEmitted;
209+
}
210+
for (size_t i = last_i + 1; i < operands().size(); ++i) {
211+
emap[operand(i).node] = swift_xla::ir::Util::kEmitted;
212+
}
213+
emap[index_placeholder_.node.get()] = swift_xla::ir::Util::kEmitted;
214+
swift_xla::ir::LoweringContext body_loctx(b, loctx->device(),
215+
std::move(emap));
216+
auto t = xla::Parameter(
217+
b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple");
218+
auto p1 = xla::GetTupleElement(t, last_i);
219+
auto p2 = xla::GetTupleElement(t, last_i + 1);
220+
for (size_t i = 0; i < placeholders_.size(); ++i) {
221+
body_loctx.AssignOutputOp(placeholders_[i], xla::GetTupleElement(t, i));
222+
}
223+
for (size_t i = last_i + 1; i < operands().size(); ++i) {
224+
body_loctx.AssignOutputOp(operand(i), xla::GetTupleElement(t, i + 1));
225+
}
226+
body_loctx.AssignOutputOp(index_placeholder_, p1);
227+
228+
std::vector<xla::XlaOp> tmps;
229+
for (auto& result : results_) {
230+
tmps.push_back(body_loctx.GetOutputOp(result));
231+
}
232+
tmps.push_back(p1 + oneLike(p1));
233+
tmps.push_back(p2);
234+
for (size_t i = last_i + 1; i < operands().size(); ++i) {
235+
tmps.push_back(body_loctx.GetOutputOp(operand(i)));
236+
}
237+
body_result = xla::Tuple(b, tmps);
238+
}
239+
240+
auto cond_builder = loctx->builder()->CreateSubBuilder("cond_body");
241+
xla::XlaOp cond_result;
242+
{
243+
auto* b = cond_builder.get();
244+
auto t = xla::Parameter(
245+
b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple");
246+
auto p1 = xla::GetTupleElement(t, last_i);
247+
auto p2 = xla::GetTupleElement(t, last_i + 1);
248+
cond_result = xla::Lt(p1, p2);
249+
}
250+
251+
auto result = xla::While(
252+
cond_builder->Build(cond_result).ConsumeValueOrDie(),
253+
body_builder->Build(body_result).ConsumeValueOrDie(), initial);
254+
255+
std::vector<xla::XlaOp> results;
256+
for (size_t i = 0; i < last_i; ++i) {
257+
results.push_back(xla::GetTupleElement(result, i));
258+
}
259+
return ReturnOps(results, loctx);
260+
}
261+
262+
Value index_placeholder_;
263+
std::vector<Value> placeholders_;
264+
std::vector<Value> results_;
265+
};
266+
267+
class XLAPlaceholderNode : public swift_xla::ir::Node {
268+
public:
269+
XLAPlaceholderNode(xla::Shape shape, int id)
270+
: Node(swift_xla::ir::OpKind(at::aten::placeholder), {}, shape, 1,
271+
xla::util::MHash(id)),
272+
id_(id) {}
273+
NodePtr Clone(OpList operands) const override {
274+
return swift_xla::ir::MakeNode<XLAPlaceholderNode>(shape(), id_);
275+
}
276+
XlaOpVector Lower(LoweringContext* loctx) const override {
277+
LOG(FATAL) << "Cannot lower placeholder: " << ToString() << " id: " << id_;
278+
}
279+
std::string ToString() const override {
280+
std::stringstream ss;
281+
ss << Node::ToString() << ", id=" << id_;
282+
return ss.str();
283+
}
284+
int id_;
285+
};
286+
287+
std::vector<Value> UnpackIrValues(OpaqueXLATensorArrayRef array) {
288+
std::vector<Value> out;
289+
out.reserve(array.size);
290+
for (size_t i = 0; i < array.size; ++i) {
291+
out.push_back(array.data[i]->GetIrValue());
292+
}
293+
return out;
294+
}
295+
296+
OpaqueXLATensorArrayRef XLATensor_functional_while(
297+
OpaqueXLATensor* n, OpaqueXLATensorArrayRef initial,
298+
OpaqueXLATensorArrayRef placeholders, OpaqueXLATensor* indexPlaceholder,
299+
OpaqueXLATensorArrayRef results) {
300+
auto initial_ir = UnpackIrValues(initial);
301+
auto placeholders_ir = UnpackIrValues(placeholders);
302+
auto results_ir = UnpackIrValues(results);
303+
304+
auto result_node = swift_xla::ir::MakeNode<XLAFunctionalWhileNode>(
305+
initial_ir, n->GetIrValue(), indexPlaceholder->GetIrValue(),
306+
placeholders_ir, results_ir);
307+
size_t count = results.size;
308+
auto opaque_tensors = new OpaqueXLATensor*[count];
309+
for (size_t i = 0; i < count; ++i) {
310+
opaque_tensors[i] = new XLATensor(
311+
results.data[i]->CreateFrom(swift_xla::ir::Value(result_node, i)));
312+
}
313+
return {opaque_tensors, count};
314+
}
315+
316+
OpaqueXLATensor* XLATensor_makePlaceholder(OpaqueXLATensor* t, int id) {
317+
return new XLATensor(t->CreateFrom(
318+
swift_xla::ir::MakeNode<XLAPlaceholderNode>(t->shape(), id)));
319+
}

‎Sources/CX10/xla_tensor_wrapper.cc

+5
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,11 @@ OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a) {
315315
swift_xla::ir::DumpUtil::ToText({a->GetIrValue().node.get()});
316316
return new std::string(ir_dag_text);
317317
}
318+
OpaqueString* XLATensor_xla_ir_text(OpaqueXLATensor* a) {
319+
std::string ir_dag_text =
320+
swift_xla::ir::DumpUtil::ToHlo({a->GetIrValue()}, a->GetDevice());
321+
return new std::string(ir_dag_text);
322+
}
318323
OpaqueXLATensor* XLATensor_linspace(XLAScalar start, XLAScalar stop,
319324
int64_t num, const CDevice device,
320325
enum XLATensorScalarType type) {

‎Sources/CX10/xla_tensor_wrapper.h

+6
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ XLA_API OpaqueXLATensor* XLATensor_ge(OpaqueXLATensor* x, OpaqueXLATensor* y);
293293
XLA_API OpaqueString* XLATensor_get_annotations(OpaqueXLATensor* a);
294294
XLA_API OpaqueXLATensor* XLATensor_gt(OpaqueXLATensor* x, OpaqueXLATensor* y);
295295
XLA_API OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a);
296+
XLA_API OpaqueString* XLATensor_xla_ir_text(OpaqueXLATensor* a);
296297
XLA_API OpaqueXLATensor* XLATensor_is_finite(OpaqueXLATensor* input);
297298
XLA_API OpaqueXLATensor* XLATensor_is_inf(OpaqueXLATensor* input);
298299
XLA_API OpaqueXLATensor* XLATensor_is_nan(OpaqueXLATensor* input);
@@ -427,6 +428,11 @@ XLA_API OpaqueXLATensor* XLATensor_xla_slice(OpaqueXLATensor* input,
427428
Int64ArrayRef begin,
428429
Int64ArrayRef end,
429430
Int64ArrayRef strides);
431+
XLA_API OpaqueXLATensorArrayRef XLATensor_functional_while(
432+
OpaqueXLATensor* n, OpaqueXLATensorArrayRef initial,
433+
OpaqueXLATensorArrayRef placeholders, OpaqueXLATensor* indexPlaceholder,
434+
OpaqueXLATensorArrayRef results);
435+
XLA_API OpaqueXLATensor* XLATensor_makePlaceholder(OpaqueXLATensor* t, int id);
430436
// Retrieves the device for a given tensor.
431437
XLA_API struct CDevice XLATensor_device(OpaqueXLATensor* t);
432438
// Creates a float tensor on the current device filled with random numbers in

‎Sources/TensorFlow/Core/Tensor.swift

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ infix operator .!=: ComparisonPrecedence
2222
public protocol AnyTensor {
2323
var _rawTensorHandle: CTensorHandle { get }
2424
var _tensorFlowDataType: TensorDataType { get }
25+
var scalarType: TensorFlowScalar.Type { get }
2526
}
2627

2728
/// A multidimensional array of elements that is a generalization of vectors and matrices to
@@ -55,6 +56,7 @@ public struct Tensor<Scalar: TensorFlowScalar> {
5556
extension Tensor: AnyTensor {
5657
public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle }
5758
public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType }
59+
public var scalarType: TensorFlowScalar.Type { return Scalar.self }
5860
}
5961

6062
//===------------------------------------------------------------------------------------------===//

0 commit comments

Comments
 (0)
This repository has been archived.