-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Closed
Labels
Description
Hi,
I'm trying to build a recurrent net which, instead of using the last hidden state, computes bilinear attention over the last k hidden states. I have it implemented as a module as follows:
import torch as th
from torch.nn import Module, Parameter
from torch.autograd import Variable, Function
from torch.nn.functional import relu, softmax
class LocalRNN(Module):
def __init__(self, window_size, hidden_size):
super(LocalGlobalRNN, self).__init__()
self.window_size = window_size
self.hidden_size = hidden_size
self.att_W = Parameter(th.Tensor(1, hidden_size, hidden_size))
self.i2h_W = Parameter(th.Tensor(hidden_size, hidden_size))
self.h2h_W = Parameter(th.Tensor(hidden_size, hidden_size))
self.i2h_b = Parameter(th.Tensor(hidden_size))
self.h2h_b = Parameter(th.Tensor(hidden_size))
self.reset()
def reset(self):
stdv = np.sqrt(2.0/self.hidden_size)
self.att_W.data.uniform_(-stdv, stdv)
self.i2h_W.data.uniform_(-stdv, stdv)
self.h2h_W.data.uniform_(-stdv, stdv)
self.i2h_b.data.uniform_(-stdv, stdv)
self.h2h_b.data.uniform_(-stdv, stdv)
return self
def forward(self, input, window=None):
seq_length, batch_size, hidden_size = input.size()
add_buffer = Variable(input.data.new(batch_size).fill_(1))
output = []
if window is None:
window = [Variable(input.data.new(batch_size, 1, hidden_size).zero_())]*self.window_size
for t in range(seq_length):
x_t = input[t]
w_t = th.cat(window[-self.window_size:], 1)
att_W = self.att_W.expand(batch_size, hidden_size, hidden_size)
h_tm1 = th.bmm(softmax(th.bmm(w_t, th.bmm(att_W, x_t.unsqueeze(2)))).transpose(1, 2), w_t).squeeze()
h_t = relu(th.mm(x_t, self.i2h_W).addr_(add_buffer, self.i2h_b).addmm(h_tm1, self.h2h_W).addr_(add_buffer, self.h2h_b))
output.append(h_t.unsqueeze(0))
window.append(h_t.unsqueeze(1))
output = th.cat(output)
return output
With an input of size 100 x 50 x 512 (batch_first=False) forward prop takes ~0.6 s, but for some reason autograd back prop takes ~16 s, over 20x longer! Am I using autograd incorrectly (I'm just stacking this module with other modules) or is this a bug? How do I profile the backward graph to find the bottleneck?
tczhangzhi