From d37a03244dbf6008350af1364b3fc3ddb0e16def Mon Sep 17 00:00:00 2001 From: Xiaoya Xiang Date: Tue, 27 Feb 2024 10:17:18 -0800 Subject: [PATCH] Fix an import loop (#119820) Summary: We ran into the following import loop when testing aps: ``` Traceback (most recent call last): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/forkserver.py", line 274, in main code = _serve_one(child_r, fds, File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/forkserver.py", line 313, in _serve_one code = spawn._main(child_r, parent_sentinel) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/spawn.py", line 125, in _main prepare(preparation_data) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/spawn.py", line 234, in prepare _fixup_main_from_name(data['init_main_from_name']) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/spawn.py", line 258, in _fixup_main_from_name main_content = runpy.run_module(mod_name, File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/runpy.py", line 224, in run_module return _run_module_code(code, init_globals, run_name, mod_spec) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/runpy.py", line 96, in _run_module_code _run_code(code, mod_globals, init_globals, File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/aps_models/ads/icvr/icvr_launcher.py", line 29, in class ICVRConfig(AdsComboLauncherConfig): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/aps_models/ads/common/ads_launcher.py", line 249, in class AdsComboLauncherConfig(AdsConfig): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/aps_models/ads/common/app_config.py", line 16, in class AdsConfig(RecTrainAppConfig): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/apf/rec/config_def.py", line 47, in class EmbeddingKernelConfig: File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/apf/rec/config_def.py", line 52, in EmbeddingKernelConfig cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torchrec/distributed/types.py", line 501, in class ParameterSharding: File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torchrec/distributed/types.py", line 527, in ParameterSharding sharding_spec: Optional[ShardingSpec] = None File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torch/distributed/_shard/sharding_spec/api.py", line 48, in class ShardingSpec(ABC): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torch/distributed/_shard/sharding_spec/api.py", line 55, in ShardingSpec tensor_properties: sharded_tensor_meta.TensorProperties, File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torch/distributed/_shard/sharded_tensor/__init__.py", line 21, in def empty(sharding_spec: shard_spec.ShardingSpec, ImportError: cannot import name 'ShardingSpec' from partially initialized module 'torch.distributed._shard.sharding_spec.api' (most likely due to a circular import) (/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torch/distributed/_shard/sharding_spec/api.py) ``` Using future annotations to mitigate. Test Plan: ``` hg update 1b1b3154616b70fd3325c467db1f7e0f70182a74 CUDA_VISIBLE_DEVICES=1,2 buck2 run @//mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_ctr_cvr_rep ``` Reviewed By: fegin Differential Revision: D53685582 --- .../_shard/sharded_tensor/__init__.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index f2723dca4bfd..152c287ee703 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -1,9 +1,12 @@ - import functools -from typing import List +from typing import List, TYPE_CHECKING import torch -import torch.distributed._shard.sharding_spec as shard_spec + +if TYPE_CHECKING: + from torch.distributed._shard.sharding_spec import ShardingSpec +else: + ShardingSpec = "ShardingSpec" from .api import ( _CUSTOM_SHARDED_OPS, @@ -18,7 +21,7 @@ from torch.distributed._shard.op_registry_utils import _decorator_func -def empty(sharding_spec: shard_spec.ShardingSpec, +def empty(sharding_spec: ShardingSpec, *size, dtype=None, layout=torch.strided, @@ -70,7 +73,7 @@ def empty(sharding_spec: shard_spec.ShardingSpec, init_rrefs=init_rrefs, ) -def ones(sharding_spec: shard_spec.ShardingSpec, +def ones(sharding_spec: ShardingSpec, *size, dtype=None, layout=torch.strided, @@ -121,7 +124,7 @@ def ones(sharding_spec: shard_spec.ShardingSpec, init_rrefs=init_rrefs ) -def zeros(sharding_spec: shard_spec.ShardingSpec, +def zeros(sharding_spec: ShardingSpec, *size, dtype=None, layout=torch.strided, @@ -172,7 +175,7 @@ def zeros(sharding_spec: shard_spec.ShardingSpec, init_rrefs=init_rrefs ) -def full(sharding_spec: shard_spec.ShardingSpec, +def full(sharding_spec: ShardingSpec, size, fill_value, *, @@ -225,7 +228,7 @@ def full(sharding_spec: shard_spec.ShardingSpec, torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type] return sharded_tensor -def rand(sharding_spec: shard_spec.ShardingSpec, +def rand(sharding_spec: ShardingSpec, *size, dtype=None, layout=torch.strided, @@ -278,7 +281,7 @@ def rand(sharding_spec: shard_spec.ShardingSpec, torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type] return sharded_tensor -def randn(sharding_spec: shard_spec.ShardingSpec, +def randn(sharding_spec: ShardingSpec, *size, dtype=None, layout=torch.strided,