8000 add test for qkv_size · pytorch/pytorch@1bbb11e · GitHub
[go: up one dir, main page]

8000 Skip to content

Commit 1bbb11e

Browse files
committed
add test for qkv_size
1 parent 2e5a1cb commit 1bbb11e

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

test/test_nn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11664,6 +11664,23 @@ def test_skip_init(self, device):
1166411664
self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device)
1166511665
self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight))
1166611666

11667+
@dtypes(torch.float)
11668+
@dtypesIfCUDA(torch.double, torch.float, torch.half)
11669+
def test_multihead_attention_qkv_diff_size(self, device, dtype):
11670+
embed_dim = 128
11671+
k_dim = 64
11672+
v_dim = 32
11673+
num_heads = 8
11674+
sl = 10
11675+
bs = 8
11676+
model = nn.MultiheadAttention(embed_dim, num_heads, kdim=k_dim, vdim=v_dim).to(device).to(dtype)
11677+
q = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype)
11678+
k = torch.randn(sl, bs, k_dim, device=device, dtype=dtype)
11679+
v = torch.randn(sl, bs, v_dim, device=device, dtype=dtype)
11680+
out = model(q, k, v)
11681+
self.assertEqual(q.size(), out[0].size())
11682+
self.assertEqual(dtype, out[0].dtype)
11683+
1166711684
@dtypes(torch.float)
1166811685
@dtypesIfCUDA(torch.double, torch.float, torch.half)
1166911686
def test_transformerencoderlayer(self, device, dtype):

0 commit comments

Comments
 (0)
0