8000 Bugfix in vision transformer - save class token and pos embedding (#1… · pytorch/examples@de85c09 · GitHub
[go: up one dir, main page]

Skip to content

Commit de85c09

Browse files
authored
Bugfix in vision transformer - save class token and pos embedding (#1204)
1 parent 5a3b333 commit de85c09

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vision_transformer/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def __init__(self, args):
4949
# Linear projection
5050
self.LinearProjection = nn.Linear(self.input_size, self.latent_size)
5151
# Class token
52-
self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
52+
self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device))
5353
# Positional embedding
54-
self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
54+
self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device))
5555

5656
def forward(self, input_data):
5757
input_data = input_data.to(self.device)

0 commit comments

Comments
 (0)
0