@@ -43,8 +43,8 @@ CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
43
43
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
44
44
45
45
#define LOAD_CUPTI_SYM (p, lib, x ) \
46
- p-> x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
47
- " cupti" #x);
46
+ p. x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
47
+ " cupti" #x);
48
48
49
49
#else
50
50
using tracing_event_t = void *;
@@ -55,15 +55,21 @@ using cuptiEnableDomain_fn = void *;
55
55
using cuptiEnableCallback_fn = void *;
56
56
#endif // XPTI_ENABLE_INSTRUMENTATION
57
57
58
+ struct cupti_table_t_ {
59
+ cuptiSubscribe_fn Subscribe = nullptr ;
60
+ cuptiUnsubscribe_fn Unsubscribe = nullptr ;
61
+ cuptiEnableDomain_fn EnableDomain = nullptr ;
62
+ cuptiEnableCallback_fn EnableCallback = nullptr ;
63
+
64
+ bool isInitialized () const ;
65
+ };
66
+
58
67
struct cuda_tracing_context_t_ {
59
68
tracing_event_t CallEvent = nullptr ;
60
69
tracing_event_t DebugEvent = nullptr ;
61
70
subscriber_handle_t Subscriber = nullptr ;
62
71
ur_loader::LibLoader::Lib Library;
63
- cuptiSubscribe_fn Subscribe = nullptr ;
64
- cuptiUnsubscribe_fn Unsubscribe = nullptr ;
65
- cuptiEnableDomain_fn EnableDomain = nullptr ;
66
- cuptiEnableCallback_fn EnableCallback = nullptr ;
72
+ cupti_table_t_ Cupti;
67
73
};
68
74
69
75
#ifdef XPTI_ENABLE_INSTRUMENTATION
@@ -132,6 +138,10 @@ void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx) {
132
138
#endif // XPTI_ENABLE_INSTRUMENTATION
133
139
}
134
140
141
+ bool cupti_table_t_::isInitialized () const {
142
+ return Subscribe && Unsubscribe && EnableDomain && EnableCallback;
143
+ }
144
+
135
145
bool loadCUDATracingLibrary (cuda_tracing_context_t_ *Ctx) {
136
146
#if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH)
137
147
if (!Ctx)
@@ -141,16 +151,16 @@ bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
141
151
auto Lib{ur_loader::LibLoader::loadAdapterLibrary (CUPTI_LIB_PATH)};
142
152
if (!Lib)
143
153
return false ;
144
- LOAD_CUPTI_SYM (Ctx, Lib, Subscribe)
145
- LOAD_CUPTI_SYM (Ctx, Lib, Unsubscribe)
146
- LOAD_CUPTI_SYM (Ctx, Lib, EnableDomain)
147
- LOAD_CUPTI_SYM (Ctx, Lib, EnableCallback)
148
- if (!Ctx->Subscribe || !Ctx->Unsubscribe || !Ctx->EnableDomain ||
149
- !Ctx->EnableCallback ) {
150
- unloadCUDATracingLibrary (Ctx);
154
+ cupti_table_t_ Table;
155
+ LOAD_CUPTI_SYM (Table, Lib, Subscribe)
156
+ LOAD_CUPTI_SYM (Table, Lib, Unsubscribe)
157
+ LOAD_CUPTI_SYM (Table, Lib, EnableDomain)
158
+ LOAD_CUPTI_SYM (Table, Lib, EnableCallback)
159
+ if (!Table.isInitialized ()) {
151
160
return false ;
152
161
}
153
162
Ctx->Library = std::move (Lib);
163
+ Ctx->Cupti = Table;
154
164
return true ;
155
165
#else
156
166
(void )Ctx;
@@ -160,14 +170,10 @@ bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
160
170
161
171
void unloadCUDATracingLibrary (cuda_tracing_context_t_ *Ctx) {
162
172
#ifdef XPTI_ENABLE_INSTRUMENTATION
163
- if (!Ctx || !Ctx-> Library )
173
+ if (!Ctx)
164
174
return ;
165
- Ctx->Subscribe = nullptr ;
166
- Ctx->Unsubscribe = nullptr ;
167
- Ctx->EnableDomain = nullptr ;
168
- Ctx->EnableCallback = nullptr ;
169
-
170
175
Ctx->Library .reset ();
176
+ Ctx->Cupti = cupti_table_t_ ();
171
177
#else
172
178
(void )Ctx;
173
179
#endif // XPTI_ENABLE_INSTRUMENTATION
@@ -207,12 +213,12 @@ void enableCUDATracing(cuda_tracing_context_t_ *Ctx) {
207
213
xptiMakeEvent (" CUDA Plugin Debug Layer" , &CUDADebugPayload,
208
214
xpti::trace_algorithm_event, xpti_at::active, &Dummy);
209
215
210
- Ctx->Subscribe (&Ctx->Subscriber , cuptiCallback, Ctx);
211
- Ctx->EnableDomain (1 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API);
212
- Ctx->EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
213
- CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
214
- Ctx->EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
215
- CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
216
+ Ctx->Cupti . Subscribe (&Ctx->Subscriber , cuptiCallback, Ctx);
217
+ Ctx->Cupti . EnableDomain (1 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API);
218
+ Ctx->Cupti . EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
219
+ CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
220
+ Ctx->Cupti . EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
221
+ CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
216
222
#else
217
223
(void )Ctx;
218
224
#endif
@@ -223,8 +229,8 @@ void disableCUDATracing(cuda_tracing_context_t_ *Ctx) {
223
229
if (!Ctx || !xptiTraceEnabled ())
224
230
return ;
225
231
226
- if (Ctx->Subscriber ) {
227
- Ctx->Unsubscribe (Ctx->Subscriber );
232
+ if (Ctx->Subscriber && Ctx-> Cupti . isInitialized () ) {
233
+ Ctx->Cupti . Unsubscribe (Ctx->Subscriber );
228
234
Ctx->Subscriber = nullptr ;
229
235
}
230
236
0 commit comments