8000 Making batching rule for F.embedding DTensor-aware (#162117) · mansiag05/pytorch@439faa2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 439faa2

Browse files
zou3519mansiag05
authored andcommitted
Making batching rule for F.embedding DTensor-aware (pytorch#162117)
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's batching rule generates a new tensor via at::arange, at::arange generates a regular tensor, and DTensor rightfully errors on mixed DTensor-regular Tensor operations. This PR fixes the problem by activating DTensor implicit replication on just the at::arange and the subsequent add operation. In order to accomplish this I move the DTensor implicit replication flag to C++ (most batching rules are in C++). Test Plan: - new test Pull Request resolved: pytorch#162117 Approved by: https://github.com/bdhirsh
1 parent 2e6b5e7 commit 439faa2

File tree

10 files changed

+112
-7
lines changed

10 files changed

+112
-7
lines changed

aten/src/ATen/DTensorState.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include <ATen/DTensorState.h>
2+
3+
namespace at {
4+
5+
namespace {
6+
thread_local bool kDTensorAllowImplicitReplication = false;
7+
}
8+
9+
bool get_dtensor_allow_implicit_replication() {
10+
return kDTensorAllowImplicitReplication;
11+
}
12+
13+
void set_dtensor_allow_implicit_replication(bool enabled) {
14+
kDTensorAllowImplicitReplication = enabled;
15+
}
16+
17+
} // namespace at

aten/src/ATen/DTensorState.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#pragma once
2+
3+
#include <c10/macros/Macros.h>
4+
5+
namespace at {
6+
7+
TORCH_API bool get_dtensor_allow_implicit_replication();
8+
TORCH_API void set_dtensor_allow_implicit_replication(bool enabled);
9+
10+
struct DTensorAllowImplicitReplication {
11+
DTensorAllowImplicitReplication()
12+
: prev_dtensor_allow_implicit_replication_(
13+
get_dtensor_allow_implicit_replication()) {
14+
set_dtensor_allow_implicit_replication(true);
15+
}
16+
17+
DTensorAllowImplicitReplication(const DTensorAllowImplicitReplication&) =
18+
delete;
19+
DTensorAllowImplicitReplication& operator=(
20+
const DTensorAllowImplicitReplication&) = delete;
21+
DTensorAllowImplicitReplication(DTensorAllowImplicitReplication&&) = delete;
22+
DTensorAllowImplicitReplication& operator=(
23+
DTensorAllowImplicitReplication&&) = delete;
24+
25+
~DTensorAllowImplicitReplication() {
26+
set_dtensor_allow_implicit_replication(
27+
prev_dtensor_allow_implicit_replication_);
28+
}
29+
30+
private:
31+
bool prev_dtensor_allow_implicit_replication_;
32+
};
33+
34+
} // namespace at

aten/src/ATen/ThreadLocalState.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ATen/record_function.h>
99
#include <ATen/SavedTensorHooks.h>
1010
#include <ATen/FunctionalTensorWrapper.h>
11+
#include <ATen/DTensorState.h>
1112

1213
namespace at {
1314

@@ -19,6 +20,7 @@ ThreadLocalState::ThreadLocalState()
1920
torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
2021 8000
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
2122
saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
23+
dtensor_allow_implicit_replication_(at::get_dtensor_allow_implicit_replication()),
2224
saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
2325
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
2426
for(size_t i=0; i<autocast_dtypes_.size(); i++) {
@@ -52,6 +54,8 @@ void ThreadLocalState::setThreadLocalState(
5254

5355
c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
5456

57+
at::set_dtensor_allow_implicit_replication(state.dtensor_allow_implicit_replication_);
58+
5559
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
5660

5761
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);

aten/src/ATen/ThreadLocalState.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class TORCH_API ThreadLocalState {
7575

7676
bool functionalization_reapply_views_state_;
7777

78+
bool dtensor_allow_implicit_replication_;
79+
7880
// TLS for arbitrary python objects that is registered via hooks
7981
at::impl::ThreadLocalPythonObjects saved_objects_;
8082

aten/src/ATen/functorch/BatchRulesModules.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/functorch/BatchRulesHelper.h>
88
#include <ATen/functorch/PlumbingHelper.h>
99
#include <ATen/core/dispatch/Dispatcher.h>
10+
#include <ATen/DTensorState.h>
1011

1112
#include <utility>
1213

@@ -44,8 +45,13 @@ static std::tuple<Tensor, std::optional<int64_t>> embedding_batch_rule(
4445
const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
4546
auto indices_ = moveBatchDimToFront(indices, indices_bdim);
4647

47-
const auto range = getStepTensor(indices, batch_size, num_embeddings);
48-
indices_ = indices_ + range;
48+
{
49+
// getStepTensor returns a regular Tensor. If indices_ is a DTensor
50+
// we want to allow this mixed DTensor-Tensor operation.
51+
at::DTensorAllowImplicitReplication guard;
52+
const auto range = getStepTensor(indices, batch_size, num_embeddings);
53+
indices_ = indices_ + range;
54+
}
4955
auto result = at::embedding_symint(weight_, indices_, std::move(padding_idx), scale_grad_by_freq, sparse);
5056
return std::make_tuple(std::move(result), 0);
5157
}

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,7 @@ aten_cpu_source_non_codegen_list = [
10881088
"aten/src/ATen/DeviceAccelerator.cpp",
10891089
"aten/src/ATen/Context.cpp",
10901090
"aten/src/ATen/DLConvertor.cpp",
1091+
"aten/src/ATen/DTensorState.cpp",
10911092
"aten/src/ATen/EmptyTensor.cpp",
10921093
"aten/src/ATen/ExpandUtils.cpp",
10931094
"aten/src/ATen/CachedTensorUtils.cpp",

test/distributed/tensor/test_dtensor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,30 @@ def test_implicit_replication(self):
848848
self.assertEqual(local_shard.shape, (4, 3))
849849
self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
850850

851+
@with_comms
852+
def test_vmap_embedding(self):
853+
mesh = self.build_device_mesh()
854+
batch_size, seq_len = 2, 6
855+
output_dim = 32
856+
857+
indices = torch.zeros(*(batch_size, seq_len), dtype=torch.int64)
858+
indices[0, 1] = 1
859+
indices[1, 3] = 1
860+
indices[1, 5] = 1
861+
indices = DTensor.from_local(indices, mesh, [Shard(0)])
862+
863+
emb = torch.randn(
864+
*(batch_size, 8, output_dim),
865+
dtype=torch.float32,
866+
)
867+
emb = DTensor.from_local(emb, mesh, [Shard(0)])
868+
result = torch.vmap(F.embedding)(indices, emb)
869+
expected = [F.embedding(indices[i], emb[i]) for i in range(batch_size)]
870+
expected = torch.stack(expected)
871+
local_result = result.to_local()
872+
local_expected = expected.to_local()
873+
self.assertEqual(local_result, local_expected)
874+
851875
@with_comms
852876
def test_auto_implicit_replication(self):
853877
mesh = self.build_device_mesh()

torch/_C/__init__.pyi.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,6 +1852,9 @@ class _SetExcludeDispatchKeyGuard:
18521852
def __enter__(self): ...
18531853
def __exit__(self, *exc_info: object) -> None: ...
18541854

1855+
def _get_dtensor_allow_implicit_replication() -> _bool: ...
1856+
def _set_dtensor_allow_implicit_replication(value: _bool) -> None: ...
1857+
18551858
# Defined in torch/csrc/utils/schema_info.h
18561859

18571860
class _SchemaInfo:

torch/csrc/utils/python_dispatch.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/csrc/utils/python_dispatch.h>
33

44
#include <ATen/ATen.h>
5+
#include <ATen/DTensorState.h>
56
#include <ATen/FuncTorchTLS.h>
67
#include <ATen/FunctionalTensorWrapper.h>
78
#include <ATen/TensorSubclassLikeUtils.h>
@@ -1045,6 +1046,13 @@ void initDispatchBindings(PyObject* module) {
10451046
m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
10461047
m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
10471048

1049+
m.def(
1050+
"_get_dtensor_allow_implicit_replication",
1051+
&at::get_dtensor_allow_implicit_replication);
1052+
m.def(
1053+
"_set_dtensor_allow_implicit_replication",
1054+
&at::set_dtensor_allow_implicit_replication);
1055+
10481056
using c10::impl::TorchDispatchModeKey;
10491057
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
10501058
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)

torch/distributed/tensor/_dispatch.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,17 @@ def __init__(self) -> None:
121121
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
122122
}
123123

124-
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
125-
# as implicitly replicated or we throw error to user.
126-
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
127-
# it as False by default.
128-
self._allow_implicit_replication = False
124+
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
125+
# as implicitly replicated or we throw error to user.
126+
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
127+
# it as False by default.
128+
@property
129+
def _allow_implicit_replication(self) -> bool:
130+
return torch._C._get_dtensor_allow_implicit_replication()
131+
132+
@_allow_implicit_replication.setter
133+
def _allow_implicit_replication(self, value: bool) -> None:
134+
return torch._C._set_dtensor_allow_implicit_replication(value)
129135

130136
def dispatch(
131137
self,

0 commit comments

Comments
 (0)
0