8000 [DTensor] Assert DTensorSpec has valid placements (#158133) (#158133) · pytorch/pytorch@226e35c · GitHub
[go: up one dir, main page]

Skip to content

Commit 226e35c

Browse files
wconstabfacebook-github-bot
authored andcommitted
[DTensor] Assert DTensorSpec has valid placements (#158133) (#158133)
Summary: This helped identify buggy sharding rules during debugging, why not check it in. Approved by: https://github.com/XilunWu, https://github.com/zpcore ghstack dependencies: #158132 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/1839e8d04b81ee6eda0cff6fbfc218a7a600f6f7 Rollback Plan: Differential Revision: D78929245
1 parent abb0bf4 commit 226e35c

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def test_dtensor_constructor_w_graph_break(self):
343343
x = torch.randn(64, 32, requires_grad=True)
344344
spec = DTensorSpec(
345345
mesh,
346-
(Replicate(), Shard(0)),
346+
(Replicate(),),
347347
tensor_meta=TensorMeta(
348348
shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype
349349
),

torch/distributed/tensor/_dtensor_spec.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class DTensorSpec:
3232
def __post_init__(self) -> None:
3333
if not isinstance(self.placements, tuple):
3434
self.placements = tuple(self.placements)
35+
if not len(self.placements) == self.mesh.ndim:
36+
raise ValueError(
37+
f"DTensorSpec requires one placement per mesh dim (mesh.ndim={self.mesh.ndim}), got {self.placements=}"
38+
)
3539
self._hash: Optional[int] = None
3640

3741
def __setattr__(self, attr: str, value: Any) -> None:

0 commit comments

Comments
 (0)
0