8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5a3b333 commit de85c09Copy full SHA for de85c09
vision_transformer/main.py
@@ -49,9 +49,9 @@ def __init__(self, args):
49
# Linear projection
50
self.LinearProjection = nn.Linear(self.input_size, self.latent_size)
51
# Class token
52
- self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
+ self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device))
53
# Positional embedding
54
- self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
+ self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device))
55
56
def forward(self, input_data):
57
input_data = input_data.to(self.device)
0 commit comments