|
| 1 | +// Copyright 2020 TensorFlow Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#if defined(_WIN32) |
| 16 | +#define XLA_API __declspec(dllexport) |
| 17 | +#else |
| 18 | +#define XLA_API __attribute__((__visibility__("default"))) |
| 19 | +#endif |
| 20 | + |
| 21 | +#include "xla_tensor_wrapper.h" |
| 22 | + |
| 23 | +#include "absl/container/flat_hash_set.h" |
| 24 | +#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h" |
| 25 | +#include "tensorflow/compiler/tf2xla/xla_tensor/ir.h" |
| 26 | +#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h" |
| 27 | + |
| 28 | +using swift_xla::XLATensor; |
| 29 | +using swift_xla::ir::LoweringContext; |
| 30 | +using swift_xla::ir::Node; |
| 31 | +using swift_xla::ir::NodePtr; |
| 32 | +using swift_xla::ir::OpList; |
| 33 | +using swift_xla::ir::Output; |
| 34 | +using swift_xla::ir::Value; |
| 35 | +using swift_xla::ir::XlaOpVector; |
| 36 | + |
| 37 | +xla::Shape ShapeOfXlaOpList(absl::Span<const Value> ops) { |
| 38 | + xla::Shape result; |
| 39 | + result.set_element_type(xla::TUPLE); |
| 40 | + result.mutable_tuple_shapes()->reserve(ops.size()); |
| 41 | + for (const auto& op : ops) { |
| 42 | + xla::ShapeUtil::AppendShapeToTuple(op.shape(), &result); |
| 43 | + } |
| 44 | + TF_DCHECK_OK(xla::ShapeUtil::ValidateShapeWithOptionalLayout(result)); |
| 45 | + return result; |
| 46 | +} |
| 47 | + |
| 48 | +struct ExtraInputDiscovery { |
| 49 | + // TODO: color when building the graph as this can be n^2 |
| 50 | + // in the number of for loops. |
| 51 | + void BackRefVisit(const Output& v, const Node* node = nullptr) { |
| 52 | + auto& state = state_map[v.node]; |
| 53 | + if (!state.visited) { |
| 54 | + state.visited = true; |
| 55 | + work_list.push_back(v.node); |
| 56 | + } |
| 57 | + if (node) state.refs.push_back(node); |
| 58 | + } |
| 59 | + void PlaceholderVisit(const Node* node) { |
| 60 | + auto& state = state_map[node]; |
| 61 | + if (!state.depends_on_placeholder) { |
| 62 | + state.depends_on_placeholder = true; |
| 63 | + work_list.push_back(node); |
| 64 | + } |
| 65 | + } |
| 66 | + void WorkListBackRefVisit() { |
| 67 | + while (!work_list.empty()) { |
| 68 | + const Node* node = work_list.back(); |
| 69 | + work_list.pop_back(); |
| 70 | + for (const auto& value : node->operands()) { |
| 71 | + BackRefVisit(value, node); |
| 72 | + } |
| 73 | + } |
| 74 | + } |
| 75 | + void WorkListPlaceholderVisit() { |
| 76 | + while (!work_list.empty()) { |
| 77 | + const Node* node = work_list.back(); |
| 78 | + work_list.pop_back(); |
| 79 | + for (auto* ref : state_map[node].refs) { |
| 80 | + PlaceholderVisit(ref); |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + void BackRefVisitExtraSearch(const Output& v, const NodePtr& n) { |
| 85 | + auto& state = state_map[v.node]; |
| 86 | + if (!state.visited_looking_for_extras) { |
| 87 | + state.visited_looking_for_extras = true; |
| 88 | + if (state.depends_on_placeholder) { |
| 89 | + work_list.push_back(v.node); |
| 90 | + } else { |
| 91 | + results.push_back(Value(n, v.index)); |
| 92 | + } |
| 93 | + } |
| 94 | + } |
| 95 | + void WorkListBackRefVisitExtraSearch() { |
| 96 | + while (!work_list.empty()) { |
| 97 | + const Node* node = work_list.back(); |
| 98 | + work_list.pop_back(); |
| 99 | + auto& operands = node->operands(); |
| 100 | + auto& node_ptrs = node->operand_nodes(); |
| 101 | + for (size_t i = 0; i < operands.size(); ++i) { |
| 102 | + BackRefVisitExtraSearch(operands[i], node_ptrs[i]); |
| 103 | + } |
| 104 | + } |
| 105 | + } |
| 106 | + struct State { |
| 107 | + State() {} |
| 108 | + bool visited = |
| 109 | + false; // Has been fully visited if true and work_list.empty(). |
| 110 | + bool depends_on_placeholder = false; |
| 111 | + bool visited_looking_for_extras = false; |
| 112 | + std::vector<const Node*> refs; |
| 113 | + }; |
| 114 | + std::vector<const Node*> work_list; |
| 115 | + absl::flat_hash_map<const Node*, State> state_map; |
| 116 | + std::vector<Value> results; |
| 117 | +}; |
| 118 | + |
| 119 | +std::vector<Value> DiscoverExtraInputs(absl::Span<const Value> results, |
| 120 | + const Value& index_placeholder, |
| 121 | + absl::Span<const Value> placeholders) { |
| 122 | + ExtraInputDiscovery state; |
| 123 | + for (auto& result : results) { |
| 124 | + state.BackRefVisit(result); |
| 125 | + } |
| 126 | + state.WorkListBackRefVisit(); |
| 127 | + for (auto& placeholder : placeholders) { |
| 128 | + state.PlaceholderVisit(placeholder.node.get()); |
| 129 | + } |
| 130 | + state.PlaceholderVisit(index_placeholder.node.get()); |
| 131 | + state.WorkListPlaceholderVisit(); |
| 132 | + for (auto& result : results) { |
| 133 | + state.BackRefVisitExtraSearch(result, result.node); |
| 134 | + } |
| 135 | + state.WorkListBackRefVisitExtraSearch(); |
| 136 | + return std::move(state.results); |
| 137 | +} |
| 138 | + |
| 139 | +class XLAFunctionalWhileNode : public swift_xla::ir::Node { |
| 140 | + public: |
| 141 | + static std::vector<Value> BuildArgs(absl::Span<const Value> initial, |
| 142 | + const Value& n, |
| 143 | + absl::Span<const Value> extras) { |
| 144 | + std::vector<Value> out(initial.begin(), initial.end()); |
| 145 | + out.push_back(n); |
| 146 | + out.insert(out.end(), extras.begin(), extras.end()); |
| 147 | + return out; |
| 148 | + } |
| 149 | + static xla::hash_t HashOfResults(absl::Span<const Value> results) { |
| 150 | + xla::hash_t hash = 0; |
| 151 | + for (auto& result : results) |
| 152 | + hash = xla::util::HashCombine(hash, result.hash()); |
| 153 | + return hash; |
| 154 | + } |
| 155 | + XLAFunctionalWhileNode(absl::Span<const Value> initial, const Value& n, |
| 156 | + const Value& index_placeholder, |
| 157 | + absl::Span<const Value> placeholders, |
| 158 | + absl::Span<const Value> results) |
| 159 | + : Node(swift_xla::ir::OpKind(at::aten::functional_while), |
| 160 | + BuildArgs( |
| 161 | + initial, n, |
| 162 | + DiscoverExtraInputs(results, index_placeholder, placeholders)), |
| 163 | + ShapeOfXlaOpList(results), results.size(), HashOfResults(results)), |
| 164 | + index_placeholder_(index_placeholder), |
| 165 | + placeholders_(placeholders.begin(), placeholders.end()), |
| 166 | + results_(results.begin(), results.end()) {} |
| 167 | + |
| 168 | + static xla::XlaOp zeroLike(xla::XlaOp op) { |
| 169 | + auto* b = op.builder(); |
| 170 | + return xla::ConstantLiteral( |
| 171 | + b, xla::LiteralUtil::Zero( |
| 172 | + swift_xla::XlaHelpers::ShapeOfXlaOp(op).element_type())); |
| 173 | + } |
| 174 | + |
| 175 | + static xla::XlaOp oneLike(xla::XlaOp op) { |
| 176 | + auto* b = op.builder(); |
| 177 | + return xla::ConstantLiteral( |
| 178 | + b, xla::LiteralUtil::One( |
| 179 | + swift_xla::XlaHelpers::ShapeOfXlaOp(op).element_type())); |
| 180 | + } |
| 181 | + |
| 182 | + XlaOpVector Lower(LoweringContext* loctx) const { |
| 183 | + size_t last_i = placeholders_.size(); |
| 184 | + |
| 185 | + auto body_builder = loctx->builder()->CreateSubBuilder("loop_body"); |
| 186 | + xla::XlaOp initial; |
| 187 | + { |
| 188 | + std::vector<xla::XlaOp> args; |
| 189 | + args.reserve(operands().size() + 1); |
| 190 | + for (size_t i = 0; i < last_i; ++i) { |
| 191 | + args.push_back(loctx->GetOutputOp(operand(i))); |
| 192 | + } |
| 193 | + auto tmp = loctx->GetOutputOp(operand(last_i)); |
| 194 | + auto it = zeroLike(tmp); |
| 195 | + args.push_back(it); |
| 196 | + args.push_back(tmp); |
| 197 | + for (size_t i = last_i + 1; i < operands().size(); ++i) { |
| 198 | + args.push_back(loctx->GetOutputOp(operand(i))); |
| 199 | + } |
| 200 | + |
| 201 | + initial = xla::Tuple(loctx->builder(), args); |
| 202 | + } |
| 203 | + xla::XlaOp body_result; |
| 204 | + { |
| 205 | + auto* b = body_builder.get(); |
| 206 | + swift_xla::ir::Util::EmissionMap emap; |
| 207 | + for (const auto& placeholder : placeholders_) { |
| 208 | + emap[placeholder.node.get()] = swift_xla::ir::Util::kEmitted; |
| 209 | + } |
| 210 | + for (size_t i = last_i + 1; i < operands().size(); ++i) { |
| 211 | + emap[operand(i).node] = swift_xla::ir::Util::kEmitted; |
| 212 | + } |
| 213 | + emap[index_placeholder_.node.get()] = swift_xla::ir::Util::kEmitted; |
| 214 | + swift_xla::ir::LoweringContext body_loctx(b, loctx->device(), |
| 215 | + std::move(emap)); |
| 216 | + auto t = xla::Parameter( |
| 217 | + b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple"); |
| 218 | + auto p1 = xla::GetTupleElement(t, last_i); |
| 219 | + auto p2 = xla::GetTupleElement(t, last_i + 1); |
| 220 | + for (size_t i = 0; i < placeholders_.size(); ++i) { |
| 221 | + body_loctx.AssignOutputOp(placeholders_[i], xla::GetTupleElement(t, i)); |
| 222 | + } |
| 223 | + for (size_t i = last_i + 1; i < operands().size(); ++i) { |
| 224 | + body_loctx.AssignOutputOp(operand(i), xla::GetTupleElement(t, i + 1)); |
| 225 | + } |
| 226 | + body_loctx.AssignOutputOp(index_placeholder_, p1); |
| 227 | + |
| 228 | + std::vector<xla::XlaOp> tmps; |
| 229 | + for (auto& result : results_) { |
| 230 | + tmps.push_back(body_loctx.GetOutputOp(result)); |
| 231 | + } |
| 232 | + tmps.push_back(p1 + oneLike(p1)); |
| 233 | + tmps.push_back(p2); |
| 234 | + for (size_t i = last_i + 1; i < operands().size(); ++i) { |
| 235 | + tmps.push_back(body_loctx.GetOutputOp(operand(i))); |
| 236 | + } |
| 237 | + body_result = xla::Tuple(b, tmps); |
| 238 | + } |
| 239 | + |
| 240 | + auto cond_builder = loctx->builder()->CreateSubBuilder("cond_body"); |
| 241 | + xla::XlaOp cond_result; |
| 242 | + { |
| 243 | + auto* b = cond_builder.get(); |
| 244 | + auto t = xla::Parameter( |
| 245 | + b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple"); |
| 246 | + auto p1 = xla::GetTupleElement(t, last_i); |
| 247 | + auto p2 = xla::GetTupleElement(t, last_i + 1); |
| 248 | + cond_result = xla::Lt(p1, p2); |
| 249 | + } |
| 250 | + |
| 251 | + auto result = xla::While( |
| 252 | + cond_builder->Build(cond_result).ConsumeValueOrDie(), |
| 253 | + body_builder->Build(body_result).ConsumeValueOrDie(), initial); |
| 254 | + |
| 255 | + std::vector<xla::XlaOp> results; |
| 256 | + for (size_t i = 0; i < last_i; ++i) { |
| 257 | + results.push_back(xla::GetTupleElement(result, i)); |
| 258 | + } |
| 259 | + return ReturnOps(results, loctx); |
| 260 | + } |
| 261 | + |
| 262 | + Value index_placeholder_; |
| 263 | + std::vector<Value> placeholders_; |
| 264 | + std::vector<Value> results_; |
| 265 | +}; |
| 266 | + |
| 267 | +class XLAPlaceholderNode : public swift_xla::ir::Node { |
| 268 | + public: |
| 269 | + XLAPlaceholderNode(xla::Shape shape, int id) |
| 270 | + : Node(swift_xla::ir::OpKind(at::aten::placeholder), {}, shape, 1, |
| 271 | + xla::util::MHash(id)), |
| 272 | + id_(id) {} |
| 273 | + NodePtr Clone(OpList operands) const override { |
| 274 | + return swift_xla::ir::MakeNode<XLAPlaceholderNode>(shape(), id_); |
| 275 | + } |
| 276 | + XlaOpVector Lower(LoweringContext* loctx) const override { |
| 277 | + LOG(FATAL) << "Cannot lower placeholder: " << ToString() << " id: " << id_; |
| 278 | + } |
| 279 | + std::string ToString() const override { |
| 280 | + std::stringstream ss; |
| 281 | + ss << Node::ToString() << ", id=" << id_; |
| 282 | + return ss.str(); |
| 283 | + } |
| 284 | + int id_; |
| 285 | +}; |
| 286 | + |
| 287 | +std::vector<Value> UnpackIrValues(OpaqueXLATensorArrayRef array) { |
| 288 | + std::vector<Value> out; |
| 289 | + out.reserve(array.size); |
| 290 | + for (size_t i = 0; i < array.size; ++i) { |
| 291 | + out.push_back(array.data[i]->GetIrValue()); |
| 292 | + } |
| 293 | + return out; |
| 294 | +} |
| 295 | + |
| 296 | +OpaqueXLATensorArrayRef XLATensor_functional_while( |
| 297 | + OpaqueXLATensor* n, OpaqueXLATensorArrayRef initial, |
| 298 | + OpaqueXLATensorArrayRef placeholders, OpaqueXLATensor* indexPlaceholder, |
| 299 | + OpaqueXLATensorArrayRef results) { |
| 300 | + auto initial_ir = UnpackIrValues(initial); |
| 301 | + auto placeholders_ir = UnpackIrValues(placeholders); |
| 302 | + auto results_ir = UnpackIrValues(results); |
| 303 | + |
| 304 | + auto result_node = swift_xla::ir::MakeNode<XLAFunctionalWhileNode>( |
| 305 | + initial_ir, n->GetIrValue(), indexPlaceholder->GetIrValue(), |
| 306 | + placeholders_ir, results_ir); |
| 307 | + size_t count = results.size; |
| 308 | + auto opaque_tensors = new OpaqueXLATensor*[count]; |
| 309 | + for (size_t i = 0; i < count; ++i) { |
| 310 | + opaque_tensors[i] = new XLATensor( |
| 311 | + results.data[i]->CreateFrom(swift_xla::ir::Value(result_node, i))); |
| 312 | + } |
| 313 | + return {opaque_tensors, count}; |
| 314 | +} |
| 315 | + |
| 316 | +OpaqueXLATensor* XLATensor_makePlaceholder(OpaqueXLATensor* t, int id) { |
| 317 | + return new XLATensor(t->CreateFrom( |
| 318 | + swift_xla::ir::MakeNode<XLAPlaceholderNode>(t->shape(), id))); |
| 319 | +} |
0 commit comments