8000 Fix accidental typing reversions · pytorch/pytorch@efb9798 · GitHub
[go: up one dir, main page]

Skip to content

Commit efb9798

Browse files
committed
Fix accidental typing reversions
1 parent 512c4f9 commit efb9798

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

torch/_decomp/decompositions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,7 +2266,7 @@ def native_batch_norm_backward(
22662266
broadcast_mask: list[int] = [1] * input_rank
22672267
broadcast_mask[axis] = input_shape[axis]
22682268

2269-
reduction_axes: List[int] = []
2269+
reduction_axes: list[int] = []
22702270
for i in range(input_rank):
22712271
if i != axis:
22722272
reduction_axes.append(i)
@@ -4455,7 +4455,7 @@ def matmul(tensor1, tensor2, *, is_out=False):
44554455
m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
44564456
p = tensor2.size(-1) if dim_tensor2 > 1 else 1
44574457

4458-
batch_tensor2: List[int] = []
4458+
batch_tensor2: list[int] = []
44594459
# TODO: handling of slice
44604460
for i in range(dim_tensor2 - 2):
44614461
batch_tensor2.append(tensor2.size(i))

torch/_decomp/decompositions_for_jvp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def native_batch_norm_backward(
251251
broadcast_mask = [1] * input_rank
252252
broadcast_mask[axis] = input_shape[axis]
253253

254-
reduction_axes: List[int] = []
254+
reduction_axes: list[int] = []
255255
for i in range(input_rank):
256256
if i != axis:
257257
reduction_axes.append(i)

torch/_inductor/kernel/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ def lower_cpu(
936936
+ mask_graph_placeholder_inps
937937
+ list(mask_mod_other_buffers)
938938
)
939-
fake_buffers: List[Buffer] = [item.data.data for item in buffer_list if isinstance(item, TensorBox)] # type: ignore[attr-defined]
939+
fake_buffers: list[Buffer] = [item.data.data for item in buffer_list if isinstance(item, TensorBox)] # type: ignore[attr-defined]
940940

941941
(
942942
query,

torch/distributed/tensor/_ops/_einsum_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def gen_einsum_strategies(
149149

150150
# linearity strategy
151151
if linearity:
152-
linearity_placement_list: List[Placement] = [Partial()]
152+
linearity_placement_list: list[Placement] = [Partial()]
153153
linearity_placement_list.extend(Partial() for input_dim in input_dims)
154154
mesh_dim_strategies.append(linearity_placement_list)
155155

0 commit comments

Comments
 (0)
0