-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
scatter
>>> torch.zeros(3,5).scatter_(1, torch.tensor([[0],[3],[3]]), 3)
tensor([[3., 0., 0., 0., 0.],
[0., 0., 0., 3., 0.],
[0., 0., 0., 3., 0.]])
>>> torch.zeros(3,5).scatter_(0, torch.tensor([[0,1,2,2,2]]), 3)
tensor([[3., 0., 0., 0., 0.],
[0., 3., 0., 0., 0.],
[0., 0., 3., 3., 3.]])
advanced indexing :
num_vectors = 5
max_length = 3
num_out = 3
t = torch.arange(num_vectors).view(-1, 1).float()
zeros = torch.zeros(num_out * max_length, 1)
zeros[[0, 1, 2, 3, 6, 7]] = t[[2, 1, 0, 4, 3, 2]]
print(zeros.view(3, 3, 1))
tensor([[[2.],
[1.],
[0.]],
[[4.],
[0.],
[0.]],
[[3.],
[2.],
[0.]]])
select:torch.``index_select(input, dim, index, out=None) → Tensor
index (LongTensor) – the 1-D tensor containing the indices to index
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
masked_select
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],
[-1.2035, 1.2252, 0.5002, 0.6248],
[ 0.1307, -2.0608, 0.1244, 2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[ 0, 0, 0, 0],
[ 0, 1, 1, 1],
[ 0, 0, 0, 1]], dtype=torch.uint8)
>>> torch.masked_select(x, mask)
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
view and reshape
view 是共享底层数据的, view 只能作用在 contiguous tensor 上 , reshpae 可能共享底层数据, 如果 tensor 不是 continguous 的话, 可能就是 copy, you can't count on this.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels