@@ -11664,6 +11664,23 @@ def test_skip_init(self, device):
11664
11664
self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device)
11665
11665
self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight))
11666
11666
11667
+ @dtypes(torch.float)
11668
+ @dtypesIfCUDA(torch.double, torch.float, torch.half)
11669
+ def test_multihead_attention_qkv_diff_size(self, device, dtype):
11670
+ embed_dim = 128
11671
+ k_dim = 64
11672
+ v_dim = 32
11673
+ num_heads = 8
11674
+ sl = 10
11675
+ bs = 8
11676
+ model = nn.MultiheadAttention(embed_dim, num_heads, kdim=k_dim, vdim=v_dim).to(device).to(dtype)
11677
+ q = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype)
11678
+ k = torch.randn(sl, bs, k_dim, device=device, dtype=dtype)
11679
+ v = torch.randn(sl, bs, v_dim, device=device, dtype=dtype)
11680
+ out = model(q, k, v)
11681
+ self.assertEqual(q.size(), out[0].size())
11682
+ self.assertEqual(dtype, out[0].dtype)
11683
+
11667
11684
@dtypes(torch.float)
11668
11685
@dtypesIfCUDA(torch.double, torch.float, torch.half)
11669
11686
def test_transformerencoderlayer(self, device, dtype):
0 commit comments