8000 Support torch.Event elapsed_time method on XPU (#140865) · pytorch/pytorch@b1a8be6 · GitHub
[go: up one dir, main page]

Skip to content

Commit b1a8be6

Browse files
guangyeypytorchmergebot
authored andcommitted
Support torch.Event elapsed_time method on XPU (#140865)
# Motivation This PR aims to support c10::Event/torch.Event elapsed_time method on XPU. We create a profiling tag Event when the timing flag is enabled. Pull Request resolved: #140865 Approved by: https://github.com/Samkm0084, https://github.com/gujinghui
1 parent d70b702 commit b1a8be6

File tree

2 files changed

+46
-11
lines changed

2 files changed

+46
-11
lines changed

c10/xpu/impl/XPUGuardImpl.h

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,19 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
103103
// Delete the event previously recorded.
104104
if (xpu_event)
105105
delete xpu_event;
106+
#if SYCL_COMPILER_VERSION >= 20250000
107+
if (flag == EventFlag::BACKEND_DEFAULT) {
108+
// Use the profiling tag to record the event to enable timing feature.
109+
xpu_event =
110+
new sycl::event(sycl::ext::oneapi::experimental::submit_profiling_tag(
111+
xpu_stream.queue()));
112+
} else {
113+
xpu_event =
114+
new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier());
115+
}
116+
#else
106117
xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier());
118+
#endif
107119
*event = reinterpret_cast<void*>(xpu_event);
108120

109121
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
@@ -140,6 +152,30 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
140152
event_command_status::complete;
141153
}
142154

155+
double elapsedTime(
156+
void* start_event,
157+
void* end_event,
158+
const DeviceIndex device_index) const override {
159+
#if SYCL_COMPILER_VERSION < 20250000
160+
TORCH_CHECK_NOT_IMPLEMENTED(
161+
false,
162+
"elapsedTime requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
163+
#endif
164+
TORCH_CHECK(
165+
start_event && end_event,
166+
"Both events must be recorded before calculating elapsed time.");
167+
auto* xpu_start_event = reinterpret_cast<sycl::event*>(start_event);
168+
auto* xpu_end_event = reinterpret_cast<sycl::event*>(end_event);
169+
170+
using namespace sycl::info::event_profiling;
171+
// Block until both of the recorded events are completed.
172+
uint64_t end_time_ns = xpu_end_event->get_profiling_info<command_end>();
173+
uint64_t start_time_ns = xpu_start_event->get_profiling_info<command_end>();
174+
// Return the eplased time in milliseconds.
175+
return 1e-6 *
176+
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
177+
}
178+
143179
// Stream-related functions
144180
bool queryStream(const Stream& stream) const override {
145181
const XPUStream xpu_stream{stream};
@@ -176,12 +212,6 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
176212
const XPUStream xpu_stream{stream};
177213
XPUCachingAllocator::recordStream(data_ptr, xpu_stream);
178214
}
179-
180-
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
181-
const override {
182-
TORCH_CHECK_NOT_IMPLEMENTED(
183-
false, "elapsedTime is not supported by XPU backend.");
184-
}
185215
};
186216

187217
} // namespace c10::xpu::impl

test/test_xpu.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def test_generic_stream_event(self):
250250
self.assertEqual(stream.stream_id, xpu_stream.stream_id)
251251
self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
252252

253-
event1 = torch.Event("xpu")
254-
event2 = torch.Event("xpu")
253+
event1 = torch.Event("xpu", enable_timing=True)
254+
event2 = torch.Event("xpu", enable_timing=True)
255255
self.assertEqual(event1.event_id, 0)
256256
a = torch.randn(1000)
257257
b = torch.randn(1000)
@@ -263,15 +263,20 @@ def test_generic_stream_event(self):
263263
event1.synchronize()
264264
self.assertTrue(event1.query())
265265
c_xpu = a_xpu + b_xpu
266+
# Here intendionly records another stream.
266267
event2.record()
267268
event2.synchronize()
268269
self.assertTrue(event2.query())
269270
self.assertNotEqual(event1.event_id, event2.event_id)
270271
self.assertEqual(c_xpu.cpu(), a + b)
271-
with self.assertRaisesRegex(
272-
NotImplementedError, "elapsedTime is not supported by XPU backend."
273-
):
272+
if int(torch.version.xpu) >= 20250000:
274273
event1.elapsed_time(event2)
274+
else:
275+
with self.assertRaisesRegex(
276+
NotImplementedError,
277+
"elapsedTime requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.",
278+
):
279+
event1.elapsed_time(event2)
275280
xpu_event = torch.xpu.Event()
276281
self.assertIsInstance(xpu_event, torch.Event)
277282
self.assertTrue(issubclass(type(xpu_event), torch.Event))

0 commit comments

Comments
 (0)
0