#include <gtest/gtest.h> #include <test/cpp/tensorexpr/test_base.h> #include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/irparser.h> #include <torch/csrc/jit/passes/lower_tuples.h> #include <torch/csrc/jit/tensorexpr/graph_opt.h> #include <torch/csrc/jit/tensorexpr/kernel.h> #include <torch/csrc/jit/testing/file_check.h> #include <torch/torch.h> #include <limits> namespace torch { namespace jit { using namespace torch::jit::tensorexpr; class GraphOpt : public ::testing::Test { public: void SetUp() override { old_cat_wo_conditionals_ = getCatWoConditionals(); getCatWoConditionals() = true; } void TearDown() override { getCatWoConditionals() = old_cat_wo_conditionals_; } private: bool old_cat_wo_conditionals_; }; TEST_F(GraphOpt, OptimizeCat) { #ifdef TORCH_ENABLE_LLVM const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) return (%5))IR"; auto g = std::make_shared<Graph>(); torch::jit::parseIR(graph_string, g.get()); g->lint(); TensorExprKernel kernel(g); // The `aten::log` op must be moved to the inputs of `aten::cat`. testing::FileCheck() .check("aten::log") ->check("aten::log") ->check("aten::log") ->check("aten::cat") ->check_not("aten::log") ->run(*kernel.graph()); auto x = at::rand({10}, at::kFloat); auto y = at::rand({20}, at::kFloat); auto z = at::rand({30}, at::kFloat); auto ref = at::log(at::cat({x, y, z}, 0)); std::vector<at::Tensor> inputs = {x, y, z}; std::vector<IValue> stack = fmap<IValue>(inputs); kernel.run(stack); auto out = stack[0].toTensor(); ASSERT_EQ(out.sizes(), ref.sizes()); ASSERT_EQ(out.dtype(), ref.dtype()); ASSERT_TRUE(at::allclose(out, ref)); #endif } TEST_F(GraphOpt, OptimizeCat2) { #ifdef TORCH_ENABLE_LLVM const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5) return (%6))IR"; auto g = std::make_shared<Graph>(); torch::jit::parseIR(graph_string, g.get()); g->lint(); TensorExprKernel kernel(g); // The `aten::log` and `aten::tanh` ops must be moved to the inputs of // `aten::cat`. testing::FileCheck() .check("aten::log") ->check("aten::log") ->check("aten::log") ->check("aten::tanh") ->check("aten::tanh") ->check("aten::tanh") ->check("aten::cat") ->check_not("aten::log") ->check_not("aten::tanh") ->run(*kernel.graph()); auto x = at::rand({10}, at::kFloat); auto y = at::rand({20}, at::kFloat); auto z = at::rand({30}, at::kFloat); auto ref = at::tanh(at::log(at::cat({x, y, z}, 0))); std::vector<at::Tensor> inputs = {x, y, z}; std::vector<IValue> stack = fmap<IValue>(inputs); kernel.run(stack); auto out = stack[0].toTensor(); ASSERT_EQ(out.sizes(), ref.sizes()); ASSERT_EQ(out.dtype(), ref.dtype()); ASSERT_TRUE(at::allclose(out, ref)); #endif } TEST_F(GraphOpt, OptimizeCat3) { #ifdef TORCH_ENABLE_LLVM const auto graph_string = R"IR( graph(%a : Float(60, strides=[1], device=cpu), %x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5) return (%6))IR"; auto g = std::make_shared<Graph>(); torch::jit::parseIR(graph_string, g.get()); g->lint(); TensorExprKernel kernel(g); // The `aten::tanh` op must be moved to the inputs of `aten::cat`. // But the `aten::mul` op must not be moved since it is not a single-tensor // op (it has 2 tensor inputs). testing::FileCheck() .check("aten::tanh") ->check("aten::tanh") ->check("aten::tanh") ->check("aten::cat") ->check("aten::mul") ->check_not("aten::tanh") ->run(*kernel.graph()); auto a = at::rand({60}, at::kFloat); auto x = at::rand({10}, at::kFloat); auto y = at::rand({20}, at::kFloat); auto z = at::rand({30}, at::kFloat); auto ref = at::tanh(at::cat({x, y, z}, 0)) * a; std::vector<at::Tensor> inputs = {a, x, y, z}; std::vector<IValue> stack = fmap<IValue>(inputs); kernel.run(stack); auto out = stack[0].toTensor(); ASSERT_EQ(out.sizes(), ref.sizes()); ASSERT_EQ(out.dtype(), ref.dtype()); ASSERT_TRUE(at::allclose(out, ref)); #endif } TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) { #ifdef TORCH_ENABLE_LLVM const auto graph_string = R"IR( graph(%x : Int(10, strides=[1], device=cpu), %y : Int(20, strides=[1], device=cpu), %z : Int(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) return (%5))IR"; auto g = std::make_shared<Graph>(); torch::jit::parseIR(graph_string, g.get()); g->lint(); TensorExprKernel kernel(g); // The `aten::tanh` op must be moved to the inputs of `aten::cat`. // The scalar type of the inputs to `cat` should now be `Float` since they // are the result of `tanh` which does the type promotion. testing::FileCheck() .check("aten::tanh") ->check("aten::tanh") ->check("aten::tanh") ->check("aten::cat") ->check_not("aten::tanh") ->run(*kernel.graph()); auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt); auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt); auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt); auto ref = at::tanh(at::cat({x, y, z}, 0)); std::vector<at::Tensor> inputs = {x, y, z}; std::vector<IValue> stack = fmap<IValue>(inputs); kernel.run(stack); auto out = stack[0].toTensor(); ASSERT_EQ(out.sizes(), ref.sizes()); ASSERT_EQ(out.dtype(), ref.dtype()); ASSERT_TRUE(at::allclose(out, ref)); #endif } TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) { #ifdef TORCH_ENABLE_LLVM const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Double(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat) return (%5))IR"; auto g = std::make_shared<Graph>(); torch::jit::parseIR(graph_string, g.get()); g->lint(); TensorExprKernel kernel(g); // No transformation should have happened because the `aten::cat` op performs // type promotion. This case is currently not handled. testing::FileCheck() .check("aten::cat") ->check("aten::log") ->check_not("aten::cat") ->check_not("aten::log") ->run(*kernel.graph()); #endif } TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) { #ifdef TORCH_ENABLE_LLVM const auto graph_string = R"IR( graph(%0 : Float(60, strides=[1], device=cpu), %x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) return (%5))IR"; auto g = std::make_shared<Graph>(); torch::jit::parseIR(graph_string, g.get()); g->lint(); TensorExprKernel kernel(g); // No transformation is expected since the consumers of cat are not // single-tensor element-wise ops. testing::FileCheck() .check("aten::cat") ->check("aten::mul") ->check_not("aten::cat") ->check_not("aten::mul") ->run(*kernel.graph()); #endif } TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) { #ifdef TORCH_ENABLE_LLVM const auto graph_string = R"IR( graph(%0 : Float(60, strides=[1], device=cpu), %1 : Float(60, strides=[1], device=cpu), %x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Float(30, strides=[1], device=cpu)): %one : int = prim::Constant[value=1]() %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one) return (%6))IR"; auto g = std::make_shared<Graph>(); torch::jit::parseIR(graph_string, g.get()); g->lint(); TensorExprKernel kernel(g); // No transformation is expected since the consumers of cat are not // single-tensor element-wise ops. testing::FileCheck() .check("aten::cat") ->check("aten::mul") ->check("aten::add") ->check_not("aten::cat") ->check_not("aten::mul") ->check_not("aten::add") ->run(*kernel.graph()); #endif } TEST_F(GraphOpt, AOTGraphPrepPasses) { const auto graph_string = R"IR( graph(%x, %y, %z, %i : int): %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) return (%xyz_list, %i))IR"; auto g = std::make_shared<Graph>(); torch::jit::parseIR(graph_string, g.get()); removeGraphOutput(g, 1); replaceListOutputWithTuple(g); LowerAllTuples(g); testing::FileCheck().check("return (%x, %y, %z)")->run(*g); } } // namespace jit } // namespace torch