8000 Fix an import loop · pytorch/pytorch@a61ac5c · GitHub
  • [go: up one dir, main page]

    Skip to content

    Commit a61ac5c

    Browse files
    cxxxsfacebook-github-bot
    authored andcommitted
    Fix an import loop
    Summary: We ran into the following import loop when testing aps: ``` Traceback (most recent call last): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/forkserver.py", line 274, in main code = _serve_one(child_r, fds, File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/forkserver.py", line 313, in _serve_one code = spawn._main(child_r, parent_sentinel) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/spawn.py", line 125, in _main prepare(preparation_data) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/spawn.py", line 234, in prepare _fixup_main_from_name(data['init_main_from_name']) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/multiprocessing/spawn.py", line 258, in _fixup_main_from_name main_content = runpy.run_module(mod_name, File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/runpy.py", line 224, in run_module return _run_module_code(code, init_globals, run_name, mod_spec) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/runpy.py", line 96, in _run_module_code _run_code(code, mod_globals, init_globals, File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/runtime/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/aps_models/ads/icvr/icvr_launcher.py", line 29, in <module> class ICVRConfig(AdsComboLauncherConfig): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/aps_models/ads/common/ads_launcher.py", line 249, in <module> class AdsComboLauncherConfig(AdsConfig): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/aps_models/ads/common/app_config.py", line 16, in <module> class AdsConfig(RecTrainAppConfig): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/apf/rec/config_def.py", line 47, in <module> class EmbeddingKernelConfig: File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/apf/rec/config_def.py", line 52, in EmbeddingKernelConfig cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torchrec/distributed/types.py", line 501, in <module> class ParameterSharding: File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torchrec/distributed/types.py", line 527, in ParameterSharding sharding_spec: Optional[ShardingSpec] = None File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torch/distributed/_shard/sharding_spec/api.py", line 48, in <module> class ShardingSpec(ABC): File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torch/distributed/_shard/sharding_spec/api.py", line 55, in ShardingSpec tensor_properties: sharded_tensor_meta.TensorProperties, File "/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torch/distributed/_shard/sharded_tensor/__init__.py", line 21, in <module> def empty(sharding_spec: shard_spec.ShardingSpec, ImportError: cannot import name 'ShardingSpec' from partially initialized module 'torch.distributed._shard.sharding_spec.api' (most likely due to a circular import) (/mnt/xarfuse/uid-26572/e04e8e0a-seed-nspid4026534049_cgpid5889271-ns-4026534028/torch/distributed/_shard/sharding_spec/api.py) ``` Using future annotations to mitigate. Test Plan: ``` hg update 1b1b3154616b70fd3325c467db1f7e0f70182a74 CUDA_VISIBLE_DEVICES=1,2 buck2 run @//mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_ctr_cvr_rep ``` Differential Revision: D53685582
    1 parent 7ad4ab4 commit a61ac5c

    File tree

    1 file changed

    +12
    -9
    lines changed

    1 file changed

    +12
    -9
    lines changed

    torch/distributed/_shard/sharded_tensor/__init__.py

    Lines changed: 12 additions & 9 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1,9 +1,12 @@
    1-
    21
    import functools
    3-
    from typing import List
    2+
    from typing import List, TYPE_CHECKING
    43

    54
    import torch
    6-
    import torch.distributed._shard.sharding_spec as shard_spec
    5+
    6+
    if TYPE_CHECKING:
    7+
    from torch.distributed._shard.sharding_spec import ShardingSpec
    8+
    else:
    9+
    ShardingSpec = "ShardingSpec"
    710

    811
    from .api import (
    912
    _CUSTOM_SHARDED_OPS,
    @@ -18,7 +21,7 @@
    1821
    from torch.distributed._shard.op_registry_utils import _decorator_func
    1922

    2023

    21-
    def empty(sharding_spec: shard_spec.ShardingSpec,
    24+
    def empty(sharding_spec: ShardingSpec,
    2225
    *size,
    2326
    dtype=None,
    2427
    layout=torch.strided,
    @@ -70,7 +73,7 @@ def empty(sharding_spec: shard_spec.ShardingSpec,
    7073
    init_rrefs=init_rrefs,
    7174
    )
    7275

    73-
    def ones(sharding_spec: shard_spec.ShardingSpec,
    76+
    def ones(sharding_spec: ShardingSpec,
    7477
    *size,
    7578
    dtype=None,
    7679
    layout=torch.strided,
    @@ -121,7 +124,7 @@ def ones(sharding_spec: shard_spec.ShardingSpec,
    121124
    init_rrefs=init_rrefs
    122125
    )
    123126

    124-
    def zeros(sharding_spec: shard_spec.ShardingSpec,
    127+
    def zeros(sharding_spec: ShardingSpec,
    125128
    *size,
    126129
    dtype=None,
    127130
    layout=torch.strided,
    @@ -172,7 +175,7 @@ def zeros(sharding_spec: shard_spec.ShardingSpec,
    172175
    init_rrefs=init_rrefs
    173176
    )
    174177

    175-
    def full(sharding_spec: shard_spec.ShardingSpec,
    178+
    def full(sharding_spec: ShardingSpec,
    176179
    size,
    177180
    fill_value,
    178181
    *,
    @@ -225,7 +228,7 @@ def full(sharding_spec: shard_spec.ShardingSpec,
    225228
    torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type]
    226229
    return sharded_tensor
    227230

    228-
    def rand(sharding_spec: shard_spec.ShardingSpec,
    231+
    def rand(sharding_spec: ShardingSpec,
    229232
    *size,
    230233
    dtype=None,
    231234
    layout=torch.strided,
    @@ -278,7 +281,7 @@ def rand(sharding_spec: shard_spec.ShardingSpec,
    278281
    torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type]
    279282
    return sharded_tensor
    280283

    281-
    def randn(sharding_spec: shard_spec.ShardingSpec,
    284+
    def randn(sharding_spec: ShardingSpec,
    282285
    *size,
    283286
    dtype=None,
    284287
    layout=torch.strided,

    0 commit comments

    Comments
     (0)
    0