8000 Fix an import loop by cxxxs · Pull Request #119820 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix an import loop #119820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions torch/distributed/_shard/sharded_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@

import functools
from typing import List
from typing import List, TYPE_CHECKING

import torch
import torch.distributed._shard.sharding_spec as shard_spec

if TYPE_CHECKING:
from torch.distributed._shard.sharding_spec import ShardingSpec
else:
ShardingSpec = "ShardingSpec"

from .api import (
_CUSTOM_SHARDED_OPS,
Expand All @@ -18,7 +21,7 @@
from torch.distributed._shard.op_registry_utils import _decorator_func


def empty(sharding_spec: shard_spec.ShardingSpec,
def empty(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
Expand Down Expand Up @@ -70,7 +73,7 @@ def empty(sharding_spec: shard_spec.ShardingSpec,
init_rrefs=init_rrefs,
)

def ones(sharding_spec: shard_spec.ShardingSpec,
def ones(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
Expand Down Expand Up @@ -121,7 +124,7 @@ def ones(sharding_spec: shard_spec.ShardingSpec,
init_rrefs=init_rrefs
)

def zeros(sharding_spec: shard_spec.ShardingSpec,
def zeros(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
Expand 8000 Down Expand Up @@ -172,7 +175,7 @@ def zeros(sharding_spec: shard_spec.ShardingSpec,
init_rrefs=init_rrefs
)

def full(sharding_spec: shard_spec.ShardingSpec,
def full(sharding_spec: ShardingSpec,
size,
fill_value,
*,
Expand Down Expand Up @@ -225,7 +228,7 @@ def full(sharding_spec: shard_spec.ShardingSpec,
torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type]
return sharded_tensor

def rand(sharding_spec: shard_spec.ShardingSpec,
def rand(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
Expand Down Expand Up @@ -278,7 +281,7 @@ def rand(sharding_spec: shard_spec.ShardingSpec,
torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type]
return sharded_tensor

def randn(sharding_spec: shard_spec.ShardingSpec,
def randn(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
Expand Down
0