forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInlineEvent.h
139 lines (119 loc) · 3.75 KB
/
InlineEvent.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/Exception.h>
namespace c10::impl {
template <typename T>
struct InlineEvent final {
InlineEvent() = delete;
InlineEvent(
const DeviceType _device_type,
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
: backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {}
// Copy constructor and copy assignment operator (deleted)
InlineEvent(const InlineEvent&) = delete;
InlineEvent& operator=(const InlineEvent&) = delete;
// Move constructor and move assignment operator
InlineEvent(InlineEvent&& other) noexcept
: event_(other.event_),
backend_(std::move(other.backend_)),
device_type_(other.device_type_),
device_index_(other.device_index_),
flag_(other.flag_),
was_marked_for_recording_(other.was_marked_for_recording_) {
other.event_ = nullptr;
}
InlineEvent& operator=(InlineEvent&& other) noexcept {
swap(other);
return *this;
}
void swap(InlineEvent& other) noexcept {
std::swap(event_, other.event_);
std::swap(backend_, other.backend_);
std::swap(device_type_, other.device_type_);
std::swap(device_index_, other.device_index_);
std::swap(flag_, other.flag_);
std::swap(was_marked_for_recording_, other.was_marked_for_recording_);
}
~InlineEvent() noexcept {
if (event_)
backend_.destroyEvent(event_, device_index_);
}
DeviceType device_type() const noexcept {
return device_type_;
}
DeviceIndex device_index() const noexcept {
return device_index_;
}
EventFlag flag() const noexcept {
return flag_;
}
bool was_marked_for_recording() const noexcept {
return was_marked_for_recording_;
}
void recordOnce(const Stream& stream) {
if (!was_marked_for_recording_)
record(stream);
}
void record(const Stream& stream) {
TORCH_CHECK(
stream.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match recording stream's device type ",
DeviceTypeName(stream.device_type()),
".");
backend_.record(&event_, stream, device_index_, flag_);
was_marked_for_recording_ = true;
device_index_ = stream.device_index();
}
void block(const Stream& stream) const {
if (!was_marked_for_recording_)
return;
TORCH_CHECK(
stream.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match blocking stream's device type ",
DeviceTypeName(stream.device_type()),
".");
backend_.block(event_, stream);
}
bool query() const {
if (!was_marked_for_recording_)
return true;
return backend_.queryEvent(event_);
}
void* eventId() const {
return event_;
}
double elapsedTime(const InlineEvent& other) const {
TORCH_CHECK(
other.was_marked_for_recording(),
"other was not marked for recording.");
TORCH_CHECK(
was_marked_for_recording(), "self was not marked for recording.");
TORCH_CHECK(
other.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match other's device type ",
DeviceTypeName(other.device_type()),
".");
return backend_.elapsedTime(event_, other.event_, device_index_);
}
void synchronize() const {
if (!was_marked_for_recording_)
return;
backend_.synchronizeEvent(event_);
}
private:
void* event_ = nullptr;
T backend_;
DeviceType device_type_;
DeviceIndex device_index_ = -1;
EventFlag flag_ = EventFlag::PYTORCH_DEFAULT;
bool was_marked_for_recording_ = false;
};
} // namespace c10::impl