@@ -132,11 +132,19 @@ def __repr__(self) -> str:
132
132
return f"VR[{ self .lower } , { self .upper } ]"
133
133
134
134
@overload
135
- def __init__ (self : ValueRanges [sympy .Expr ], lower : ExprIn , upper : ExprIn ) -> None :
135
+ def __init__ (
136
+ self : ValueRanges [sympy .Expr ],
137
+ lower : ExprIn ,
138
+ upper : ExprIn ,
139
+ ) -> None :
136
140
...
137
141
138
142
@overload
139
- def __init__ (self : ValueRanges [SympyBoolean ], lower : BoolIn , upper : BoolIn ) -> None :
143
+ def __init__ ( # type: ignore[misc]
144
+ self : ValueRanges [SympyBoolean ],
145
+ lower : BoolIn ,
146
+ upper : BoolIn ,
147
+ ) -> None :
140
148
...
141
149
142
150
def __init__ (self , lower : AllIn , upper : AllIn ) -> None :
@@ -149,26 +157,31 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None:
149
157
raise ValueRangeError (f"Invalid ranges [{ lower } :{ upper } ]" )
150
158
except TypeError as e :
151
159
raise TypeError (f"Could not compare { lower } <= { upper } " ) from e
152
- # Because this is a frozen class
153
- object .__setattr__ (self , "lower" , lower )
154
- object .__setattr__ (self , "upper" , upper )
155
- # Unlike bool/int in Python, we don't report bools are ints
156
- object .__setattr__ (self , "is_bool" , isinstance (lower , SympyBoolean ))
157
- if self .is_bool :
158
- assert isinstance (upper , SympyBoolean ), (lower , upper )
160
+
161
+ is_bool_lower = isinstance (lower , SympyBoolean )
162
+ is_bool_upper = isinstance (upper , SympyBoolean )
163
+ assert is_bool_lower == is_bool_upper , (lower , upper )
159
164
160
165
# Warning: is_int/is_float is best effort. We do pretty well in
161
166
# Dynamo, but in Inductor these attributes are often wrong because we
162
167
# are not very rigorous in dtype analysis. This is also why we need
163
168
# the flexible analysis for is_int: sometimes a sympy.oo pops in for
164
169
# an integer bound. I would /like/ for us not to do this, but it's
165
170
# too hard to push the invariant through right now.
171
+ is_int_lower = isinstance (lower , sympy .Integer )
172
+ is_int_upper = isinstance (upper , sympy .Integer )
166
173
174
+ # Because this is a frozen class
175
+ object .__setattr__ (self , "lower" , lower )
176
+ object .__setattr__ (self , "upper" , upper )
177
+ # Unlike bool/int in Python, we don't report bools are ints
178
+ #
179
+ # NB: is_bool_lower == is_bool_upper, so we only need to check one
180
+ object .__setattr__ (self , "is_bool" , is_bool_lower )
167
181
object .__setattr__ (
168
182
self ,
169
183
"is_int" ,
170
- not self .is_bool
171
- and (isinstance (lower , sympy .Integer ) or isinstance (upper , sympy .Integer )),
184
+ not self .is_bool and (is_int_lower or is_int_upper ),
172
185
)
173
186
"""
174
187
# This assert is just impossible right now, too many sympy bugs
@@ -209,13 +222,15 @@ def tighten(self, other) -> ValueRanges:
209
222
# Intersection
210
223
@overload
211
224
def __and__ (
212
- self : ValueRanges [sympy .Expr ], other : ValueRanges [sympy .Expr ]
225
+ self : ValueRanges [sympy .Expr ],
226
+ other : ValueRanges [sympy .Expr ],
213
227
) -> ValueRanges [sympy .Expr ]:
214
228
...
215
229
216
230
@overload
217
- def __and__ (
218
- self : ValueRanges [SympyBoolean ], other : ValueRanges [SympyBoolean ]
231
+ def __and__ ( # type: ignore[misc]
232
+ self : ValueRanges [SympyBoolean ],
233
+ other : ValueRanges [SympyBoolean ],
219
234
) -> ValueRanges [SympyBoolean ]:
220
235
...
221
236
@@ -239,20 +254,24 @@ def __and__(self: AllVR, other: AllVR) -> AllVR:
239
254
# Union
240
255
@overload
241
256
def __or__ (
242
- self : ValueRanges [sympy .Expr ], other : ValueRanges [sympy .Expr ]
257
+ self : ValueRanges [sympy .Expr ],
258
+ other : ValueRanges [sympy .Expr ],
243
259
) -> ValueRanges [sympy .Expr ]:
244
260
...
245
261
246
262
@overload
247
- def __or__ (
248
- self : ValueRanges [SympyBoolean ], other : ValueRanges [SympyBoolean ]
263
+ def __or__ ( # type: ignore[misc]
264
+ self : ValueRanges [SympyBoolean ],
265
+ other : ValueRanges [SympyBoolean ],
249
266
) -> ValueRanges [SympyBoolean ]:
250
267
...
251
268
252
269
def __or__ (self : AllVR , other : AllVR ) -> AllVR :
253
270
if ValueRanges .unknown () in (self , other ):
254
271
return ValueRanges .unknown ()
255
272
assert self .is_bool == other .is_bool , (self , other )
273
+ assert self .is_int == other .is_int , (self , other )
274
+ assert self .is_float == other .is_float , (self , other )
256
275
if self .is_bool :
257
276
return ValueRanges (
258
277
sympy .And (self .lower , other .lower ), sympy .Or (self .upper , other .upper )
@@ -282,7 +301,7 @@ def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap
282
301
283
302
@overload
284
303
@staticmethod
285
- def wrap (arg : Union [BoolIn , BoolVR ]) -> BoolVR :
304
+ def wrap (arg : Union [BoolIn , BoolVR ]) -> BoolVR : # type: ignore[misc]
286
305
...
287
306
288
307
@staticmethod
@@ -307,7 +326,7 @@ def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
307
326
308
327
@overload
309
328
@staticmethod
310
- def decreasing_map (x : Union [BoolIn , BoolVR ], fn : BoolFn ) -> BoolVR :
329
+ def decreasing_map (x : Union [BoolIn , BoolVR ], fn : BoolFn ) -> BoolVR : # type: ignore[misc]
311
330
...
312
331
313
332
@staticmethod
@@ -330,27 +349,36 @@ def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
330
349
"""Fn is convex and has a minimum at 0."""
331
350
x = ValueRanges .wrap (x )
332
351
if 0 in x :
333
- return ValueRanges (0 , max (fn (x .lower ), fn (x .upper )))
334
- else :
335
- return ValueRanges .monotone_map (x , fn )
352
+ upper = max (fn (x .lower ), fn (x .upper ))
353
+ upper = simple_sympify (upper )
354
+ if isinstance (upper , sympy .Float ) or upper == sympy .oo :
355
+ return ValueRanges (0.0 , upper )
356
+ return ValueRanges (0 , upper )
357
+ return ValueRanges .monotone_map (x , fn )
336
358
337
359
@overload
338
360
@staticmethod
339
361
def coordinatewise_increasing_map (
340
- x : Union [ExprIn , ExprVR ], y : Union [ExprIn , ExprVR ], fn : ExprFn2
362
+ x : Union [ExprIn , ExprVR ],
363
+ y : Union [ExprIn , ExprVR ],
364
+ fn : ExprFn2 ,
341
365
) -> ExprVR :
342
366
...
343
367
344
368
@overload
345
369
@staticmethod
346
- def coordinatewise_increasing_map (
347
- x : Union [BoolIn , BoolVR ], y : Union [BoolIn , BoolVR ], fn : BoolFn2
370
+ def coordinatewise_increasing_map ( # type: ignore[misc]
371
+ x : Union [BoolIn , BoolVR ],
372
+ y : Union [BoolIn , BoolVR ],
373
+ fn : BoolFn2 ,
348
374
) -> BoolVR :
349
375
...
350
376
351
377
@staticmethod
352
378
def coordinatewise_increasing_map (
353
- x : Union [AllIn , AllVR ], y : Union [AllIn , AllVR ], fn : AllFn2
379
+ x : Union [AllIn , AllVR ],
380
+ y : Union [AllIn , AllVR ],
381
+ fn : AllFn2 ,
354
382
) -> AllVR :
355
383
"""
356
384
It's increasing on each coordinate.
@@ -1001,7 +1029,7 @@ def bound_sympy(
1001
1029
if unbounded_vars :
1002
1030
# Give some bounds to the free variables via their SymPy assumptions
1003
1031
# TODO A better way of doing this would be to assign them a range upon creation, as
1004
- # size variables can come with a lower bound of 2, as we specialise on 0 and 1
1032
+ # size variables can come with a lower bound of 2, as we specialize on 0 and 1
1005
1033
unbounded_ranges : Dict [sympy .Symbol , ValueRanges ] = {}
1006
1034
for s in unbounded_vars :
1007
1035
if s .is_integer : # type: ignore[attr-defined]
0 commit comments