8000 Add support for PickleOpCode::APPEND in torch unpickler (#104027) · pytorch/pytorch@fe1f26a · GitHub
[go: up one dir, main page]

Skip to content

Commit fe1f26a

Browse files
emasappytorchmergebot
authored andcommitted
Add support for PickleOpCode::APPEND in torch unpickler (#104027)
Reviewed By: qiminglu Differential Revision: D46760650 Pull Request resolved: #104027 Approved by: https://github.com/ezyang
1 parent 0297232 commit fe1f26a

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

test/cpp/jit/test_save_load.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <torch/csrc/jit/serialization/export_bytecode.h>
1212
#include <torch/csrc/jit/serialization/import.h>
1313
#include <torch/csrc/jit/serialization/import_source.h>
14+
#include <torch/script.h>
1415
#include <torch/torch.h>
1516

1617
#include "caffe2/serialize/istream_adapter.h"
@@ -318,5 +319,15 @@ TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest)
318319
ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception"), error2);
319320
}
320321

322+
TEST(SerializationTest, TestPickleAppend) {
323+
auto data = std::vector<char>({'\x80', char(2), ']', 'K', char(2), 'a', '.'});
324+
325+
torch::IValue actual = torch::jit::unpickle(data.data(), data.size());
326+
327+
torch::IValue expected = c10::impl::GenericList(at::AnyType::get());
328+
expected.toList().push_back(2);
329+
ASSERT_EQ(expected, actual);
330+
}
331+
321332
} // namespace jit
322333
} // namespace torch

torch/csrc/jit/serialization/unpickler.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <torch/csrc/jit/serialization/unpickler.h>
1111
#include <torch/csrc/utils/byte_order.h>
1212
#include <string>
13+
#include <utility>
1314

1415
namespace torch::jit {
1516

@@ -406,7 +407,7 @@ PickleOpCode Unpickler::readInstruction() {
406407
} break;
407408
case PickleOpCode::TUPLE1: {
408409
TORCH_CHECK(
409-
stack_.size() > 0,
410+
!stack_.empty(),
410411
"Parsing error: stack_ contains ",
411412
stack_.size(),
412413
" elements, at least 1 expected");
@@ -451,6 +452,12 @@ PickleOpCode Unpickler::readInstruction() {
451452
auto list_ivalue = stack_.at(start - 1);
452453
readList(list_ivalue);
453454
} break;
455+
case PickleOpCode::APPEND: {
456+
TORCH_CHECK(
457+
stack_.size() >= 2, "Parsing error: missing elements in stack_.");
458+
auto list_ivalue = stack_.at(stack_.size() - 2);
459+
readListElements(list_ivalue, stack_.size() - 1);
460+
} break;
454461
case PickleOpCode::LIST: {
455462
IValue list_ivalue = c10::impl::GenericList(AnyType::get());
456463
readList(list_ivalue);
@@ -1118,12 +1125,7 @@ std::string Unpickler::readBytes(size_t length) {
11181125
return data;
11191126
}
11201127

1121-
// Pop all the list items off of the stack and append them to the list at
1122-
// the corresponding MARK
1123-
void Unpickler::readList(IValue list_ivalue) {
1124-
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
1125-
size_t start = marks_.back();
1126-
marks_.pop_back();
1128+
void Unpickler::readListElements(IValue list_ivalue, size_t start) {
11271129
auto num_elements = stack_.size() - start;
11281130
auto elements = c10::ArrayRef<IValue>(stack_).slice(start);
11291131
if (list_ivalue.isIntList()) {
@@ -1159,10 +1161,18 @@ void Unpickler::readList(IValue list_ivalue) {
11591161
} else {
11601162
AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind());
11611163
}
1162-
11631164
stack_.erase(stack_.begin() + start, stack_.end());
11641165
}
11651166

1167+
// Pop all the list items off of the stack and append them to the list at
1168+
// the corresponding MARK
1169+
void Unpickler::readList(IValue list_ivalue) {
1170+
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
1171+
size_t start = marks_.back();
1172+
marks_.pop_back();
1173+
readListElements(std::move(list_ivalue), start);
1174+
}
1175+
11661176
inline bool is_valid_python_id_char(char c) {
11671177
return c == '_' || c == '.' || (c >= '0' && c <= '9') ||
11681178
(c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');

torch/csrc/jit/serialization/unpickler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class TORCH_API Unpickler {
131131
}
132132
std::string readString();
133133
void readList(IValue list_ivalue);
134+
void readListElements(IValue list_ivalue, size_t start);
134135
void setInput(size_t memo_id);
135136
void run();
136137

0 commit comments

Comments
 (0)
0