8000 torch.sum (used in Expand.backward) is slow · Issue #700 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
torch.sum (used in Expand.backward) is slow #700
@shawnjhenry

Description

@shawnjhenry

Hi,

I'm trying to build a recurrent net which, instead of using the last hidden state, computes bilinear attention over the last k hidden states. I have it implemented as a module as follows:

import torch as th

from torch.nn import Module, Parameter

from torch.autograd import Variable, Function

from torch.nn.functional import relu, softmax


class LocalRNN(Module):

	def __init__(self, window_size, hidden_size):
		super(LocalGlobalRNN, self).__init__()

		self.window_size = window_size
		self.hidden_size = hidden_size

		self.att_W = Parameter(th.Tensor(1, hidden_size, hidden_size))

		self.i2h_W = Parameter(th.Tensor(hidden_size, hidden_size))
		self.h2h_W = Parameter(th.Tensor(hidden_size, hidden_size))

		self.i2h_b = Parameter(th.Tensor(hidden_size))
		self.h2h_b = Parameter(th.Tensor(hidden_size))

		self.reset()

	def reset(self):
		stdv = np.sqrt(2.0/self.hidden_size)

		self.att_W.data.uniform_(-stdv, stdv)

		self.i2h_W.data.uniform_(-stdv, stdv)
		self.h2h_W.data.uniform_(-stdv, stdv)

		self.i2h_b.data.uniform_(-stdv, stdv)
		self.h2h_b.data.uniform_(-stdv, stdv)

		return self

	def forward(self, input, window=None):
		seq_length, batch_size, hidden_size = input.size()

		add_buffer = Variable(input.data.new(batch_size).fill_(1))

		output = []

		if window is None:
			window = [Variable(input.data.new(batch_size, 1, hidden_size).zero_())]*self.window_size

		for t in range(seq_length):
			x_t = input[t]

			w_t = th.cat(window[-self.window_size:], 1)

			att_W = self.att_W.expand(batch_size, hidden_size, hidden_size)

			h_tm1 = th.bmm(softmax(th.bmm(w_t, th.bmm(att_W, x_t.unsqueeze(2)))).transpose(1, 2), w_t).squeeze()

			h_t = relu(th.mm(x_t, self.i2h_W).addr_(add_buffer, self.i2h_b).addmm(h_tm1, self.h2h_W).addr_(add_buffer, self.h2h_b))

			output.append(h_t.unsqueeze(0))
			window.append(h_t.unsqueeze(1))

		output = th.cat(output)

		return output

With an input of size 100 x 50 x 512 (batch_first=False) forward prop takes ~0.6 s, but for some reason autograd back prop takes ~16 s, over 20x longer! Am I using autograd incorrectly (I'm just stacking this module with other modules) or is this a bug? How do I profile the backward graph to find the bottleneck?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0