8000 Revert "In Inductor, be willing to generate deferred runtime asserts … · pytorch/pytorch@f69bf00 · GitHub
[go: up one dir, main page]

Skip to content

Commit f69bf00

Browse files
Revert "In Inductor, be willing to generate deferred runtime asserts when unbacked (#137097)"
This reverts commit 4304c68. Reverted #137097 on behalf of https://github.com/huydhn due to Sorry for reverting your change, it seems to increase the compilation time a lot causing some jobs to timeout ([comment](#137097 (comment)))
1 parent eea1f79 commit f69bf00

File tree

3 files changed

+3
-98
lines changed

3 files changed

+3
-98
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch._inductor
1515
import torch._inductor.config
1616
import torch.nn as nn
17-
import torch.nn.functional as F
1817
from torch._dynamo.testing import rand_strided, same
1918
from torch._dynamo.utils import counters
2019
from torch._inductor import config
@@ -815,97 +814,6 @@ def forward(self, x, bias, scale_a, scale_b):
815814
dynamic_shapes=dynamic_shapes,
816815
)
817816

818-
def test_tile_positional_embedding(self):
819-
class TilePositionalEmbedding(nn.Module):
820-
"""
821-
Positional embedding for tiles, different for every tile, same for every token within a tile.
822-
823-
Notice that tile is different from patch (token). For details, please check the documentation of
824-
:class:`torchtune.modules.vision_transformer.VisionTransformer`.
825-
826-
Args:
827-
max_num_tiles (int): The maximum number of tiles an image can be divided into.
828-
embed_dim (int): The dimensionality of each tile embedding.
829-
"""
830-
831-
def __init__(
832-
self,
833-
max_num_tiles: int,
834-
embed_dim: int,
835-
):
836-
super().__init__()
837-
self.max_num_tiles = max_num_tiles
838-
self.embed_dim = embed_dim
839-
840-
scale = embed_dim**-0.5
841-
self.embedding = nn.Parameter(
842-
scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim)
843-
)
844-
self.gate = nn.Parameter(torch.zeros(1))
845-
846-
def forward(
847-
self, x: torch.Tensor, aspect_ratio: torch.Tensor
848-
) -> torch.Tensor:
849-
"""
850-
args:
851-
x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
852-
aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
853-
representing the aspect ratio of the image before tile-cropping, e.g. (2,1).
854-
returns:
855-
torch.Tensor: The input tensor with added positional embeddings.
856-
"""
857-
bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape
858-
torch._check(n_tiles <= self.max_num_tiles)
859-
860-
for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio):
861-
# When we batch images, all are padded to the same amount of tiles.
862-
# The aspect_ratio lets us know the non padded tiles for each image.
863-
# We only add positional encoding to those.
864-
n_tiles_h = n_tiles_h.item()
865-
n_tiles_w = n_tiles_w.item()
866-
867-
n_non_padded_tiles = int(n_tiles_h * n_tiles_w)
868-
869-
# We get only the positional encoding for non padded tiles,
870-
# i.e. n_tiles_h, n_tiles_w.
871-
torch._check_is_size(n_tiles_h)
872-
torch._check_is_size(n_tiles_w)
873-
torch._check(n_tiles_h > 0)
874-
torch._check(n_tiles_w > 0)
875-
torch._check(n_tiles_h <= self.max_num_tiles)
876-
torch._check(n_tiles_w <= self.max_num_tiles)
877-
padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1))
878-
# pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]
879-
pos_embed = padded_embedding.narrow(0, 0, n_tiles_h).narrow(
880-
1, 0, n_tiles_w
881-
)
882-
883-
# Add pos encoding to the non padded tiles.
884-
pos_embed = pos_embed.clone()
885-
pos_embed = pos_embed.view(n_non_padded_tiles, 1, self.embed_dim)
886-
887-
x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0))
888-
torch._check_is_size(n_non_padded_tiles)
889-
torch._check(n_non_padded_tiles < x.size(1))
890-
# x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed
891-
updating = x.narrow(0, batch_idx, batch_idx + 1).narrow(
892-
1, 0, n_non_padded_tiles
893-
)
894-
# updating += pos_embed * self.gate.tanh()
895-
updating.add_(pos_embed * self.gate.tanh())
896-
# x = x[:, :n_tiles, :, :]
897-
x = x.narrow(1, 0, n_tiles)
898-
899-
return x
900-
901-
x = torch.ones(1, 4, 1600, 1280, device=self.device)
902-
aspect_ratio = torch.tensor([[2, 2]], device=self.device)
903-
904-
self.check_model(
905-
TilePositionalEmbedding(4, 1280),
906-
(x, aspect_ratio),
907-
)
908-
909817
def test_poi_multiple_dynamic(self):
910818
class Model(torch.nn.Module):
911819
def __init__(self) -> None:

torch/_inductor/sizevars.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,16 +415,14 @@ def guard_equals(self, left: Expr, right: Expr) -> Expr:
415415
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
416416
if isinstance(right, Expr):
417417
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
418-
assert self.shape_env.defer_runtime_assert(
419-
sympy.Eq(left, right), "guard_equals"
420-
)
418+
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
421419
return left
422420

423421
def guard_leq(self, left: Expr, right: Expr) -> None:
424422
return self.guard_lt(left, right + 1)
425423

426424
def guard_lt(self, left: Expr, right: Expr) -> None:
427-
assert self.shape_env.defer_runtime_assert(sympy.Lt(left, right), "guard_lt")
425+
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
428426

429427
def guarded_order(self, seq):
430428
"""

torch/fx/experimental/symbolic_shapes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6356,8 +6356,7 @@ def defer_runtime_assert(
63566356
if not self._suppress_guards_tls():
63576357
# If you're here because of this assert, read Note [Backwards runtime asserts]
63586358
# in torch/_inductor/graph.py
6359-
if self.runtime_asserts_frozen:
6360-
log.warning("runtime_asserts_frozen but then got %s", expr)
6359+
assert not self.runtime_asserts_frozen, expr
63616360

63626361
self._check_frozen(expr, sympy.true)
63636362

0 commit comments

Comments
 (0)
0