|
1 |
| - |
2 | 1 | import functools
|
3 |
| -from typing import List |
| 2 | +from typing import List, TYPE_CHECKING |
4 | 3 |
|
5 | 4 | import torch
|
6 |
| -import torch.distributed._shard.sharding_spec as shard_spec |
| 5 | + |
| 6 | +if TYPE_CHECKING: |
| 7 | + from torch.distributed._shard.sharding_spec import ShardingSpec |
| 8 | +else: |
| 9 | + ShardingSpec = "ShardingSpec" |
7 | 10 |
|
8 | 11 | from .api import (
|
9 | 12 | _CUSTOM_SHARDED_OPS,
|
|
18 | 21 | from torch.distributed._shard.op_registry_utils import _decorator_func
|
19 | 22 |
|
20 | 23 |
|
21 |
| -def empty(sharding_spec: shard_spec.ShardingSpec, |
| 24 | +def empty(sharding_spec: ShardingSpec, |
22 | 25 | *size,
|
23 | 26 | dtype=None,
|
24 | 27 | layout=torch.strided,
|
@@ -70,7 +73,7 @@ def empty(sharding_spec: shard_spec.ShardingSpec,
|
70 | 73 | init_rrefs=init_rrefs,
|
71 | 74 | )
|
72 | 75 |
|
73 |
| -def ones(sharding_spec: shard_spec.ShardingSpec, |
| 76 | +def ones(sharding_spec: ShardingSpec, |
74 | 77 | *size,
|
75 | 78 | dtype=None,
|
76 | 79 | layout=torch.strided,
|
@@ -121,7 +124,7 @@ def ones(sharding_spec: shard_spec.ShardingSpec,
|
121 | 124 | init_rrefs=init_rrefs
|
122 | 125 | )
|
123 | 126 |
|
124 |
| -def zeros(sharding_spec: shard_spec.ShardingSpec, |
| 127 | +def zeros(sharding_spec: ShardingSpec, |
125 | 128 | *size,
|
126 | 129 | dtype=None,
|
127 | 130 | layout=torch.strided,
|
@@ -172,7 +175,7 @@ def zeros(sharding_spec: shard_spec.ShardingSpec,
|
172 | 175 | init_rrefs=init_rrefs
|
173 | 176 | )
|
174 | 177 |
|
175 |
| -def full(sharding_spec: shard_spec.ShardingSpec, |
| 178 | +def full(sharding_spec: ShardingSpec, |
176 | 179 | size,
|
177 | 180 | fill_value,
|
178 | 181 | *,
|
@@ -225,7 +228,7 @@ def full(sharding_spec: shard_spec.ShardingSpec,
|
225 | 228 | torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type]
|
226 | 229 | return sharded_tensor
|
227 | 230 |
|
228 |
| -def rand(sharding_spec: shard_spec.ShardingSpec, |
| 231 | +def rand(sharding_spec: ShardingSpec, |
229 | 232 | *size,
|
230 | 233 | dtype=None,
|
231 | 234 | layout=torch.strided,
|
@@ -278,7 +281,7 @@ def rand(sharding_spec: shard_spec.ShardingSpec,
|
278 | 281 | torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type]
|
279 | 282 | return sharded_tensor
|
280 | 283 |
|
281 |
| -def randn(sharding_spec: shard_spec.ShardingSpec, |
| 284 | +def randn(sharding_spec: ShardingSpec, |
282 | 285 | *size,
|
283 | 286 | dtype=None,
|
284 | 287 | layout=torch.strided,
|
|
0 commit comments