1
1
#if !defined(C10_MOBILE) && !defined(ANDROID)
2
2
3
3
#include < c10/util/error.h>
4
- #include < c10/util/string_view.h>
5
4
#include < torch/csrc/inductor/aoti_package/model_package_loader.h>
6
5
#include < torch/csrc/inductor/aoti_runner/model_container_runner.h>
7
6
#include < torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
12
11
#include < fstream>
13
12
#include < iostream>
14
13
15
- #ifndef _WIN32
16
- #include < dirent.h>
17
- #include < sys/stat.h>
18
- #else
19
- #include < filesystem>
20
- namespace fs = std::filesystem;
21
- #endif
22
-
23
- // TODO: C++17 has the filesystem header, which may replace these
24
- #ifdef _WIN32
25
- // On Windows, the POSIX implementations are considered deprecated. We simply
26
- // map to the newer variant.
27
- #include < direct.h>
28
- #include < io.h>
29
- #include < process.h>
30
- #define access _access
31
- #define F_OK 0
32
- #else
33
- #include < sys/types.h>
34
- #include < unistd.h>
35
- #endif
36
-
37
14
namespace {
38
- bool file_exists (const std::string& path) {
39
- #ifdef _WIN32
40
- return fs::exists (path);
41
- #else
42
- struct stat rc {};
43
- return lstat (path.c_str (), &rc) == 0 ;
44
- #endif
45
- }
46
15
47
16
std::string create_temp_dir () {
48
17
#ifdef _WIN32
@@ -57,20 +26,14 @@ std::string create_temp_dir() {
57
26
return temp_dir;
58
27
#endif
59
28
}
60
-
61
- #ifdef _WIN32
62
- const std::string k_separator = " \\ " ;
63
- #else
64
- const std::string k_separator = " /" ;
65
- #endif
66
29
} // namespace
67
30
68
31
namespace torch ::inductor {
69
32
70
33
namespace {
71
- const nlohmann::json& load_json_file (const std::string & json_path) {
72
- if (!file_exists (json_path)) {
73
- throw std::runtime_error (" File not found: " + json_path);
34
+ const nlohmann::json& load_json_file (const std::filesystem::path & json_path) {
35
+ if (!std::filesystem::exists (json_path)) {
36
+ throw std::runtime_error (" File not found: " + json_path. string () );
74
37
}
75
38
76
39
std::ifstream json_file (json_path);
@@ -98,10 +61,9 @@ std::tuple<std::string, std::string> get_cpp_compile_command(
98
61
99
62
std::string file_ext = compile_only ? " .o" : " .so" ;
100
63
std::string target_file = output_dir + filename + file_ext;
101
- std::string target_dir = output_dir;
64
+ std::filesystem::path target_dir = output_dir;
102
65
if (target_dir.empty ()) {
103
- size_t parent_path_idx = filename.find_last_of (k_separator);
104
- target_dir = filename.substr (0 , parent_path_idx);
66
+ target_dir = std::filesystem::path (filename).parent_path ();
105
67
}
106
68
107
69
std::string cflags_args;
@@ -135,11 +97,10 @@ std::tuple<std::string, std::string> get_cpp_compile_command(
135
97
}
136
98
137
99
std::string passthrough_parameters_args;
100
+ std::string target = " script.ld" ;
101
+ auto replacement = (target_dir / target).string ();
138
102
for (auto & arg : compile_options[" passthrough_args" ]) {
139
103
std::string arg_str = arg.get <std::string>();
140
- std::string target = " script.ld" ;
141
- std::string replacement = target_dir;
142
- replacement.append (k_separator).append (target);
143
104
size_t pos = arg_str.find (target);
144
105
if (pos != std::string::npos) {
145
106
arg_str.replace (pos, target.length (), replacement);
@@ -166,102 +127,6 @@ std::tuple<std::string, std::string> get_cpp_compile_command(
166
127
return std::make_tuple (cmd, target_file);
167
128
}
168
129
169
- bool recursive_mkdir (const std::string& dir) {
170
- // Creates directories recursively, copied from jit_utils.cpp
171
- // Check if current dir exists
172
- const char * p_dir = dir.c_str ();
173
- const bool dir_exists = (access (p_dir, F_OK) == 0 );
174
- if (dir_exists) {
175
- return true ;
176
- }
177
-
178
- // Try to create current directory
179
- #ifdef _WIN32
180
- int ret = _mkdir (dir.c_str ());
181
- #else
182
- int ret = mkdir (dir.c_str (), S_IRWXU | S_IRWXG | S_IRWXO);
183
- #endif
184
- // Success
185
- if (ret == 0 ) {
186
- return true ;
187
- }
188
-
189
- // Find folder separator and check if we are at the top
190
- auto pos = dir.find_last_of (k_separator);
191
- if (pos == std::string::npos) {
192
- return false ;
193
- }
194
-
195
- // Try to create parent directory
196
- if (!(recursive_mkdir (dir.substr (0 , pos)))) {
197
- return false ;
198
- }
199
-
200
- // Try to create complete path again
201
- #ifdef _WIN32
202
- ret = _mkdir (dir.c_str ());
203
- #else
204
- ret = mkdir (dir.c_str (), S_IRWXU | S_IRWXG | S_IRWXO);
205
- #endif
206
- return ret == 0 ;
207
- }
208
-
209
- bool recursive_rmdir (const std::string& path) {
210
- #ifdef _WIN32
211
- std::error_code ec;
212
- return fs::remove_all (path, ec) != static_cast <std::uintmax_t >(-1 );
213
- #else
214
- DIR* dir = opendir (path.c_str ());
215
- if (!dir) {
216
- return false ;
217
- }
218
-
219
- struct dirent * entry = nullptr ;
220
- struct stat statbuf {};
221
- bool success = true ;
222
-
223
- // Iterate through directory entries
224
- while ((entry = readdir (dir)) != nullptr ) {
225
- std::string name = entry->d_name ;
226
-
227
- // Skip "." and ".."
228
- if (name == " ." || name == " .." ) {
229
- continue ;
230
- }
231
-
232
- std::string full_path = path;
233
- full_path.append (" /" ).append (name);
234
-
235
- // Get file status
236
- if (stat (full_path.c_str (), &statbuf) != 0 ) {
237
- success = false ;
238
- continue ;
239
- }
240
-
241
- if (S_ISDIR (statbuf.st_mode )) {
242
- // Recursively delete subdirectory
243
- if (!recursive_rmdir (full_path)) {
244
- success = false ;
245
- }
246
- } else {
247
- // Delete file
248
- if (unlink (full_path.c_str ()) != 0 ) {
249
- success = false ;
250
- }
251
- }
252
- }
253
-
254
- closedir (dir);
255
-
256
- // Remove the directory itself
257
- if (rmdir (path.c_str ()) != 0 ) {
258
- success = false ;
259
- }
260
-
261
- return success;
262
- #endif
263
- }
264
-
265
130
std::string compile_so (
266
131
const std::string& cpp_filename,
267
132
const std::string& consts_filename) {
@@ -295,7 +160,7 @@ std::string compile_so(
295
160
296
161
// Move the mmapped weights onto the .so
297
162
std::string serialized_weights_path = filename + " _serialized_weights.bin" ;
298
- if (file_exists (serialized_weights_path)) {
163
+ if (std::filesystem::exists (serialized_weights_path)) {
299
164
std::ifstream serialized_weights_file (
300
165
serialized_weights_path, std::ios::binary);
301
166
if (!serialized_weights_file.is_open ()) {
@@ -370,9 +235,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
370
235
std::string cpp_filename;
371
236
std::string consts_filename;
372
237
std::string found_filenames; // Saving for bookkeeping
373
- std::string model_directory =
374
- " data" + k_separator + " aotinductor" + k_separator + model_name;
375
- std::string const_directory = " data" + k_separator + " constants" ;
238
+ auto model_directory =
239
+ std::filesystem::path ( " data" ) / " aotinductor" / model_name;
240
+ auto const_directory = std::filesystem::path ( " data" ) / " constants" ;
376
241
377
242
for (uint32_t i = 0 ; i < zip_archive.m_total_files ; i++) {
378
243
uint32_t filename_len =
@@ -387,55 +252,53 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
387
252
&zip_archive, i, filename_str.data (), filename_len)) {
388
253
throw std::runtime_error (" Failed to read filename" );
389
254
}
255
+ std::filesystem::path filename (filename_str.c_str ());
390
256
391
257
found_filenames += filename_str;
392
258
found_filenames += " " ;
393
259
394
260
// Only compile files in the specified model directory
395
- if (c10::starts_with (filename_str, model_directory) ||
396
- c10::starts_with (filename_str, const_directory)) {
397
- std::string output_path_str = temp_dir_;
398
-
399
- if (c10::starts_with (filename_str, model_directory)) {
400
- output_path_str += k_separator;
401
- output_path_str += filename_str;
402
- } else { // startsWith(filename_str, const_directory)
261
+ bool in_model_directory =
262
+ (!std::filesystem::relative (filename, model_directory).empty ());
263
+ bool in_const_directory =
264
+ (!std::filesystem::relative (filename, const_directory).empty ());
265
+ if (in_model_directory || in_const_directory) {
266
+ std::filesystem::path output_path = temp_dir_;
267
+
268
+ if (in_model_directory) {
269
+ output_path /= filename;
270
+ } else {
403
271
// Extract constants to the same directory as the rest of the files
404
272
// to be consistent with internal implementation
405
- size_t lastSlash = filename_str.find_last_of (k_separator);
406
- std::string filename = filename_str;
407
- if (lastSlash != std::string::npos) {
408
- filename = filename_str.substr (lastSlash + 1 );
409
- }
410
- output_path_str +=
411
- k_separator + model_directory + k_separator + filename;
273
+ output_path /= model_directory;
274
+ output_path /= filename.filename ();
412
275
}
276
+ auto output_path_str = output_path.string ();
413
277
414
- LOG (INFO) << " Extract file: " << filename_str << " to "
415
- << output_path_str;
278
+ LOG (INFO) << " Extract file: " << filename_str << " to " << output_path;
416
279
417
280
// Create the parent directory if it doesn't exist
418
- size_t parent_path_idx = output_path_str.find_last_of (k_separator);
419
- if (parent_path_idx == std::string::npos) {
281
+ if (!output_path.has_parent_path ()) {
420
282
throw std::runtime_error (
421
283
" Failed to find parent path in " + output_path_str);
422
284
}
423
- std::string parent_path = output_path_str.substr (0 , parent_path_idx);
424
- if (!recursive_mkdir (parent_path)) {
285
+ auto parent_path = output_path.parent_path ();
286
+ std::error_code ec{};
287
+ std::filesystem::create_directories (parent_path, ec);
288
+ if (!std::filesystem::is_directory (parent_path)) {
425
289
throw std::runtime_error (fmt::format (
426
290
" Failed to create directory {}: {}" ,
427
- parent_path,
428
- c10::utils::str_error (errno )));
291
+ parent_path. string () ,
292
+ ec. message ( )));
429
293
}
430
294
431
295
// Extracts file to the temp directory
432
296
mz_zip_reader_extract_file_to_file (
433
297
&zip_archive, filename_str.c_str (), output_path_str.c_str (), 0 );
434
298
435
299
// Save the file for bookkeeping
436
- size_t extension_idx = output_path_str.find_last_of (' .' );
437
- if (extension_idx != std::string::npos) {
438
- std::string filename_extension = output_path_str.substr (extension_idx);
300
+ if (output_path.has_extension ()) {
301
+ auto filename_extension = output_path.extension ();
439
302
if (filename_extension == " .cpp" ) {
440
303
cpp_filename = output_path_str;
441
304
} else if (filename_extension == " .o" ) {
@@ -483,15 +346,17 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
483
346
throw std::runtime_error (" Unsupported device found: " + device);
484
347
}
485
348
486
- std::string cubin_dir = temp_dir_ + k_separator + model_directory;
349
+ std::string cubin_dir = ( temp_dir_ / model_directory). string () ;
487
350
runner_ = registered_aoti_runner[device](
488
351
so_path, num_runners, device, cubin_dir, run_single_threaded);
489
352
}
490
353
491
354
AOTIModelPackageLoader::~AOTIModelPackageLoader () {
492
355
// Clean up the temporary directory
493
356
if (!temp_dir_.empty ()) {
494
- recursive_rmdir (temp_dir_);
357
+ std::error_code ec{};
358
+ // The noexcept version of remove_all is used
359
+ std::filesystem::remove_all (temp_dir_, ec);
495
360
}
496
361
}
497
362
0 commit comments