forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoperator_sets.h
46 lines (40 loc) · 1.74 KB
/
operator_sets.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#pragma once
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch,
1,
SparseLengthsSumFused8BitRowwise);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchMatMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, ExpandDims);
// Iterate over schema from ai.onnx.pytorch domain opset 1
class OpSet_PyTorch_ver1 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, SparseLengthsSumFused8BitRowwise)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, SparseLengthsSum)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, SparseLengthsWeightedSum)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, BatchGather)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, DotProduct)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, FCTransposed)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, BatchMatMul)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, ExpandDims)>());
}
};
inline void RegisterPyTorchOperatorSetSchema() {
RegisterOpSetSchema<OpSet_PyTorch_ver1>();
}
} // namespace ONNX_NAMESPACE