Skip to content

Commit aeadbbd

Browse files
committed
Updated DispatchKeyExtractor to expect TensorOptions
1 parent 6f37a53 commit aeadbbd

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

aten/src/ATen/core/dispatch/DispatchKeyExtractor.h

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,65 @@ namespace detail {
4848
}
4949
};
5050

51+
struct TensorOptionsAccumulator : at::IterArgs<TensorOptionsAccumulator> {
52+
TensorOptions options;
53+
void operator()(c10::optional<ScalarType> dtype) {
54+
if (dtype.has_value())
55+
options = options.dtype(*dtype);
56+
else
57+
options = options.dtype(at::get_default_dtype());
58+
}
59+
void operator()(c10::optional<Device> device) {
60+
if (device.has_value())
61+
options = options.device(*device);
62+
}
63+
void operator()(c10::optional<Layout> layout) {
64+
if (layout.has_value())
65+
options = options.layout(*layout);
66+
}
67+
void operator()(c10::optional<bool> pin_memory) {
68+
if (pin_memory.has_value())
69+
options = options.pinned_memory(*pin_memory);
70+
}
71+
void operator()(ScalarType dtype) {
72+
options = options.dtype(dtype);
73+
}
74+
void operator()(Device device) {
75+
options = options.device(device);
76+
}
77+
void operator()(Layout layout) {
78+
options = options.layout(layout);
79+
}
80+
void operator()(bool pin_memory) {
81+
options = options.pinned_memory(pin_memory);
82+
}
83+
template <typename T>
84+
void operator()(const T& x) {
85+
// do nothing
86+
}
87+
};
88+
89+
template<class Arg> using arg_is_tensor_option_arg = guts::typelist::contains<
90+
guts::typelist::typelist<c10::optional<ScalarType>, c10::optional<Layout>,
91+
c10::optional<Device>, c10::optional<bool>, ScalarType, Layout, Device, bool>,
92+
guts::remove_const_t<guts::remove_reference_t<Arg>>>;
93+
94+
template<class... Args> using args_have_tensor_options = guts::disjunction<
95+
arg_is_tensor_option_arg<Args>...>;
96+
5197
// NB: take by const reference (Don't do universal forwarding here! You
5298
// don't want to move into this function!)
5399
template <typename... Args>
54100
TensorTypeSet multi_dispatch_tensor_type_set(const Args&... args) {
55-
return MultiDispatchTensorTypeSet().apply(args...).ts;
101+
auto type_set = MultiDispatchTensorTypeSet().apply(args...);
102+
103+
if (args_have_tensor_options<Args...>::value) {
104+
TensorOptions tensorOptions = TensorOptionsAccumulator().apply(args...).options;
105+
if (tensorOptions.has_dtype() && tensorOptions.has_device() && tensorOptions.has_layout()) {
106+
type_set(tensorOptions);
107+
}
108+
}
109+
return type_set.ts;
56110
}
57111
}
58112

0 commit comments

Comments
 (0)