Skip to content

Commit 81e4739

Browse files
xcheng16facebook-github-bot
authored andcommitted
Move QScheme ops to c10 (pytorch#30134)
Summary: Pull Request resolved: pytorch#30134 ghstack-source-id: 95055387 Test Plan: buck build mode/dev caffe2:generate-code Differential Revision: D18609716 fbshipit-source-id: fec39359e0b97387a9b13f8179d72a731cc61808
1 parent d6ddfab commit 81e4739

File tree

7 files changed

+43
-5
lines changed

7 files changed

+43
-5
lines changed

aten/src/ATen/core/boxing/kernel_functor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ namespace detail {
3939
bool,
4040
std::string,
4141
at::Tensor,
42-
at::Scalar
42+
at::Scalar,
43+
c10::QScheme
4344
>;
4445

4546
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_input_type {

aten/src/ATen/core/jit_type.h

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ using OptNameList = c10::optional<std::vector<std::string>>;
4949
_(FunctionType) \
5050
_(ClassType) \
5151
_(CapsuleType) \
52-
_(InterfaceType)
52+
_(InterfaceType) \
53+
_(QSchemeType)
5354

5455
enum class TypeKind {
5556
#define DEFINE_TYPE(T) T,
@@ -1090,6 +1091,28 @@ struct CAFFE2_API GeneratorType : public Type {
10901091
GeneratorType() : Type(TypeKind::GeneratorType) {}
10911092
};
10921093

1094+
struct QSchemeType;
1095+
using QSchemeTypePtr = std::shared_ptr<QSchemeType>;
1096+
// This type represents a QScheme
1097+
struct CAFFE2_API QSchemeType : public Type {
1098+
static QSchemeTypePtr create() {
1099+
return QSchemeTypePtr(
1100+
new QSchemeType()); // NOLINT(modernize-make-shared)
1101+
}
1102+
bool operator==(const Type& rhs) const override {
1103+
return rhs.kind() == kind();
1104+
}
1105+
std::string str() const override {
1106+
return "QScheme";
1107+
}
1108+
static const TypeKind Kind = TypeKind::QSchemeType;
1109+
// global singleton
1110+
static QSchemeTypePtr get();
1111+
1112+
private:
1113+
QSchemeType() : Type(TypeKind::QSchemeType) {}
1114+
};
1115+
10931116
struct DeviceObjType;
10941117
using DeviceObjTypePtr = std::shared_ptr<DeviceObjType>;
10951118
// This type represents a Generator
@@ -1261,6 +1284,12 @@ struct getTypePtr_<at::Scalar> final {
12611284
}
12621285
};
12631286
template <>
1287+
struct getTypePtr_<c10::QScheme> final {
1288+
static TypePtr call() {
1289+
return QSchemeType::get();
1290+
}
1291+
};
1292+
template <>
12641293
struct getTypePtr_<at::Generator*> final {
12651294
static TypePtr call() {
12661295
return OptionalType::create(GeneratorType::get());
@@ -1392,8 +1421,8 @@ struct CAFFE2_API ClassType : public NamedType {
13921421
}
13931422

13941423
std::string str() const override {
1395-
return python_str();
1396-
}
1424+
return python_str();
1425+
}
13971426

13981427
std::string python_str() const override {
13991428
const auto& n = name().value();

aten/src/ATen/core/type.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ GeneratorTypePtr GeneratorType::get() {
104104
static auto value = GeneratorType::create();
105105
return value;
106106
}
107+
QSchemeTypePtr QSchemeType::get() {
108+
static auto value = QSchemeType::create();
109+
return value;
110+
}
107111
StringTypePtr StringType::get() {
108112
static auto value = StringType::create();
109113
return value;

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3533,6 +3533,7 @@
35333533
CPU: make_per_channel_quantized_tensor_cpu
35343534

35353535
- func: qscheme(Tensor self) -> QScheme
3536+
use_c10_dispatcher: full
35363537
variants: method
35373538
dispatch:
35383539
QuantizedCPU: qscheme_quant

torch/csrc/jit/pybind_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ inline IValue toIValue(
541541
case TypeKind::GeneratorType:
542542
case TypeKind::VarType:
543543
case TypeKind::FutureType:
544+
case TypeKind::QSchemeType:
544545
break;
545546
case TypeKind::FunctionType:
546547
AT_ERROR("Function Values aren't yet supported");

torch/csrc/jit/script/schema_type_parser.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using c10::NumberType;
2222
using c10::OptionalType;
2323
using c10::StringType;
2424
using c10::Symbol;
25+
using c10::QSchemeType;
2526
using c10::TensorType;
2627
using c10::TupleType;
2728
using c10::VarType;
@@ -38,7 +39,7 @@ TypeAndAlias SchemaTypeParser::parseBaseType() {
3839
{"Layout", IntType::get()},
3940
{"MemoryFormat", IntType::get()},
4041
{"Storage", IntType::get()},
41-
{"QScheme", IntType::get()},
42+
{"QScheme", QSchemeType::get()},
4243
{"ConstQuantizerPtr", IntType::get()}, // TODO This type should be removed from the schema parser, it should use the custom class mechanism instead. @jerryzh
4344
{"Device", DeviceObjType::get()},
4445
{"Scalar", NumberType::get()},

torch/csrc/jit/unpickler.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
8383
case StringType::Kind:
8484
case FunctionType::Kind:
8585
case DeviceObjType::Kind:
86+
case QSchemeType::Kind:
8687
// no op, there is nothing to tag
8788
break;
8889
case AnyType::Kind:

0 commit comments

Comments
 (0)