10000 Fix memory leak on masked Tensor · pytorch/pytorch@41881d7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 41881d7

Browse files
committed
Fix memory leak on masked Tensor
1 parent 0e4d426 commit 41881d7

File tree

2 files changed

+1
-10
lines changed

2 files changed

+1
-10
lines changed

test/test_maskedtensor.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ def _compare_forward_backward(data, mask, fn):
7878
_compare_mt_t(masked_res, tensor_res)
7979
_compare_mt_t(mt.grad, t.grad, atol=1e-06)
8080

81-
# Free up the masked tensors manually to avoid memory leak
82-
del mt._masked_mask
83-
del mt._masked_data
84-
8581
def _create_random_mask(shape, device):
8682
return make_tensor(shape, device=device, dtype=torch.bool)
8783

@@ -229,11 +225,6 @@ def test_stack(self, device):
229225
for mt, t in zip(masked_tensors, data_tensors):
230226
_compare_mt_t(mt.grad, t.grad, atol=1e-06)
231227

232-
# Free up the masked tensors manually to avoid memory leak
233-
for mt in masked_tensors:
234-
del mt._masked_mask
235-
del mt._masked_data
236-
237228
def test_to_sparse(self, device):
238229
for sample in _generate_sample_data(device=device):
239230
data = sample.input

torch/masked/maskedtensor/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def get_data(self):
334334
class GetData(torch.autograd.Function):
335335
@staticmethod
336336
def forward(ctx, self):
337-
return self._masked_data
337+
return self._masked_data.detach()
338338

339339
@staticmethod
340340
def backward(ctx, grad_output):

0 commit comments

Comments
 (0)
0