8000 [Environment Variable][4/N] Use thread-safe getenv functions (#137843) · pytorch/pytorch@239ad73 · GitHub
[go: up one dir, main page]

Skip to content

Commit 239ad73

Browse files
cyyeverpytorchmergebot
authored andcommitted
[Environment Variable][4/N] Use thread-safe getenv functions (#137843)
Follows #137328 Pull Request resolved: #137843 Approved by: https://github.com/ezyang
1 parent 07fd61e commit 239ad73

File tree

7 files changed

+40
-63
lines changed

7 files changed

+40
-63
lines changed

aten/src/ATen/cuda/tunable/TunableGemm.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,15 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
197197
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
198198

199199
#ifdef USE_ROCM
200-
static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
201-
if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
200+
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
201+
if (!env_rocblas.has_value() || env_rocblas.value()) {
202202
for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
203203
this->RegisterOp(std::move(name), std::move(op));
204204
}
205205
}
206206

207-
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
208-
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
207+
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
208+
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
209209
// disallow tuning of hipblaslt with c10::complex
210210
if constexpr (
211211
!std::is_same_v<T, c10::complex<float>> &&
@@ -230,8 +230,8 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
230230
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
231231

232232
#ifdef USE_ROCM
233-
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
234-
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
233+
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
234+
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
235235
// disallow tuning of hipblaslt with c10::complex
236236
if constexpr (
237237
!std::is_same_v<T, c10::complex<float>> &&
@@ -256,15 +256,15 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
256256
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
257257

258258
#ifdef USE_ROCM
259-
static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
260-
if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
259+
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
260+
if (!env_rocblas.has_value() || env_rocblas.value()) {
261261
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
262262
this->RegisterOp(std::move(name), std::move(op));
263263
}
264264
}
265265

266-
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
267-
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
266+
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
267+
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
268268
// disallow tuning of hipblaslt with c10::complex
269269
if constexpr (
270270
!std::is_same_v<T, c10::complex<float>> &&

torch/csrc/distributed/c10d/GlooDeviceFactory.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cstdlib>
66

77
#include <c10/util/Exception.h>
8+
#include <c10/util/env.h>
89

910
#if GLOO_HAVE_TRANSPORT_TCP
1011
#include <gloo/transport/tcp/device.h>
@@ -84,14 +85,12 @@ static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice(
8485
} else {
8586
attr.hostname = hostname;
8687
}
87-
const auto pkey =
88-
cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY"));
89-
const auto cert =
90-
cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT"));
88+
const auto pkey = c10::utils::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY");
89+
const auto cert = c10::utils::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT");
9190
const auto caFile =
92-
cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE"));
91+
c10::utils::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE");
9392
const auto caPath =
94-
cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH"));
93+
c10::utils::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH");
9594
return ::gloo::transport::tcp::tls::CreateDevice(
9695
attr, pkey, cert, caFile, caPath);
9796
}

torch/csrc/distributed/c10d/Utils.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/ATen.h>
44
#include <c10/util/Exception.h>
55
#include <c10/util/accumulate.h>
6+
#include <c10/util/env.h>
67
#include <c10/util/irange.h>
78
#include <torch/csrc/distributed/c10d/Types.hpp>
89

@@ -92,7 +93,7 @@ inline std::vector<std::string> split(
9293
inline std::string getCvarString(
9394
const std::vector<std::string>& env,
9495
const char* def) {
95-
const char* ret = def;
96+
std::string ret(def);
9697

9798
if (env.empty()) {
9899
TORCH_CHECK(false, "No environment variables passed");
@@ -103,14 +104,14 @@ inline std::string getCvarString(
103104
* versions of a variable get higher priority than the latter
104105
* versions of the same variable */
105106
for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
106-
const char* val = std::getenv(env[i].c_str());
107-
if (val == nullptr) {
107+
auto val = c10::utils::get_env(env[i].c_str());
108+
if (!val) {
108109
continue;
109110
} else if (i) {
110111
WARN_ENV_VAR_ONCE(env[i], env[0]);
111112
}
112113

113-
ret = val;
114+
ret = val.value();
114115
}
115116

116117
return ret;
@@ -157,15 +158,14 @@ inline bool getCvarBool(const std::vector<std::string>& env, bool def) {
157158
* versions of a variable get higher priority than the latter
158159
* versions of the same variable */
159160
for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
160-
char* val_ = std::getenv(env[i].c_str());
161-
if (val_ == nullptr) {
161+
auto val = c10::utils::get_env(env[i].c_str());
162+
if (!val.has_value()) {
162163
continue;
163164
} else if (i) {
164165
WARN_ENV_VAR_ONCE(env[i], env[0]);
165166
}
166167

167-
std::string val = std::string(val_);
168-
for (auto& x : val) {
168+
for (auto& x : val.value()) {
169169
// NOLINTNEXTLINE(*-narrowing-conversions)
170170
x = std::tolower(x);
171171
}

torch/csrc/lazy/core/ir.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <c10/util/env.h>
12
#include <torch/csrc/lazy/backend/backend_interface.h>
23
#include <torch/csrc/lazy/core/cache.h>
34
#include <torch/csrc/lazy/core/config.h>
@@ -57,7 +58,7 @@ hash_t OpKind::hash() const {
5758
}
5859

5960
bool Node::enableDynamicShape() {
60-
static bool enabled = std::getenv("LTC_ENABLE_DYNAMIC_SHAPES") != nullptr;
61+
static bool enabled = c10::utils::has_env("LTC_ENABLE_DYNAMIC_SHAPES");
6162
return enabled || FLAGS_ltc_enable_dynamic_shapes;
6263
}
6364

torch/csrc/lazy/ts_backend/ts_backend_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface {
4242
public:
4343
TSBackendImpl() {
4444
// TODO(whc) unify how all our flags are set and parsed as envs
45-
static bool env_use_cuda = std::getenv("LTC_TS_CUDA") != nullptr;
45+
static bool env_use_cuda = c10::utils::has_env("LTC_TS_CUDA");
4646
auto type =
4747
(env_use_cuda || FLAGS_torch_lazy_ts_cuda) ? at::kCUDA : at::kCPU;
4848
default_device_type_ = std::make_shared<TSBackendDeviceType>(type);

torch/csrc/lazy/ts_backend/ts_node.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
#include <c10/util/env.h>
12
#include <torch/csrc/lazy/core/debug_util.h>
23
#include <torch/csrc/lazy/ts_backend/ts_node.h>
34

45
namespace {
56
std::string GetFirstUserFrameInPythonIfEnabled() {
67
static const auto LTC_ENABLE_SOURCE_INFO =
7-
std::getenv("LTC_ENABLE_SOURCE_INFO");
8-
if (!LTC_ENABLE_SOURCE_INFO) {
8+
c10::utils::has_env("LTC_ENABLE_SOURCE_INFO");
9+
if (LTC_ENABLE_SOURCE_INFO) {
910
return {};
1011
}
1112

torch/csrc/utils/cpp_stacktraces.cpp

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,18 @@
44
#include <cstring>
55

66
#include <c10/util/Exception.h>
7+
#include <c10/util/env.h>
78

89
namespace torch {
910
namespace {
1011
bool compute_cpp_stack_traces_enabled() {
11-
auto envar = std::getenv("TORCH_SHOW_CPP_STACKTRACES");
12-
if (envar) {
13-
if (strcmp(envar, "0") == 0) {
14-
return false;
15-
}
16-
if (strcmp(envar, "1") == 0) {
17-
return true;
18-
}
19-
TORCH_WARN(
20-
"ignoring invalid value for TORCH_SHOW_CPP_STACKTRACES: ",
21-
envar,
22-
" valid values are 0 or 1.");
23-
}
24-
return false;
12+
auto envvar = c10::utils::check_env("TORCH_SHOW_CPP_STACKTRACES");
13+
return envvar.has_value() && envvar.value();
2514
}
2615

2716
bool compute_disable_addr2line() {
28-
auto envar = std::getenv("TORCH_DISABLE_ADDR2LINE");
29-
if (envar) {
30-
if (strcmp(envar, "0") == 0) {
31-
return false;
32-
}
33-
if (strcmp(envar, "1") == 0) {
34-
return true;
35-
}
36-
TORCH_WARN(
37-
"ignoring invalid value for TORCH_DISABLE_ADDR2LINE: ",
38-
envar,
39-
" valid values are 0 or 1.");
40-
}
41-
return false;
17+
auto envvar = c10::utils::check_env("TORCH_DISABLE_ADDR2LINE");
18+
return envvar.has_value() && envvar.value();
4219
}
4320
} // namespace
4421

@@ -48,20 +25,19 @@ bool get_cpp_stacktraces_enabled() {
4825
}
4926

5027
static torch::unwind::Mode compute_symbolize_mode() {
51-
auto envar_c = std::getenv("TORCH_SYMBOLIZE_MODE");
52-
if (envar_c) {
53-
std::string envar = envar_c;
54-
if (envar == "dladdr") {
28+
auto envar_c = c10::utils::get_env("TORCH_SYMBOLIZE_MODE");
29+
if (envar_c.has_value()) {
30+
if (envar_c == "dladdr") {
5531
return unwind::Mode::dladdr;
56-
} else if (envar == "addr2line") {
32+
} else if (envar_c == "addr2line") {
5733
return unwind::Mode::addr2line;
58-
} else if (envar == "fast") {
34+
} else if (envar_c == "fast") {
5935
return unwind::Mode::fast;
6036
} else {
6137
TORCH_CHECK(
6238
false,
6339
"expected {dladdr, addr2line, fast} for TORCH_SYMBOLIZE_MODE, got ",
64-
envar);
40+
envar_c.value());
6541
}
6642
} else {
6743
return compute_disable_addr2line() ? unwind::Mode::dladdr

0 commit comments

Comments
 (0)
0