8000 Avoid unnecessary tensor constructions (#139039) · pytorch/pytorch@d8f99f3 · GitHub
[go: up one dir, main page]

Skip to content

Commit d8f99f3

Browse files
cyyeverpytorchmergebot
authored andcommitted
Avoid unnecessary tensor constructions (#139039)
Because Variable is an alias of Tensor Pull Request resolved: #139039 Approved by: https://github.com/Skylion007
1 parent e80fe7f commit d8f99f3

File tree

5 files changed

+8
-10
lines changed

5 files changed

+8
-10
lines changed

torch/csrc/api/include/torch/nn/cloneable.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ class Cloneable : public Module {
4949
"and not the constructor?");
5050
for (const auto& parameter : named_parameters(/*recurse=*/false)) {
5151
auto& tensor = *parameter;
52-
auto data = device && tensor.device() != *device
53-
? tensor.to(*device)
54-
: autograd::Variable(tensor).clone();
52+
auto data = device && tensor.device() != *device ? tensor.to(*device)
53+
: tensor.clone();
5554
copy->parameters_[parameter.key()].set_data(data);
5655
}
5756
TORCH_CHECK(
@@ -62,9 +61,8 @@ class Cloneable : public Module {
6261
"and not the constructor?");
6362
for (const auto& buffer : named_buffers(/*recurse=*/false)) {
6463
auto& tensor = *buffer;
65-
auto data = device && tensor.device() != *device
66-
? tensor.to(*device)
67-
: autograd::Variable(tensor).clone();
64+
auto data = device && tensor.device() != *device ? tensor.to(*device)
65+
: tensor.clone();
6866
copy->buffers_[buffer.key()].set_data(data);
6967
}
7068
TORCH_CHECK(

torch/csrc/jit/frontend/tracer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ static IValue addInput(
347347
value->setType(type);
348348
if (type->isSubtypeOf(*TensorType::get())) {
349349
auto input_tensor = input.toTensor();
350-
auto name = Variable(input_tensor).name();
350+
auto const& name = input_tensor.name();
351351
if (state->hasValue(input)) {
352352
input_tensor = input_tensor.view(input_tensor.sizes());
353353
}

torch/csrc/jit/python/pybind_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ py::object toPyObject(IValue ivalue) {
612612
}
613613
} else {
614614
guardAgainstNamedTensor<at::Tensor>(tensor);
615-
return py::cast(autograd::Variable(std::move(tensor)));
615+
return py::cast(std::move(tensor));
616616
}
617617
} else if (ivalue.isStorage()) {
618618
return py::cast(std::move(ivalue).toStorage());

torch/csrc/jit/runtime/graph_executor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ struct CaptureList {
121121
}
122122

123123
void captureTensor(const at::Tensor& tensor, bool is_output) {
124-
var_captures_.emplace_back(Variable(tensor), is_output);
124+
var_captures_.emplace_back(tensor, is_output);
125125
}
126126

127127
void capture(const IValue& val, bool is_output) {

torch/csrc/jit/runtime/register_prim_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1309,7 +1309,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
13091309
[](Stack& stack) {
13101310
at::Tensor a;
13111311
pop(stack, a);
1312-
push(stack, autograd::Variable(a).variable_data());
1312+
push(stack, a.variable_data());
13131313
},
13141314
aliasAnalysisFromSchema()),
13151315
// these ops are not defined for Tensor

0 commit comments

Comments
 (0)
0