@@ -48,11 +48,65 @@ namespace detail {
4848 }
4949 };
5050
51+ struct TensorOptionsAccumulator : at::IterArgs<TensorOptionsAccumulator> {
52+ TensorOptions options;
53+ void operator ()(c10::optional<ScalarType> dtype) {
54+ if (dtype.has_value ())
55+ options = options.dtype (*dtype);
56+ else
57+ options = options.dtype (at::get_default_dtype ());
58+ }
59+ void operator ()(c10::optional<Device> device) {
60+ if (device.has_value ())
61+ options = options.device (*device);
62+ }
63+ void operator ()(c10::optional<Layout> layout) {
64+ if (layout.has_value ())
65+ options = options.layout (*layout);
66+ }
67+ void operator ()(c10::optional<bool > pin_memory) {
68+ if (pin_memory.has_value ())
69+ options = options.pinned_memory (*pin_memory);
70+ }
71+ void operator ()(ScalarType dtype) {
72+ options = options.dtype (dtype);
73+ }
74+ void operator ()(Device device) {
75+ options = options.device (device);
76+ }
77+ void operator ()(Layout layout) {
78+ options = options.layout (layout);
79+ }
80+ void operator ()(bool pin_memory) {
81+ options = options.pinned_memory (pin_memory);
82+ }
83+ template <typename T>
84+ void operator ()(const T& x) {
85+ // do nothing
86+ }
87+ };
88+
89+ template <class Arg > using arg_is_tensor_option_arg = guts::typelist::contains<
90+ guts::typelist::typelist<c10::optional<ScalarType>, c10::optional<Layout>,
91+ c10::optional<Device>, c10::optional<bool >, ScalarType, Layout, Device, bool >,
92+ guts::remove_const_t <guts::remove_reference_t <Arg>>>;
93+
94+ template <class ... Args> using args_have_tensor_options = guts::disjunction<
95+ arg_is_tensor_option_arg<Args>...>;
96+
5197 // NB: take by const reference (Don't do universal forwarding here! You
5298 // don't want to move into this function!)
5399 template <typename ... Args>
54100 TensorTypeSet multi_dispatch_tensor_type_set (const Args&... args) {
55- return MultiDispatchTensorTypeSet ().apply (args...).ts ;
101+ auto type_set = MultiDispatchTensorTypeSet ().apply (args...);
102+
103+ if (args_have_tensor_options<Args...>::value) {
104+ TensorOptions tensorOptions = TensorOptionsAccumulator ().apply (args...).options ;
105+ if (tensorOptions.has_dtype () && tensorOptions.has_device () && tensorOptions.has_layout ()) {
106+ type_set (tensorOptions);
107+ }
108+ }
109+ return type_set.ts ;
56110 }
57111}
58112
0 commit comments