8000 Use object identity for deepcopy memo · pytorch/pytorch@990d9b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 990d9b8

Browse files
committed
Use object identity for deepcopy memo
ghstack-source-id: 05bda28 Pull Request resolved: #126126
1 parent afda668 commit 990d9b8

File tree

7 files changed

+31
-14
lines changed

7 files changed

+31
-14
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -887,12 +887,12 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::create(
887887
}
888888

889889
IValue IValue::deepcopy(c10::optional<at::Device> device) const {
890-
IValue::HashAliasedIValueMap memo;
890+
IValue::HashIdentityIValueMap memo;
891891
return deepcopy(memo, device);
892892
}
893893

894894
IValue IValue::deepcopy(
895-
IValue::HashAliasedIValueMap& memo,
895+
IValue::HashIdentityIValueMap& memo,
896896
c10::optional<at::Device> device) const {
897897
if (memo.count(*this)) {
898898
return memo.at(*this);
@@ -1028,12 +1028,12 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy_to_weak_compilation_ref(
10281028

10291029
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
10301030
c10::optional<at::Device> device) const {
1031-
IValue::HashAliasedIValueMap memo;
1031+
IValue::HashIdentityIValueMap memo;
10321032
return deepcopy(memo, device);
10331033
}
10341034

10351035
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
1036-
IValue::HashAliasedIValueMap& memo,
1036+
IValue::HashIdentityIValueMap& memo,
10371037
c10::optional<at::Device> device) const {
10381038
auto cu = type_.cu_;
10391039
auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());

aten/src/ATen/core/ivalue.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,23 @@ struct TORCH_API IValue final {
11171117
using HashAliasedIValueMap =
11181118
std::unordered_map<IValue, IValue, HashAliasedIValue, CompAliasedIValues>;
11191119

1120+
struct HashIdentityIValue {
1121+
size_t operator()(const IValue& val) const {
1122+
return val.payload.u.as_int;
1123+
}
1124+
};
1125+
1126+
struct CompIdentityIValues {
1127+
bool operator()(const IValue& lhs, const IValue& rhs) const {
1128+
return lhs.is(rhs);
1129+
}
1130+
};
1131+
1132+
using HashIdentityIValues =
1133+
std::unordered_set<IValue, HashIdentityIValue, CompIdentityIValues>;
1134+
using HashIdentityIValueMap =
1135+
std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
1136+
11201137
// Chechs if this and rhs has a subvalues in common.
11211138
// [t1,t2] and [t2, t3] returns true.
11221139
bool overlaps(const IValue& rhs) const;
@@ -1130,7 +1147,7 @@ struct TORCH_API IValue final {
11301147
void visit(const std::function<bool(const IValue&)>& visitor) const;
11311148
IValue deepcopy(c10::optional<at::Device> device = c10::nullopt) const;
11321149
IValue deepcopy(
1133-
HashAliasedIValueMap& memo,
1150+
HashIdentityIValueMap& memo,
11341151
c10::optional<at::Device> device = c10::nullopt) const;
11351152

11361153
private:

aten/src/ATen/core/ivalue_inl.h

Lines changed: 1 addition & 1 deletion
< 8000 col width="52"/>
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
15891589
c10::optional<at::Device> device = c10::nullopt) const;
15901590

15911591
c10::intrusive_ptr<Object> deepcopy(
1592-
IValue::HashAliasedIValueMap& memo,
1592+
IValue::HashIdentityIValueMap& memo,
15931593
c10::optional<at::Device> device = c10::nullopt) const;
15941594

15951595
bool is_weak_compilation_ref() const {

torch/csrc/jit/api/module.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ Module Module::deepcopy(c10::optional<at::Device> device) const {
323323

324324
Module Module::clone(bool inplace) const {
325325
std::unordered_map<TypePtr, TypePtr> type_remap;
326-
IValue::HashAliasedIValueMap memo;
326+
IValue::HashIdentityIValueMap memo;
327327
const std::unordered_set<std::string> ignored_methods;
328328
const std::unordered_set<std::string> ignored_attributes;
329329
return clone_impl(
@@ -335,15 +335,15 @@ Module Module::clone(
335335
const std::unordered_set<std::string>& ignored_methods,
336336
const std::unordered_set<std::string>& ignored_attributes) const {
337337
std::unordered_map<TypePtr, TypePtr> type_remap;
338-
IValue::HashAliasedIValueMap memo;
338+
IValue::HashIdentityIValueMap memo;
339339
return clone_impl(
340340
type_remap, inplace, memo, ignored_methods, ignored_attributes);
341341
}
342342

343343
Module Module::clone_impl(
344344
std::unordered_map<TypePtr, TypePtr>& type_remap,
345345
bool inplace,
346-
IValue::HashAliasedIValueMap memo,
346+
IValue::HashIdentityIValueMap memo,
347347
const std::unordered_set<std::string>& ignored_methods,
348348
const std::unordered_set<std::string>& ignored_attributes) const {
349349
// Create a new _ivalue in the same compilation unit.

torch/csrc/jit/api/module.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ struct TORCH_API Module : public Object {
301301
Module clone_impl(
302302
std::unordered_map<TypePtr, TypePtr>& type_remap,
303303
bool inplace,
304-
IValue::HashAliasedIValueMap memo,
304+
IValue::HashIdentityIValueMap memo,
305305
const std::unordered_set<std::string>& ignored_methods,
306306
const std::unordered_set<std::string>& ignored_attributes) const;
307307

torch/csrc/jit/passes/quantization/insert_observers.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class ModuleCloneHelper {
9292
const ModuleQConfigMap& module_qconfig_map,
9393
bool inplace = false) {
9494
std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap;
95-
IValue::HashAliasedIValueMap memo;
95+
IValue::HashIdentityIValueMap memo;
9696
return clone_impl(
9797
module, module_qconfig_map, type_remap, inplace, std::move(memo));
9898
}
@@ -103,7 +103,7 @@ class ModuleCloneHelper {
103103
const ModuleQConfigMap& module_qconfig_map,
104104
std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap,
105105
bool inplace,
106-
IValue::HashAliasedIValueMap memo) {
106+
IValue::HashIdentityIValueMap memo) {
107107
auto qconfig = module_qconfig_map.at(module._ivalue());
108108
auto type = module.type();
109109
// Create a new _ivalue in the same compilation unit.

torch/csrc/jit/python/script_init.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,13 +668,13 @@ static constexpr std::array<const char*, 48> magic_method_names = {
668668
};
669669

670670
struct DeepCopyMemoTable {
671-
std::shared_ptr<IValue::HashAliasedIValueMap> map;
671+
std::shared_ptr<IValue::HashIdentityIValueMap> map;
672672
};
673673

674674
IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) {
675675
if (!memo.contains(py::str("__torch_script_memo_table"))) {
676676
memo["__torch_script_memo_table"] =
677-
DeepCopyMemoTable{std::make_shared<IValue::HashAliasedIValueMap>()};
677+
DeepCopyMemoTable{std::make_shared<IValue::HashIdentityIValueMap>()};
678678
}
679679
auto& ivalue_memo =
680680
*py::cast<DeepCopyMemoTable>(memo["__torch_script_memo_table"]).map;

0 commit comments

Comments
 (0)
0