|
14 | 14 | import torch._inductor
|
15 | 15 | import torch._inductor.config
|
16 | 16 | import torch.nn as nn
|
17 |
| -import torch.nn.functional as F |
18 | 17 | from torch._dynamo.testing import rand_strided, same
|
19 | 18 | from torch._dynamo.utils import counters
|
20 | 19 | from torch._inductor import config
|
@@ -815,97 +814,6 @@ def forward(self, x, bias, scale_a, scale_b):
|
815 | 814 | dynamic_shapes=dynamic_shapes,
|
816 | 815 | )
|
817 | 816 |
|
818 |
| - def test_tile_positional_embedding(self): |
819 |
| - class TilePositionalEmbedding(nn.Module): |
820 |
| - """ |
821 |
| - Positional embedding for tiles, different for every tile, same for every token within a tile. |
822 |
| -
|
823 |
| - Notice that tile is different from patch (token). For details, please check the documentation of |
824 |
| - :class:`torchtune.modules.vision_transformer.VisionTransformer`. |
825 |
| -
|
826 |
| - Args: |
827 |
| - max_num_tiles (int): The maximum number of tiles an image can be divided into. |
828 |
| - embed_dim (int): The dimensionality of each tile embedding. |
829 |
| - """ |
830 |
| - |
831 |
| - def __init__( |
832 |
| - self, |
833 |
| - max_num_tiles: int, |
834 |
| - embed_dim: int, |
835 |
| - ): |
836 |
| - super().__init__() |
837 |
| - self.max_num_tiles = max_num_tiles |
838 |
| - self.embed_dim = embed_dim |
839 |
| - |
840 |
| - scale = embed_dim**-0.5 |
841 |
| - self.embedding = nn.Parameter( |
842 |
| - scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim) |
843 |
| - ) |
844 |
| - self.gate = nn.Parameter(torch.zeros(1)) |
845 |
| - |
846 |
| - def forward( |
847 |
| - self, x: torch.Tensor, aspect_ratio: torch.Tensor |
848 |
| - ) -> torch.Tensor: |
849 |
| - """ |
850 |
| - args: |
851 |
| - x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). |
852 |
| - aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), |
853 |
| - representing the aspect ratio of the image before tile-cropping, e.g. (2,1). |
854 |
| - returns: |
855 |
| - torch.Tensor: The input tensor with added positional embeddings. |
856 |
| - """ |
857 |
| - bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape |
858 |
| - torch._check(n_tiles <= self.max_num_tiles) |
859 |
| - |
860 |
| - for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): |
861 |
| - # When we batch images, all are padded to the same amount of tiles. |
862 |
| - # The aspect_ratio lets us know the non padded tiles for each image. |
863 |
| - # We only add positional encoding to those. |
864 |
| - n_tiles_h = n_tiles_h.item() |
865 |
| - n_tiles_w = n_tiles_w.item() |
866 |
| - |
867 |
| - n_non_padded_tiles = int(n_tiles_h * n_tiles_w) |
868 |
| - |
869 |
| - # We get only the positional encoding for non padded tiles, |
870 |
| - # i.e. n_tiles_h, n_tiles_w. |
871 |
| - torch._check_is_size(n_tiles_h) |
872 |
| - torch._check_is_size(n_tiles_w) |
873 |
| - torch._check(n_tiles_h > 0) |
874 |
| - torch._check(n_tiles_w > 0) |
875 |
| - torch._check(n_tiles_h <= self.max_num_tiles) |
876 |
| - torch._check(n_tiles_w <= self.max_num_tiles) |
877 |
| - padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) |
878 |
| - # pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] |
879 |
| - pos_embed = padded_embedding.narrow(0, 0, n_tiles_h).narrow( |
880 |
| - 1, 0, n_tiles_w |
881 |
| - ) |
882 |
| - |
883 |
| - # Add pos encoding to the non padded tiles. |
884 |
| - pos_embed = pos_embed.clone() |
885 |
| - pos_embed = pos_embed.view(n_non_padded_tiles, 1, self.embed_dim) |
886 |
| - |
887 |
| - x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) |
888 |
| - torch._check_is_size(n_non_padded_tiles) |
889 |
| - torch._check(n_non_padded_tiles < x.size(1)) |
890 |
| - # x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed |
891 |
| - updating = x.narrow(0, batch_idx, batch_idx + 1).narrow( |
892 |
| - 1, 0, n_non_padded_tiles |
893 |
| - ) |
894 |
| - # updating += pos_embed * self.gate.tanh() |
895 |
| - updating.add_(pos_embed * self.gate.tanh()) |
896 |
| - # x = x[:, :n_tiles, :, :] |
897 |
| - x = x.narrow(1, 0, n_tiles) |
898 |
| - |
899 |
| - return x |
900 |
| - |
901 |
| - x = torch.ones(1, 4, 1600, 1280, device=self.device) |
902 |
| - aspect_ratio = torch.tensor([[2, 2]], device=self.device) |
903 |
| - |
904 |
| - self.check_model( |
905 |
| - TilePositionalEmbedding(4, 1280), |
906 |
| - (x, aspect_ratio), |
907 |
| - ) |
908 |
| - |
909 | 817 | def test_poi_multiple_dynamic(self):
|
910 | 818 | class Model(torch.nn.Module):
|
911 | 819 | def __init__(self) -> None:
|
|
0 commit comments