Skip to content

Commit aa99eba

Browse files
committed
Add tracing support for optional Device and Layout
1 parent 190dac1 commit aa99eba

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

torch/csrc/jit/tracer.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,30 @@ void addInputs(
504504
n->addInput(none);
505505
}
506506
}
507+
void addInputs(
508+
Node* n,
509+
const char* name,
510+
const c10::optional<at::Layout>& value) {
511+
if (value.has_value()) {
512+
detail::genericAddInput(n, static_cast<int64_t>(*value));
513+
} else {
514+
Graph* g = n->owningGraph();
515+
Value* none = g->insertNode(g->createNone())->output();
516+
n->addInput(none);
517+
}
518+
}
519+
void addInputs(
520+
Node* n,
521+
const char* name,
522+
const c10::optional<at::Device>& value) {
523+
if (value.has_value()) {
524+
detail::genericAddInput(n, value);
525+
} else {
526+
Graph* g = n->owningGraph();
527+
Value* none = g->insertNode(g->createNone())->output();
528+
n->addInput(none);
529+
}
530+
}
507531
#ifdef BUILD_NAMEDTENSOR
508532
void addInputs(
509533
Node* n,

torch/csrc/jit/tracer.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,14 @@ TORCH_API void addInputs(
259259
Node* n,
260260
const char* name,
261261
const c10::optional<at::ScalarType>& value);
262+
TORCH_API void addInputs(
263+
Node* n,
264+
const char* name,
265+
const c10::optional<at::Device>& value);
266+
TORCH_API void addInputs(
267+
Node* n,
268+
const char* name,
269+
const c10::optional<at::Layout>& value);
262270
TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value);
263271
#ifdef BUILD_NAMEDTENSOR
264272
TORCH_API void addInputs(Node* n, const char* name, c10::optional<at::DimnameList> value);

0 commit comments

Comments
 (0)