-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
I tried running the following script and found that S5 is far slower than PyTorch's LSTM. Is this supposed to be the case? Perhaps the scale at which I'm testing it is too small to realize the benefit?
from datetime import datetime
import os
import torch
from s5 import S5
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
L = 1200
B = 256
x_dim = 128
m = S5(x_dim, 512).cuda()
lstm = torch.nn.LSTM(obs_size, 512).cuda()
x = torch.randn(B, L, obs_size).cuda()
t0 = datetime.now()
for i in range(10):
y, _ = lstm(x)
torch.sum(y).backward()
t1 = datetime.now()
print(t1 - t0)
t2 = datetime.now()
for i in range(10):
y = m(x)
torch.sum(y).backward()
t3 = datetime.now()
print(t3 - t2)
I would greatly appreciate any comment on this. Thanks in advance, and thanks for the implementation!
Metadata
Metadata
Assignees
Labels
No labels