8000 Use thread-safe getenv wrapper · pytorch/pytorch@9f65055 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9f65055

Browse files
committed
Use thread-safe getenv wrapper
1 parent 1f3edbf commit 9f65055

File tree

6 files changed

+30
-26
lines changed

6 files changed

+30
-26
lines changed

aten/src/ATen/ParallelCommon.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/Config.h>
44
#include <ATen/PTThreadPool.h>
55
#include <ATen/Version.h>
6+
#include <c10/util/env.h>
67

78
#include <sstream>
89
#include <thread>
@@ -23,17 +24,17 @@ namespace at {
2324

2425
namespace {
2526

26-
const char* get_env_var(
27+
std::string get_env_var(
2728
const char* var_name, const char* def_value = nullptr) {
28-
const char* value = std::getenv(var_name);
29-
return value ? value : def_value;
29+
auto env = c10::utils::get_env(var_name);
30+
return env.has_value() ? env.value() : def_value;
3031
}
3132

3233
#ifndef C10_MOBILE
3334
size_t get_env_num_threads(const char* var_name, size_t def_value = 0) {
3435
try {
35-
if (auto* value = std::getenv(var_name)) {
36-
int nthreads = std::stoi(value);
36+
if (auto value = c10::utils::get_env(var_name)) {
37+
int nthreads = std::stoi(value.value());
3738
TORCH_CHECK(nthreads > 0);
3839
return nthreads;
3940
}

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <c10/cuda/CUDACachingAllocator.h>
1212
#include <c10/cuda/CUDAFunctions.h>
1313
#include <c10/macros/Export.h>
14+
#include <c10/util/env.h>
1415
#include <c10/util/irange.h>
1516

1617
#ifdef USE_ROCM
@@ -180,17 +181,17 @@ uint32_t _getAlignment(uintptr_t address) {
180181
#endif
181182

182183
static size_t _parseChosenWorkspaceSize() {
183-
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
184+
auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE");
184185
#ifdef USE_ROCM
185-
if (!val) {
186+
if (!val.has_value()) {
186187
// accept either env var
187-
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
188+
val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE");
188189
}
189190
#endif
190191
size_t workspace_size = 1024; /* default size in KiB according to #73328 */
191-
if (val) {
192+
if (val.has_value()) {
192193
try {
193-
workspace_size = std::stoi(val);
194+
workspace_size = std::stoi(val.value());
194195
} catch(std::invalid_argument const& e) {
195196
TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,",
196197
" using default workspace size of ", workspace_size, " KiB.");

torch/csrc/distributed/c10d/ProcessGroupUCC.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifdef USE_C10D_UCC
22

33
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
4+
#include <c10/util/env.h>
45
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
56
#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
67
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
@@ -157,11 +158,10 @@ void read_config() {
157158
torch_ucc_config.enable_comms_logger = false;
158159

159160
// read all torch_ucc env. variables and update the map
160-
char* env;
161-
for (auto& torch_ucc_env : torch_ucc_envs_map) {
162-
env = std::getenv(torch_ucc_env.first.c_str());
163-
if (env) {
164-
torch_ucc_envs_map[torch_ucc_env.first] = std::string(env);
161+
for (auto& [env_name, value] : torch_ucc_envs_map) {
162+
auto env = c10::utils::get_env(env_name.c_str());
163+
if (env.has_value()) {
164+
value = std::move(env.value());
165165
}
166166
}
167167

torch/csrc/distributed/c10d/UCCTracing.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#ifdef USE_C10D_UCC
22

3+
#include <c10/util/env.h>
34
#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
45
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
56

@@ -32,9 +33,9 @@ void ProcessGroupUCCLogger::flushComms(int rank, int world_size) {
3233
}
3334

3435
std::string fullpath = "/tmp/" + dirname;
35-
char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
36-
if (user_path) {
37-
fullpath = user_path;
36+
auto user_path = c10::utils::get_env("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
37+
if (user_path.has_value()) {
38+
fullpath = std::move(user_path.value());
3839
}
3940
std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json");
4041
std::ofstream _outfile;

torch/csrc/distributed/c10d/debug.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// This source code is licensed under the BSD-style license found in the
55
// LICENSE file in the root directory of this source tree.
66

7+
#include <c10/util/env.h>
78
#include <torch/csrc/distributed/c10d/debug.h>
89

910
#include <algorithm>
@@ -19,15 +20,15 @@ namespace detail {
1920
namespace {
2021

2122
DebugLevel loadDebugLevelFromEnvironment() {
22-
char* env_value = std::getenv("TORCH_DISTRIBUTED_DEBUG");
23+
auto env_value = c10::utils::get_env("TORCH_DISTRIBUTED_DEBUG");
2324

24-
if (env_value == nullptr) {
25+
if (!env_value.has_value()) {
2526
return DebugLevel::Off;
2627
}
2728

2829
DebugLevel level{};
2930

30-
std::string level_str{env_value};
31+
std::string level_str = std::move(env_value.value());
3132

3233
std::transform(
3334
level_str.begin(),

torch/csrc/distributed/rpc/tensorpipe_agent.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,14 @@ C10_DEFINE_REGISTRY_WITHOUT_WARNING(
161161

162162
const std::string& TensorPipeAgent::guessAddress() {
163163
static const std::string uvAddress = []() {
164-
char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str());
165-
if (ifnameEnv != nullptr) {
164+
auto ifnameEnv = c10::utils::get_env(kSocketIfnameEnvVar.c_str());
165+
if (ifnameEnv.has_value()) {
166166
auto [error, result] =
167-
tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
167+
tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv.value());
168168
if (error) {
169169
LOG(WARNING) << "Failed to look up the IP address for interface "
170-
<< ifnameEnv << " (" << error.what() << "), defaulting to "
171-
<< kDefaultUvAddress;
170+
<< ifnameEnv.value() << " (" << error.what()
171+
<< "), defaulting to " << kDefaultUvAddress;
172172
return kDefaultUvAddress;
173173
}
174174
return result;

0 commit comments

Comments
 (0)
0