31
31
ELEMENTWISE_TYPE_PROMOTION_KIND ,
32
32
type_to_dtype ,
33
33
)
34
- from torch .fx .experimental .symbolic_shapes import definitely_true , guard_size_oblivious
34
+ from torch .fx .experimental .symbolic_shapes import (
35
+ definitely_true ,
36
+ guard_else_false ,
37
+ guard_else_true ,
38
+ )
35
39
36
40
from . import config , inductor_prims
37
41
from .utils import (
@@ -261,13 +265,13 @@ def bmm(
261
265
batch2 : torch .Tensor ,
262
266
) -> torch .Tensor :
263
267
if config .coordinate_descent_tuning and self .device .type != "cpu" :
264
- if guard_size_oblivious (self .shape [1 ] == 1 ) or guard_size_oblivious (
268
+ if guard_else_false (self .shape [1 ] == 1 ) or guard_else_false (
265
269
batch2 .shape [2 ] == 1
266
270
):
267
271
out = (self .unsqueeze (- 1 ) * batch2 .unsqueeze (1 )).sum (dim = 2 )
268
272
return out
269
273
if self .device .type == "cpu" :
270
- if guard_size_oblivious (self .size (1 ) == 1 ) and guard_size_oblivious (
274
+ if guard_else_false (self .size (1 ) == 1 ) and guard_else_false (
271
275
batch2 .size (- 1 ) == 1
272
276
):
273
277
counters ["inductor" ]["decompose_bmm" ] += 1
@@ -287,16 +291,14 @@ def addmm(
287
291
alpha : torch .types .Number = 1 ,
288
292
) -> torch .Tensor :
289
293
if self .device .type == "cpu" :
290
- if guard_size_oblivious (mat1 .size (0 ) == 1 ) and guard_size_oblivious (
291
- mat2 .size (- 1 ) == 1
292
- ):
294
+ if guard_else_false (mat1 .size (0 ) == 1 ) and guard_else_false (mat2 .size (- 1 ) == 1 ):
293
295
counters ["inductor" ]["decompose_addmm" ] += 1
294
296
out = torch .sum (
295
297
mat1 .squeeze (0 ) * mat2 .squeeze (- 1 ), dim = 0 , keepdim = True
296
298
).unsqueeze (0 )
297
299
return alpha * out + beta * self
298
300
if (
299
- guard_size_oblivious (mat1 .size (0 ) == 1 )
301
+ guard_else_false (mat1 .size (0 ) == 1 )
300
302
and definitely_true (mat2 .size (0 ) <= 16 )
301
303
and definitely_true (mat2 .size (1 ) <= 16 )
302
304
):
@@ -315,21 +317,21 @@ def mm(
315
317
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
316
318
# todo: Look into why and fix it (hopefully)
317
319
if config .coordinate_descent_tuning and self .device .type != "cpu" :
318
- if guard_size_oblivious (self .shape [0 ] == 1 ) or guard_size_oblivious (
320
+ if guard_else_false (self .shape [0 ] == 1 ) or guard_else_false (
319
321
input2 .shape [1 ] == 1
320
322
):
321
323
return (self .unsqueeze (2 ) * input2 .unsqueeze (0 )).sum (dim = 1 )
322
324
if self .device .type == "cpu" :
323
325
if (
324
- guard_size_oblivious (self .size (- 1 ) == 1 )
325
- and guard_size_oblivious (self .size (0 ) > 0 )
326
- and guard_size_oblivious (input2 .size (0 ) == 1 )
326
+ guard_else_false (self .size (- 1 ) == 1 )
327
+ and guard_else_true (self .size (0 ) > 0 )
328
+ and guard_else_false (input2 .size (0 ) == 1 )
327
329
and (self .dtype == input2 .dtype )
328
330
and definitely_true ((torch .numel (self ) + torch .numel (input2 )) <= 32 )
329
331
):
330
332
counters ["inductor" ]["decompose_mm" ] += 1
331
333
return torch .cat ([self [i , :] * input2 for i in range (self .size (0 ))])
332
- if guard_size_oblivious (self .size (0 ) == 1 ) and guard_size_oblivious (
334
+ if guard_else_false (self .size (0 ) == 1 ) and guard_else_false (
333
335
input2 .size (- 1 ) == 1
334
336
):
335
337
counters ["inductor" ]["decompose_mm" ] += 1
@@ -348,8 +350,6 @@ def cat(
348
350
tensors : list [torch .Tensor ],
349
351
dim : int = 0 ,
350
352
) -> torch .Tensor :
351
- from torch .fx .experimental .symbolic_shapes import guard_size_oblivious
352
-
353
353
def non_empty_tensor (x : torch .Tensor ) -> bool :
354
354
# For better or worse, this is a valid cat:
355
355
#
@@ -367,10 +367,10 @@ def non_empty_tensor(x: torch.Tensor) -> bool:
367
367
# runtime assert forcing u0 to be zero. So if this hasn't happened,
368
368
# we know that the unbacked SymInt has appropriate size and there are
369
369
# no problems.
370
- if len (x .shape ) == 1 and guard_size_oblivious (x .shape [0 ] == 0 ):
370
+ if len (x .shape ) == 1 and guard_else_false (x .shape [0 ] == 0 ):
371
371
return False
372
372
373
- if dim < len (x .shape ) and guard_size_oblivious (x .shape [dim ] == 0 ):
373
+ if dim < len (x .shape ) and guard_else_false (x .shape [dim ] == 0 ):
374
374
return False
375
375
376
376
return True
0 commit comments