8000 Use object identity for deepcopy memo · pytorch/pytorch@f2eff2d · GitHub
[go: up one dir, main page]

Skip to content

Commit f2eff2d

Browse files
committed
Use object identity for deepcopy memo
ghstack-source-id: 87ce5ed Pull Request resolved: #126126
1 parent afda668 commit f2eff2d

File tree

9 files changed

+95
-14
lines changed

9 files changed

+95
-14
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -887,12 +887,12 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::create(
887887
}
888888

889889
IValue IValue::deepcopy(c10::optional<at::Device> device) const {
890-
IValue::HashAliasedIValueMap memo;
890+
IValue::HashIdentityIValueMap memo;
891891
return deepcopy(memo, device);
892892
}
893893

894894
IValue IValue::deepcopy(
895-
IValue::HashAliasedIValueMap& memo,
895+
IValue::HashIdentityIValueMap& memo,
896896
c10::optional<at::Device> device) const {
897897
if (memo.count(*this)) {
898898
return memo.at(*this);
@@ -1028,12 +1028,12 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy_to_weak_compilation_ref(
10281028

10291029
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
10301030
c10::optional<at::Device> device) const {
1031-
IValue::HashAliasedIValueMap memo;
1031+
IValue::HashIdentityIValueMap memo;
10321032
return deepcopy(memo, device);
10331033
}
10341034

10351035
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
1036-
IValue::HashAliasedIValueMap& memo,
1036+
IValue::HashIdentityIValueMap& memo,
10371037
c10::optional<at::Device> device) const {
10381038
auto cu = type_.cu_;
10391039
auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());

aten/src/ATen/core/ivalue.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,23 @@ struct TORCH_API IValue final {
11171117
using HashAliasedIValueMap =
11181118
std::unordered_map<IValue, IValue, HashAliasedIValue, CompAliasedIValues>;
11191119

1120+
struct HashIdentityIValue {
1121+
size_t operator()(const IValue& val) const {
1122+
return val.payload.u.as_int;
1123+
}
1124+
};
1125+
1126+
struct CompIdentityIValues {
1127+
bool operator()(const IValue& lhs, const IValue& rhs) const {
1128+
return lhs.is(rhs);
1129+
}
1130+
};
1131+
1132+
using HashIdentityIValues =
1133+
std::unordered_set<IValue, HashIdentityIValue, CompIdentityIValues>;
1134+
using HashIdentityIValueMap =
1135+
std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
1136+
11201137
// Chechs if this and rhs has a subvalues in common.
11211138
// [t1,t2] and [t2, t3] returns true.
11221139
bool overlaps(const IValue& rhs) const;
@@ -1130,7 +1147,7 @@ struct TORCH_API IValue final {
11301147
void visit(const std::function<bool(const IValue&)>& visitor) const;
11311148
IValue deepcopy(c10::optional<at::Device> device = c10::nullopt) const;
11321149
IValue deepcopy(
1133-
HashAliasedIValueMap& memo,
1150+
HashIdentityIValueMap& memo,
11341151
c10::optional<at::Device> device = c10::nullopt) const;
11351152

11361153
private:

aten/src/ATen/core/ivalue_inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
15891589
c10::optional<at::Device> device = c10::nullopt) const;
15901590

15911591
c10::intrusive_ptr<Object> deepcopy(
1592-
IValue::HashAliasedIValueMap& memo,
1592+
IValue::HashIdentityIValueMap& memo,
15931593
c10::optional<at::Device> device = c10::nullopt) const;
15941594

15951595
bool is_weak_compilation_ref() const {

test/cpp/api/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ set(TORCH_API_TEST_SOURCES
1010
${TORCH_API_TEST_DIR}/functional.cpp
1111
${TORCH_API_TEST_DIR}/init.cpp
1212
${TORCH_API_TEST_DIR}/integration.cpp
13+
${TORCH_API_TEST_DIR}/ivalue.cpp
1314
${TORCH_API_TEST_DIR}/jit.cpp
1415
${TORCH_API_TEST_DIR}/memory.cpp
1516
${TORCH_API_TEST_DIR}/meta_tensor.cpp

test/cpp/api/ivalue.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <ATen/core/ivalue.h>
4+
5+
#include <c10/util/flat_hash_map.h>
6+
#include <c10/util/irange.h>
7+
#include <c10/util/tempfile.h>
8+
9+
#include <torch/torch.h>
10+
11+
#include <test/cpp/api/support.h>
12+
13+
#include <cstdio>
14+
#include <memory>
15+
#include <sstream>
16+
#include <string>
17+
#include <vector>
18+
19+
using namespace torch::test;
20+
using namespace torch::nn;
21+
using namespace torch::optim;
22+
23+
TEST(IValueTest, DeepcopyTensors) {
24+
torch::Tensor t0 = torch::randn({2, 3});
25+
torch::Tensor t1 = torch::randn({3, 4});
26+
torch::Tensor t2 = t0.detach();
27+
torch::Tensor t3 = t0;
28+
torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2);
29+
std::vector<torch::Tensor> tensor_vector = {t0, t1, t2, t3, t4};
30+
c10::List<torch::Tensor> tensor_list(tensor_vector);
31+
torch::IValue tensor_list_ivalue(tensor_list);
32+
33+
c10::IValue::CompIdentityIValues ivalue_compare;
34+
35+
// Make sure our setup configuration is correct
36+
ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get()));
37+
ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get()));
38+
ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get()));
39+
ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get()));
40+
ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get()));
41+
42+
c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy();
43+
c10::List<torch::IValue> copied_list = copied_ivalue.toList();
44+
45+
// Make sure our setup configuration is correct
46+
ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get()));
47+
ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get()));
48+
ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get()));
49+
ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get()));
50+
// NOTE: this is actually incorrect. Ideally, these _should_ be aliases.
51+
ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get()));
52+
53+
ASSERT_TRUE(copied_list[0].get().toTensor().allclose(
54+
tensor_list[0].get().toTensor()));
55+
ASSERT_TRUE(copied_list[1].get().toTensor().allclose(
56+
tensor_list[1].get().toTensor()));
57+
ASSERT_TRUE(copied_list[2].get().toTensor().allclose(
58+
tensor_list[2].get().toTensor()));
59+
ASSERT_TRUE(copied_list[3].get().toTensor().allclose(
60+
tensor_list[3].get().toTensor()));
61+
ASSERT_TRUE(copied_list[4].get().toTensor().allclose(
62+
tensor_list[4].get().toTensor()));
63+
}

torch/csrc/jit/api/module.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ Module Module::deepcopy(c10::optional<at::Device> device) const {
323323

324324
Module Module::clone(bool inplace) const {
325325
std::unordered_map<TypePtr, TypePtr> type_remap;
326-
IValue::HashAliasedIValueMap memo;
326+
IValue::HashIdentityIValueMap memo;
327327
const std::unordered_set<std::string> ignored_methods;
328328
const std::unordered_set<std::string> ignored_attributes;
329329
return clone_impl(
@@ -335,15 +335,15 @@ Module Module::clone(
335335
const std::unordered_set<std::string>& ignored_methods,
336336
const std::unordered_set<std::string>& ignored_attributes) const {
337337
std::unordered_map<TypePtr, TypePtr> type_remap;
338-
IValue::HashAliasedIValueMap memo;
338+
IValue::HashIdentityIValueMap memo;
339339
return clone_impl(
340340
type_remap, inplace, memo, ignored_methods, ignored_attributes);
341341
}
342342

343343
Module Module::clone_impl(
344344
std::unordered_map<TypePtr, TypePtr>& type_remap,
345345
bool inplace,
346-
IValue::HashAliasedIValueMap memo,
346+
IValue::HashIdentityIValueMap memo,
347347
const std::unordered_set<std::string>& ignored_methods,
348348
const std::unordered_set<std::string>& ignored_attributes) const {
349349
// Create a new _ivalue in the same compilation unit.

torch/csrc/jit/api/module.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ struct TORCH_API Module : public Object {
301301
Module clone_impl(
302302
std::unordered_map<TypePtr, TypePtr>& type_remap,
303303
bool inplace,
304-
IValue::HashAliasedIValueMap memo,
304+
IValue::HashIdentityIValueMap memo,
305305
const std::unordered_set<std::string>& ignored_methods,
306306
const std::unordered_set<std::string>& ignored_attributes) const;
307307

torch/csrc/jit/passes/quantization/insert_observers.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class ModuleCloneHelper {
9292
const ModuleQConfigMap& module_qconfig_map,
9393
bool inplace = false) {
9494
std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap;
95-
IValue::HashAliasedIValueMap memo;
95+
IValue::HashIdentityIValueMap memo;
9696
return clone_impl(
9797
module, module_qconfig_map, type_remap, inplace, std::move(memo));
9898
}
@@ -103,7 +103,7 @@ class ModuleCloneHelper {
103103
const ModuleQConfigMap& module_qconfig_map,
104104
std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap,
105105
bool inplace,
106-
IValue::HashAliasedIValueMap memo) {
106+
IValue::HashIdentityIValueMap memo) {
107107
auto qconfig = module_qconfig_map.at(module._ivalue());
108108
auto type = module.type();
109109
// Create a new _ivalue in the same compilation unit.

torch/csrc/jit/python/script_init.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,13 +668,13 @@ static constexpr std::array<const char*, 48> magic_method_names = {
668668
};
669669

670670
struct DeepCopyMemoTable {
671-
std::shared_ptr<IValue::HashAliasedIValueMap> map;
671+
std::shared_ptr<IValue::HashIdentityIValueMap> map;
672672
};
673673

674674
IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) {
675675
if (!memo.contains(py::str("__torch_script_memo_table"))) {
676676
memo["__torch_script_memo_table"] =
677-
DeepCopyMemoTable{std::make_shared<IValue::HashAliasedIValueMap>()};
677+
DeepCopyMemoTable{std::make_shared<IValue::HashIdentityIValueMap>()};
678678
}
679679
auto& ivalue_memo =
680680
*py::cast<DeepCopyMemoTable>(memo["__torch_script_memo_table"]).map;

0 commit comments

Comments
 (0)
0