From 2fdd13cba0500ddfa08dafb278c5aea25b7b9171 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sat, 18 Jan 2025 11:38:01 -0500 Subject: [PATCH 1/5] [BE]: Apply ruff PERF401 to torch --- torch/_decomp/__init__.py | 3 +-- torch/_decomp/decompositions.py | 9 ++----- torch/_decomp/decompositions_for_jvp.py | 5 +--- torch/_dynamo/bytecode_transformation.py | 5 +--- torch/_dynamo/graph_region_tracker.py | 4 +-- torch/_dynamo/polyfills/__init__.py | 4 +-- torch/_higher_order_ops/triton_kernel_wrap.py | 9 ++++--- torch/_inductor/codegen/wrapper.py | 10 ++++--- torch/_inductor/ir.py | 5 ++-- torch/_inductor/kernel/flex_attention.py | 5 +--- torch/_inductor/select_algorithm.py | 6 +++-- torch/_inductor/utils.py | 9 ++++--- .../_common_operator_config_utils.py | 16 ++++++------ .../algorithms/model_averaging/utils.py | 8 +++--- .../tensor/_ops/_einsum_strategy.py | 5 ++-- torch/jit/supported_ops.py | 26 +++++++++++-------- torch/nn/parallel/comm.py | 9 +++---- .../onnx/_internal/fx/fx_onnx_interpreter.py | 9 +++---- .../_internal/distributed/distributed_test.py | 19 ++++++-------- .../distributed/rpc/dist_autograd_test.py | 4 +-- torchgen/dest/register_dispatch_key.py | 8 +++--- 21 files changed, 82 insertions(+), 96 deletions(-) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 37b50a2efddf6..83012f475a5b8 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -87,8 +87,7 @@ def _add_op_to_registry(registry, op, fn): overloads.append(op) else: assert isinstance(op, OpOverloadPacket) - for ol in op.overloads(): - overloads.append(getattr(op, ol)) + overloads.extend(getattr(op, ol) for ol in op.overloads()) for op_overload in overloads: if op_overload in registry: diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 9543da39dd141..f0adb2f2f2a73 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2266,10 +2266,7 @@ def native_batch_norm_backward( broadcast_mask: list[int] = [1] * input_rank broadcast_mask[axis] = input_shape[axis] - reduction_axes: list[int] = [] - for i in range(input_rank): - if i != axis: - reduction_axes.append(i) + reduction_axes: List[int] = [i for i in range(input_rank) if i != axis] mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type] norm = 1.0 / num_features @@ -4455,10 +4452,8 @@ def matmul(tensor1, tensor2, *, is_out=False): m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) p = tensor2.size(-1) if dim_tensor2 > 1 else 1 - batch_tensor2: list[int] = [] # TODO: handling of slice - for i in range(dim_tensor2 - 2): - batch_tensor2.append(tensor2.size(i)) + batch_tensor2: List[int] = [tensor2.size(i) for i in range(dim_tensor2 - 2)] # Same optimization for the gradients as that in should_fold # If we're going to broadcast, we force it to go through the should_fold branch diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index 60a19f3200599..2cd88ab960a2a 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -251,10 +251,7 @@ def native_batch_norm_backward( broadcast_mask = [1] * input_rank broadcast_mask[axis] = input_shape[axis] - reduction_axes: list[int] = [] - for i in range(input_rank): - if i != axis: - reduction_axes.append(i) + reduction_axes: List[int] = [i for i in range(input_rank) if i != axis] mean = torch.reshape(mean, broadcast_mask) norm = 1.0 / num_features diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 16de6ef0ce3e8..eb9610891f290 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1602,10 +1602,7 @@ def template(): new_insts.append(inst) insts = new_insts - returns = [] - for inst in insts: - if inst.opname == "RETURN_VALUE": - returns.append(inst) + returns = [inst for inst in insts if inst.opname == "RETURN_VALUE"] if len(returns) == 1 and returns[0] is insts[-1]: # only 1 return at the end - just pop it diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index 9875e448c995b..585cf450bdc4f 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -40,10 +40,8 @@ def _extract_tensor_metadata_for_node_hash( ) -> tuple[Callable[[T], T], tuple[Any, ...]]: from torch._inductor.codecache import _ident, extract_tensor_metadata_for_cache_key - out = [] metadata = extract_tensor_metadata_for_cache_key(x) - for field in fields(metadata): - out.append(getattr(metadata, field.name)) + out = [getattr(metadata, field.name) for field in fields(metadata)] return (_ident, tuple(out)) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 3ef4d94a1385c..7765bb4378a93 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -188,9 +188,7 @@ def foreach_map_fn(*args): if not at_least_one_list: return op(*args[1:]) - out = [] - for unpacked in zip(*new_args): - out.append(op(*unpacked)) + out = [op(*unpacked) for unpacked in zip(*new_args)] return out diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 185f2f5a13091..bff661a52b972 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -1422,10 +1422,11 @@ def call_triton_kernel( return self.call_triton_kernel(new_var, args, kwargs, tx) if isinstance(variable.kernel, Autotuner): - special_param_names = [] - for name in SPECIAL_CONFIG_NAMES: - if name in variable.kernel.fn.arg_names: - special_param_names.append(name) + special_param_names = [ + name + for name in SPECIAL_CONFIG_NAMES + if name in variable.kernel.fn.arg_names + ] if special_param_names: # If the Triton kernel has SPECIAL_CONFIG_NAMES in parameters, those should diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0822ddd191373..2f38a56eb2825 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1741,10 +1741,12 @@ def add_arg(idx, arg, is_constexpr=False, equals_1=False, equals_none=False): # Distinguish between different functions using function id cache_key: list[Any] = [id(kernel.fn)] if len(configs) > 0: - for arg in kwargs.values(): - # We need to key on non tensor arg only in autotune mode - if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): - cache_key.append(arg) + # We need to key on non tensor arg only in autotune mode + cache_key.extend( + arg + for arg in kwargs.values() + if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)) + ) cache_key.append(str(triton_meta)) cache_key = tuple(cache_key) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 34231f0a7ed6a..f5530754c93e7 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5732,8 +5732,9 @@ def get_kernel_and_metadata(self): # type: ignore[no-untyped-def] restore_value_args.extend(kernel.restore_value) if hasattr(kernel, "reset_idx"): - for i in kernel.reset_idx: - reset_to_zero_args.append(kernel.fn.arg_names[i]) + reset_to_zero_args.extend( + kernel.fn.arg_names[i] for i in kernel.reset_idx + ) else: assert hasattr(kernel, "reset_to_zero") reset_to_zero_args.extend(kernel.reset_to_zero) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index b8d6065bc1201..3d07a7dd79b8a 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -897,7 +897,6 @@ def lower_cpu( "torch.compile on current platform is not supported for CPU." ) - fake_buffers: list[Buffer] = [] # noqa: F821 placeholder_inps = [ create_placeholder(name, dtype, query.get_device()) for name, dtype in [ @@ -937,9 +936,7 @@ def lower_cpu( + mask_graph_placeholder_inps + list(mask_mod_other_buffers) ) - for item in buffer_list: - if isinstance(item, TensorBox): - fake_buffers.append(item.data.data) # type: ignore[attr-defined] + fake_buffers: List[Buffer] = [item.data.data for item in buffer_list if isinstance(item, TensorBox)] # type: ignore[attr-defined] ( query, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 01e95756fc2ee..9c2f7442d7e45 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -630,8 +630,10 @@ def modification( ), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" # Handle scatter stores if isinstance(subgraph, list): - for scatter_graph in subgraph: - scatters.append(self._handle_scatter_graph(scatter_graph)) + scatters.extend( + self._handle_scatter_graph(scatter_graph) + for scatter_graph in subgraph + ) elif isinstance(subgraph.data, ir.InputBuffer): out = subgraph.data.make_loader()(()) else: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index bca692af9ad8c..b92eec689d933 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -606,10 +606,11 @@ def get_kernel_metadata(node_schedule, wrapper): # print the aot_autograd graph fragment if single_graph is not None: detailed_metadata.append(f"{wrapper.comment} Graph fragment:") - for n in inductor_nodes: - # TODO(future): maybe refactor torch/fx/graph.py to make it easy to - # generate python code for graph fragments - detailed_metadata.append(f"{wrapper.comment} {n.format_node()}") + # TODO(future): maybe refactor torch/fx/graph.py to make it easy to + # generate python code for graph fragments + detailed_metadata.extend( + f"{wrapper.comment} {n.format_node()}" for n in inductor_nodes + ) return metadata, "\n".join(detailed_metadata) diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index 60f2fe86b12e4..eeb65bf338818 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -714,14 +714,14 @@ def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf ) # fused bn configs - for fused_bn in bn_to_fused_bn.values(): - bn_configs.append( - BackendPatternConfig(fused_bn) - .set_observation_type( - ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 - .set_dtype_configs(dtype_configs) - ) + bn_configs.extend( + BackendPatternConfig(fused_bn) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + for fused_bn in bn_to_fused_bn.values() + ) return bn_configs diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index 0438043a6e740..7efdfe5ec16cf 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -70,9 +70,11 @@ def get_params_to_average( filtered_params.append(param_data) elif isinstance(param, dict): # optimizer.param_groups input - for param_data in param["params"]: - if param_data.grad is not None: - filtered_params.append(param_data) + filtered_params.extend( + param_data + for param_data in param["params"] + if param_data.grad is not None + ) else: raise NotImplementedError( f"Parameter input of type {type(param)} is not supported" diff --git a/torch/distributed/tensor/_ops/_einsum_strategy.py b/torch/distributed/tensor/_ops/_einsum_strategy.py index 0db79ed2f7002..0f9d52670ad49 100644 --- a/torch/distributed/tensor/_ops/_einsum_strategy.py +++ b/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -149,9 +149,8 @@ def gen_einsum_strategies( # linearity strategy if linearity: - linearity_placement_list: list[Placement] = [Partial()] - for input_dim in input_dims: - linearity_placement_list.append(Partial()) + linearity_placement_list: List[Placement] = [Partial()] + linearity_placement_list.extend(Partial() for input_dim in input_dims) mesh_dim_strategies.append(linearity_placement_list) all_mesh_dim_strategies.append(mesh_dim_strategies) diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 791a11a9b3aa7..3c3c85dd72556 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -72,9 +72,11 @@ def is_tensor_method(schema): for elem in dir(torch.Tensor): if not _hidden(elem): schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem) - for schema in schemas: - if is_tensor_method(schema): - methods.append(_emit_schema("Tensor", elem, schema, arg_start=1)) + methods.extend( + _emit_schema("Tensor", elem, schema, arg_start=1) + for schema in schemas + if is_tensor_method(schema) + ) return "Supported Tensor Methods", methods @@ -115,10 +117,12 @@ def _get_nn_functional_ops(): builtin = _find_builtin(getattr(mod, elem)) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) - for schema in schemas: - # remove _tan but not __and__ - if not _hidden(elem): - functions.append(_emit_schema(name, elem, schema)) + # remove _tan but not __and__ + functions.extend( + _emit_schema(name, elem, schema) + for schema in schemas + if not _hidden(elem) + ) return "Supported PyTorch Functions", functions @@ -164,8 +168,9 @@ def _get_torchscript_builtins(): builtin = _find_builtin(fn) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) - for schema in schemas: - functions.append(_emit_schema(mod.__name__, fn.__name__, schema)) + functions.extend( + _emit_schema(mod.__name__, fn.__name__, schema) for schema in schemas + ) return "TorchScript Builtin Functions", functions @@ -271,8 +276,7 @@ def _get_global_builtins(): if fn in op_renames: op_name = op_renames[fn] schemas = torch._C._jit_get_schemas_for_operator(op_name) - for s in schemas: - schematized_ops.append(_emit_schema(None, fn, s, padding=0)) + schematized_ops.extend(_emit_schema(None, fn, s, padding=0) for s in schemas) if len(schemas) > 0: schematized_ops.append("") else: diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 42b3dbd908d64..0b948f1e4cd46 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -156,11 +156,10 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): _flatten_dense_tensors(chunk) for chunk in chunks ] # (num_gpus,) flat_result = reduce_add(flat_tensors, destination) - for t in _unflatten_dense_tensors(flat_result, chunks[0]): - # The unflattened tensors do not share storage, and we don't expose - # base flat tensor anyways, so give them different version counters. - # See NOTE [ Version Counter in comm.*_coalesced ] - output.append(t.data) + # The unflattened tensors do not share storage, and we don't expose + # base flat tensor anyways, so give them different version counters. + # See NOTE [ Version Counter in comm.*_coalesced ] + output.extend(t.data for t in _unflatten_dense_tensors(flat_result, chunks[0])) return tuple(_reorder_tensors_as(output, ref_order)) diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index 1a1cbc9ae922f..60323e9949ad0 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -185,11 +185,10 @@ def _retrieve_or_adapt_input_to_graph_set( onnxscript_graph_building.TorchScriptTensor | None | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - ] = [] - for tensor in onnx_tensor: - sequence_elements.append( - fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None # type: ignore[index, union-attr] - ) + ] = [ + fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None # type: ignore[index, union-attr] + for tensor in onnx_tensor + ] return sequence_elements if isinstance(onnx_tensor, torch.dtype): onnx_tensor = int( # type: ignore[call-overload] diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 6dc5bd4d5c619..773ac2bd91a3f 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -7380,17 +7380,14 @@ def forward(self, x, rank): ) ) - throw_on_early_term_tests = [] - for test_input in models_to_test: - throw_on_early_term_tests.append( - DDPUnevenTestInput( - name=test_input.name, - model=test_input.model, - inp=test_input.inp, - sync_interval=test_input.sync_interval, - throw_on_early_termination=True, - ) - ) + throw_on_early_term_tests = [ + DDPUnevenTestInput( + name=test_input.name, + model=test_input.model, + inp=test_input.inp, + sync_interval=test_input.sync_interval, + throw_on_early_termination=True, + ) for test_input in models_to_test] models_to_test.extend(models_with_sync) models_to_test.extend(throw_on_early_term_tests) diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index e113b87ccebda..0a833aa29f6a9 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -2765,9 +2765,7 @@ def test_gradients_synchronizations(self): dist_autograd.backward(context_id, [x.sum()]) - futs = [] - for remote_layer in remote_layers: - futs.append(remote_layer.rpc_async().gradients(context_id)) + futs = [remote_layer.rpc_async().gradients(context_id) for remote_layer in remote_layers] for i in range(len(futs)): local_gradients = [p.grad for p in local_layers[i].parameters()] diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 015537df12e05..d5e1009219a8a 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -955,10 +955,10 @@ def generate_defn(cpp_sig: CppSignature) -> str: # Go over each output, and check if there is a proxy created for it. # If so, copy it over to the original output. if k is SchemaKind.out or k is SchemaKind.inplace: - for i in range(len(f.func.returns)): - sig_body.append( - f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);" - ) + sig_body.extend( + f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);" + for i in range(len(f.func.returns)) + ) # Destructively return the final tensors # TODO: Do this in translate instead From 14eb45652a1f2ddb071036d1108682b60df78c21 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sat, 18 Jan 2025 12:15:29 -0500 Subject: [PATCH 2/5] Revert decomp changes --- torch/_decomp/__init__.py | 3 ++- torch/_decomp/decompositions.py | 9 +++++++-- torch/_decomp/decompositions_for_jvp.py | 5 ++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 83012f475a5b8..37b50a2efddf6 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -87,7 +87,8 @@ def _add_op_to_registry(registry, op, fn): overloads.append(op) else: assert isinstance(op, OpOverloadPacket) - overloads.extend(getattr(op, ol) for ol in op.overloads()) + for ol in op.overloads(): + overloads.append(getattr(op, ol)) for op_overload in overloads: if op_overload in registry: diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index f0adb2f2f2a73..498a414b0431f 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2266,7 +2266,10 @@ def native_batch_norm_backward( broadcast_mask: list[int] = [1] * input_rank broadcast_mask[axis] = input_shape[axis] - reduction_axes: List[int] = [i for i in range(input_rank) if i != axis] + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type] norm = 1.0 / num_features @@ -4452,8 +4455,10 @@ def matmul(tensor1, tensor2, *, is_out=False): m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) p = tensor2.size(-1) if dim_tensor2 > 1 else 1 + batch_tensor2: List[int] = [] # TODO: handling of slice - batch_tensor2: List[int] = [tensor2.size(i) for i in range(dim_tensor2 - 2)] + for i in range(dim_tensor2 - 2): + batch_tensor2.append(tensor2.size(i)) # Same optimization for the gradients as that in should_fold # If we're going to broadcast, we force it to go through the should_fold branch diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index 2cd88ab960a2a..8377a10b59c73 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -251,7 +251,10 @@ def native_batch_norm_backward( broadcast_mask = [1] * input_rank broadcast_mask[axis] = input_shape[axis] - reduction_axes: List[int] = [i for i in range(input_rank) if i != axis] + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) mean = torch.reshape(mean, broadcast_mask) norm = 1.0 / num_features From b11ff174c49bc43b2f259879ab6ab7dd0339c341 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sun, 19 Jan 2025 11:16:23 -0500 Subject: [PATCH 3/5] Remove two files that require additional typing --- torch/_inductor/select_algorithm.py | 6 ++---- torch/jit/supported_ops.py | 26 +++++++++++--------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 9c2f7442d7e45..01e95756fc2ee 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -630,10 +630,8 @@ def modification( ), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" # Handle scatter stores if isinstance(subgraph, list): - scatters.extend( - self._handle_scatter_graph(scatter_graph) - for scatter_graph in subgraph - ) + for scatter_graph in subgraph: + scatters.append(self._handle_scatter_graph(scatter_graph)) elif isinstance(subgraph.data, ir.InputBuffer): out = subgraph.data.make_loader()(()) else: diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 3c3c85dd72556..791a11a9b3aa7 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -72,11 +72,9 @@ def is_tensor_method(schema): for elem in dir(torch.Tensor): if not _hidden(elem): schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem) - methods.extend( - _emit_schema("Tensor", elem, schema, arg_start=1) - for schema in schemas - if is_tensor_method(schema) - ) + for schema in schemas: + if is_tensor_method(schema): + methods.append(_emit_schema("Tensor", elem, schema, arg_start=1)) return "Supported Tensor Methods", methods @@ -117,12 +115,10 @@ def _get_nn_functional_ops(): builtin = _find_builtin(getattr(mod, elem)) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) - # remove _tan but not __and__ - functions.extend( - _emit_schema(name, elem, schema) - for schema in schemas - if not _hidden(elem) - ) + for schema in schemas: + # remove _tan but not __and__ + if not _hidden(elem): + functions.append(_emit_schema(name, elem, schema)) return "Supported PyTorch Functions", functions @@ -168,9 +164,8 @@ def _get_torchscript_builtins(): builtin = _find_builtin(fn) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) - functions.extend( - _emit_schema(mod.__name__, fn.__name__, schema) for schema in schemas - ) + for schema in schemas: + functions.append(_emit_schema(mod.__name__, fn.__name__, schema)) return "TorchScript Builtin Functions", functions @@ -276,7 +271,8 @@ def _get_global_builtins(): if fn in op_renames: op_name = op_renames[fn] schemas = torch._C._jit_get_schemas_for_operator(op_name) - schematized_ops.extend(_emit_schema(None, fn, s, padding=0) for s in schemas) + for s in schemas: + schematized_ops.append(_emit_schema(None, fn, s, padding=0)) if len(schemas) > 0: schematized_ops.append("") else: From 512c4f9e8aabf74fd4c6348b055c257be8329bf5 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Tue, 4 Feb 2025 10:24:51 -0500 Subject: [PATCH 4/5] Apply changes after rebase --- torch/_decomp/__init__.py | 3 +-- torch/_higher_order_ops/scan.py | 4 +--- torch/_inductor/codegen/triton_utils.py | 5 +---- torch/_inductor/select_algorithm.py | 6 ++++-- torch/jit/supported_ops.py | 26 ++++++++++++++----------- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 37b50a2efddf6..83012f475a5b8 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -87,8 +87,7 @@ def _add_op_to_registry(registry, op, fn): overloads.append(op) else: assert isinstance(op, OpOverloadPacket) - for ol in op.overloads(): - overloads.append(getattr(op, ol)) + overloads.extend(getattr(op, ol) for ol in op.overloads()) for op_overload in overloads: if op_overload in registry: diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index 7eaad02763851..d35c952408e61 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -135,9 +135,7 @@ def add(x: torch.Tensor, y: torch.Tensor): dim = utils.canonicalize_dim(ndim, dim) # Move scan dim to 0 and always perform scan on dim 0 - leaves_xs = [] - for elem in leaves_xs_orig: - leaves_xs.append(torch.movedim(elem, dim, 0)) + leaves_xs = [torch.movedim(elem, dim, 0) for elem in leaves_xs_orig] out = combine_fn( pytree.tree_unflatten(leaves_init, spec_init), diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 193080d360c0d..dae378d1f7a37 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -97,10 +97,7 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: def non_constexpr_signature(signature): - new_signature = [] - for arg in signature: - if not isinstance(arg, ConstexprArg): - new_signature.append(arg) + new_signature = [arg for arg in signature if not isinstance(arg, ConstexprArg)] return new_signature diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 01e95756fc2ee..9c2f7442d7e45 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -630,8 +630,10 @@ def modification( ), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" # Handle scatter stores if isinstance(subgraph, list): - for scatter_graph in subgraph: - scatters.append(self._handle_scatter_graph(scatter_graph)) + scatters.extend( + self._handle_scatter_graph(scatter_graph) + for scatter_graph in subgraph + ) elif isinstance(subgraph.data, ir.InputBuffer): out = subgraph.data.make_loader()(()) else: diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 791a11a9b3aa7..3c3c85dd72556 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -72,9 +72,11 @@ def is_tensor_method(schema): for elem in dir(torch.Tensor): if not _hidden(elem): schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem) - for schema in schemas: - if is_tensor_method(schema): - methods.append(_emit_schema("Tensor", elem, schema, arg_start=1)) + methods.extend( + _emit_schema("Tensor", elem, schema, arg_start=1) + for schema in schemas + if is_tensor_method(schema) + ) return "Supported Tensor Methods", methods @@ -115,10 +117,12 @@ def _get_nn_functional_ops(): builtin = _find_builtin(getattr(mod, elem)) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) - for schema in schemas: - # remove _tan but not __and__ - if not _hidden(elem): - functions.append(_emit_schema(name, elem, schema)) + # remove _tan but not __and__ + functions.extend( + _emit_schema(name, elem, schema) + for schema in schemas + if not _hidden(elem) + ) return "Supported PyTorch Functions", functions @@ -164,8 +168,9 @@ def _get_torchscript_builtins(): builtin = _find_builtin(fn) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) - for schema in schemas: - functions.append(_emit_schema(mod.__name__, fn.__name__, schema)) + functions.extend( + _emit_schema(mod.__name__, fn.__name__, schema) for schema in schemas + ) return "TorchScript Builtin Functions", functions @@ -271,8 +276,7 @@ def _get_global_builtins(): if fn in op_renames: op_name = op_renames[fn] schemas = torch._C._jit_get_schemas_for_operator(op_name) - for s in schemas: - schematized_ops.append(_emit_schema(None, fn, s, padding=0)) + schematized_ops.extend(_emit_schema(None, fn, s, padding=0) for s in schemas) if len(schemas) > 0: schematized_ops.append("") else: From efb97982a2d633d74dee373fa752eeeb6696a5ed Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Tue, 4 Feb 2025 10:32:16 -0500 Subject: [PATCH 5/5] Fix accidental typing reversions --- torch/_decomp/decompositions.py | 4 ++-- torch/_decomp/decompositions_for_jvp.py | 2 +- torch/_inductor/kernel/flex_attention.py | 2 +- torch/distributed/tensor/_ops/_einsum_strategy.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 498a414b0431f..9543da39dd141 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2266,7 +2266,7 @@ def native_batch_norm_backward( broadcast_mask: list[int] = [1] * input_rank broadcast_mask[axis] = input_shape[axis] - reduction_axes: List[int] = [] + reduction_axes: list[int] = [] for i in range(input_rank): if i != axis: reduction_axes.append(i) @@ -4455,7 +4455,7 @@ def matmul(tensor1, tensor2, *, is_out=False): m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) p = tensor2.size(-1) if dim_tensor2 > 1 else 1 - batch_tensor2: List[int] = [] + batch_tensor2: list[int] = [] # TODO: handling of slice for i in range(dim_tensor2 - 2): batch_tensor2.append(tensor2.size(i)) diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index 8377a10b59c73..60a19f3200599 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -251,7 +251,7 @@ def native_batch_norm_backward( broadcast_mask = [1] * input_rank broadcast_mask[axis] = input_shape[axis] - reduction_axes: List[int] = [] + reduction_axes: list[int] = [] for i in range(input_rank): if i != axis: reduction_axes.append(i) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 3d07a7dd79b8a..ceeef5c1ad756 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -936,7 +936,7 @@ def lower_cpu( + mask_graph_placeholder_inps + list(mask_mod_other_buffers) ) - fake_buffers: List[Buffer] = [item.data.data for item in buffer_list if isinstance(item, TensorBox)] # type: ignore[attr-defined] + fake_buffers: list[Buffer] = [item.data.data for item in buffer_list if isinstance(item, TensorBox)] # type: ignore[attr-defined] ( query, diff --git a/torch/distributed/tensor/_ops/_einsum_strategy.py b/torch/distributed/tensor/_ops/_einsum_strategy.py index 0f9d52670ad49..4d258487b2678 100644 --- a/torch/distributed/tensor/_ops/_einsum_strategy.py +++ b/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -149,7 +149,7 @@ def gen_einsum_strategies( # linearity strategy if linearity: - linearity_placement_list: List[Placement] = [Partial()] + linearity_placement_list: list[Placement] = [Partial()] linearity_placement_list.extend(Partial() for input_dim in input_dims) mesh_dim_strategies.append(linearity_placement_list)