8000 use backend enum, use 71668 config parser, de-inline format_size, rem… · pytorch/pytorch@6bbf293 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6bbf293

Browse files
committed
use backend enum, use 71668 config parser, de-inline format_size, remove ugly macro stubbing, Pytorch->PyTorch
1 parent 4cbde6f commit 6bbf293

File tree

12 files changed

+191
-173
lines changed

12 files changed

+191
-173
lines changed

aten/src/ATen/cuda/PeerToPeerAccess.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@ bool get_p2p_access(int dev, int dev_to_access) {
3939
dev_to_access, " is not a device");
4040
TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized");
4141

42-
#if CUDA_VERSION >= 11040
43-
static bool using_cudaMallocAsync = std::strcmp(CUDACachingAllocator::allocatorBackend(),
44-
"cudaMallocAsync") == 0;
45-
#endif
42+
static bool using_cudaMallocAsync = (CUDACachingAllocator::allocatorBackend() ==
43+
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
4644

4745
auto &cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access];
4846

@@ -55,8 +53,10 @@ bool get_p2p_access(int dev, int dev_to_access) {
5553
int access = 0;
5654
C10_CUDA_CHECK(cudaDeviceCanAccessPeer(&access, dev, dev_to_access));
5755
if (access) {
58-
#if CUDA_VERSION >= 11040
5956
if (using_cudaMallocAsync) {
57+
// Double-checks allocator backend hasn't changed, which would definitely be an error.
58+
TORCH_INTERNAL_ASSERT(CUDACachingAllocator::allocatorBackend() ==
59+
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
6060
// cudaMallocAsync pools are unaffected by cudaDeviceEnablePeerAccess.
6161
// We need pool-specific enablement. See
6262
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
@@ -68,19 +68,16 @@ bool get_p2p_access(int dev, int dev_to_access) {
6868
desc.flags = cudaMemAccessFlagsProtReadWrite;
6969
C10_CUDA_CHECK(cudaMemPoolSetAccess(mempool, &desc, 1 /* numDescs */));
7070
} else {
71-
TORCH_INTERNAL_ASSERT(std::strcmp(c10::cuda::CUDACachingAllocator::allocatorBackend(),
72-
"native") == 0);
73-
#endif
71+
TORCH_INTERNAL_ASSERT(CUDACachingAllocator::allocatorBackend() ==
72+
CUDACachingAllocator::AllocatorBackend::NATIVE);
7473
cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
7574
if (err == cudaErrorPeerAccessAlreadyEnabled) {
7675
// ignore and clear the error if access was already enabled
7776
cudaGetLastError();
7877
} else {
7978
C10_CUDA_CHECK(err);
8079
}
81-
#if CUDA_VERSION >= 11040
8280
}
83-
#endif
8481
cache = 1;
8582
} else {
8683
cache = 0;

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,11 @@ void copy_device_to_device(TensorIterator& iter,
9393
void *src = iter.data_ptr(1);
9494
size_t size = numel * iter.element_size(0);
9595
if (src != dst || src_device != dst_device) {
96-
#if CUDA_VERSION >= 11040
9796
// Due to bizarre cuda driver intricacies, copies of
9897
// cudaMallocAsynced memory between devices that aren't
9998
// peer-to-peer-capable need "cudaMemcpyPeerAsync".
100-
static bool using_cudaMallocAsync = std::strcmp(CUDACachingAllocator::allocatorBackend(),
101-
"cudaMallocAsync") == 0;
99+
static bool using_cudaMallocAsync = (CUDACachingAllocator::allocatorBackend() ==
100+
CUDACachingAllocator::AllocatorBackend::CUDAMALLOCASYNC);
102101
bool needs_MemcpyPeer = (src_device != dst_device &&
103102
using_cudaMallocAsync &&
104103
!p2p_enabled);
@@ -108,14 +107,11 @@ void copy_device_to_device(TensorIterator& iter,
108107
src, src_device.index(),
109108
size, copy_stream));
110109
} else {
111-
#endif
112110
AT_CUDA_CHECK(cudaMemcpyAsync(
113111
dst, src, size,
114112
cudaMemcpyDeviceToDevice,
115113
copy_stream));
116-
#if CUDA_VERSION >= 11040
117114
}
118-
#endif
119115
}
120116
} else {
121117
if (same_neg) {

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 119 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,100 +1613,136 @@ std::shared_ptr<void> getIpcDevPtr(std::string handle) {
16131613

16141614
} // namespace Native
16151615

1616-
// Define config stuff here, rather than its own .cpp file,
1616+
1617+
// General caching allocator utilities
1618+
1619+
// Environment config parser
1620+
// Defined here, rather than its own .cpp file,
16171621
// because parseArgs needs to know kLargeBuffer.
1618-
bool parsed = false;
1619-
1620-
size_t m_max_split_size = std::numeric_limits<size_t>::max();
1621-
std::string m_allocator_backend = "native";
1622-
1623-
void parseArgs() {
1624-
const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF");
1625-
if (val != NULL) {
1626-
const std::string config(val);
1627-
1628-
std::regex exp("[\\s,]+");
1629-
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
1630-
std::sregex_token_iterator end;
1631-
std::vector<std::string> options(it, end);
1632-
1633-
bool used_max_split_size_mb(false);
1634-
bool used_cudaMallocAsync(false);
1635-
1636-
for (auto option : options) {
1637-
std::regex exp2("[:]+");
1638-
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
1639-
std::sregex_token_iterator end2;
1640-
std::vector<std::string> kv(it2, end2);
1641-
if (kv.size() >= 2) {
1642-
/* Maximum split size in MB. Limited to large size blocks */
1643-
if (kv[0].compare("max_split_size_mb") == 0) {
1644-
size_t val2 = stoi(kv[1]);
1645-
TORCH_CHECK(
1646-
val2 > Native::kLargeBuffer / (1024 * 1024),
1647-
"CachingAllocator option max_split_size_mb too small, must be > ",
1648-
Native::kLargeBuffer / (1024 * 1024),
1649-
"");
1650-
val2 = std::max(val2, Native::kLargeBuffer / (1024 * 1024));
1651-
val2 = std::min(
1652-
val2, (std::numeric_limits<size_t>::max() / (1024 * 1024)));
1653-
m_max_split_size = val2 * 1024 * 1024;
1654-
used_max_split_size_mb = true;
1655-
} else if (kv[0].compare("backend") == 0) {
1656-
TORCH_CHECK(((kv[1].compare("native") == 0) ||
1657-
(kv[1].compare("cudaMallocAsync") == 0)),
1658-
"Unknown allocator backend, "
1659-
"options are native and cudaMallocAsync");
1660-
m_allocator_backend = kv[1];
1661-
used_cudaMallocAsync = (kv[1].compare("cudaMallocAsync") == 0);
1662-
if (used_cudaMallocAsync) {
1622+
class CachingAllocatorConfig {
1623+
public:
1624+
static AllocatorBackend allocator_backend() {
1625+
return instance().m_allocator_backend;
1626+
}
1627+
1628+
static size_t max_split_size() {
1629+
return instance().m_max_split_size;
1630+
}
1631+
1632+
private:
1633+
static CachingAllocatorConfig& instance() {
1634+
static CachingAllocatorConfig* s_instance = ([]() {
1635+
auto inst = new CachingAllocatorConfig();
1636+
inst->parseArgs();
1637+
return inst;
1638+
})();
1639+
return *s_instance;
1640+
}
1641+
1642+
CachingAllocatorConfig()
1643+
: m_allocator_backend{AllocatorBackend::NATIVE},
1644+
m_max_split_size{std::numeric_limits<size_t>::max()} {}
1645+
AllocatorBackend m_allocator_backend;
1646+
size_t m_max_split_size;
1647+
1648+
void parseArgs() {
1649+
const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF");
1650+
if (val != NULL) {
1651+
const std::string config(val);
1652+
1653+
std::regex exp("[\\s,]+");
1654+
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
1655+
std::sregex_token_iterator end;
1656+
std::vector<std::string> options(it, end);
1657+
1658+
bool used_max_split_size_mb(false);
1659+
bool used_cudaMallocAsync(false);
1660+
1661+
for (auto option : options) {
1662+
std::regex exp2("[:]+");
1663+
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
1664+
std::sregex_token_iterator end2;
1665+
std::vector<std::string> kv(it2, end2);
1666+
if (kv.size() >= 2) {
1667+
/* Maximum split size in MB. Limited to large size blocks */
1668+
if (kv[0].compare("max_split_size_mb") == 0) {
1669+
size_t val2 = stoi(kv[1]);
1670+
TORCH_CHECK(
1671+
val2 > Native::kLargeBuffer / (1024 * 1024),
1672+
"CachingAllocator option max_split_size_mb too small, must be > ",
1673+
Native::kLargeBuffer / (1024 * 1024),
1674+
"");
1675+
val2 = std::max(val2, Native::kLargeBuffer / (1024 * 1024));
1676+
val2 = std::min(
1677+
val2, (std::numeric_limits<size_t>::max() / (1024 * 1024)));
1678+
m_max_split_size = val2 * 1024 * 1024;
1679+
used_max_split_size_mb = true;
1680+
} else if (kv[0].compare("backend") == 0) {
1681+
TORCH_CHECK(((kv[1].compare("native") == 0) ||
1682+
(kv[1].compare("cudaMallocAsync") == 0)),
1683+
"Unknown allocator backend, "
1684+
"options are native and cudaMallocAsync");
1685+
used_cudaMallocAsync = (kv[1].compare("cudaMallocAsync") == 0);
1686+
if (used_cudaMallocAsync) {
16631687
#if CUDA_VERSION >= 11040
1664-
int version;
1665-
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
1666-
TORCH_CHECK(version >= 11040,
1667-
"backend:cudaMallocAsync requires CUDA runtime "
1668-
"11.4 or newer, but cudaDriverGetVersion returned ",
1669-
version);
1688+
int version;
1689+
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
1690+
TORCH_CHECK(version >= 11040,
1691+
"backend:cudaMallocAsync requires CUDA runtime "
1692+
"11.4 or newer, but cudaDriverGetVersion returned ",
1693+
version);
1694+
m_allocator_backend = AllocatorBackend::CUDAMALLOCASYNC;
16701695
#else
1671-
TORCH_CHECK(false,
1672-
"backend:cudaMallocAsync requires Pytorch to be built with "
1673-
"CUDA 11.4 or newer, but CUDA_VERSION is ",
1674-
CUDA_VERSION);
1696+
TORCH_CHECK(false,
1697+
"backend:cudaMallocAsync requires PyTorch to be built with "
1698+
"CUDA 11.4 or newer, but CUDA_VERSION is ",
1699+
CUDA_VERSION);
16751700
#endif
1676-
}
1677-
} else {
1678-
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]);
1679-
}
1701+
}
1702+
} else {
1703+
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]);
1704+
}
1705+
}
1706+
}
1707+
1708+
if (used_max_split_size_mb && used_cudaMallocAsync) {
1709+
TORCH_WARN("backend:cudaMallocAsync ignores max_split_size_mb");
1710+
}
16801711
}
16811712
}
1713+
};
16821714

1683-
if (used_max_split_size_mb && used_cudaMallocAsync) {
1684-
TORCH_WARN("backend:cudaMallocAsync ignores max_split_size_mb");
1685-
}
1686-
}
1715+
// External config interface (declared in CUDACachingAllocator.h)
1716+
// Should we bother having these two functions?
1717+
// They are basically useless layers of indirection, but a minor
1718+
// code-cleanliness benefit is they alleviate the need to define
1719+
// CachingAllocatorConfig itself in CUDACachingAllocator.h.
1720+
AllocatorBackend allocatorBackend() {
1721+
return CachingAllocatorConfig::allocator_backend();
16871722
}
16881723

1689-
// Public interface
1690-
const char* allocatorBackend() {
1691-
// Static initializer is thread-safe
1692-
static const std::string backend = []() {
1693-
if (!parsed) {
1694-
parseArgs();
1695-
}
1696-
return m_allocator_backend;
1697-
}();
1698-
return backend.c_str();
1724+
size_t maxSplitSize() {
1725+
return CachingAllocatorConfig::max_split_size();
16991726
}
17001727

1701-
size_t maxSplitSize() {
1702-
// Static initializer is thread-safe
1703-
static const size_t size = []() {
1704-
if (!parsed) {
1705-
parseArgs();
1706-
}
1707-
return m_max_split_size;
1708-
}();
1709-
return size;
1728+
// Size pretty-printer
1729+
inline std::string format_size(uint64_t size) {
1730+
std::ostringstream os;
1731+
os.precision(2);
1732+
os << std::fixed;
1733+
if (size <= 1024) {
1734+
os << size << " bytes";
1735+
} else if (size <= 1048576) {
1736+
os << (size / 1024.0);
1737+
os << " KiB";
1738+
} else if (size <= 1073741824ULL) {
1739+
os << size / 1048576.0;
1740+
os << " MiB";
1741+
} else {
1742+
os << size / 1073741824.0;
1743+
os << " GiB";
1744+
}
1745+
return os.str();
17101746
}
17111747

17121748
} // namespace CUDACachingAllocator

0 commit comments

Comments
 (0)
0