8000 Fix schema check logic for NJT's chunk op. · pytorch/pytorch@056193e · GitHub
[go: up one dir, main page]

Skip to content

Commit 056193e

Browse files
committed
Fix schema check logic for NJT's chunk op.
1 parent b297e01 commit 056193e

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

torch/nested/_internal/ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None:
112112
arg_type_check_fns = {
113113
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
114114
"jt": lambda x: isinstance(x, NestedTensor)
115-
and x._lengths is None
116-
and x._ragged_idx == 1, # ops with "jt" require contiguous JT only
115+
and x.is_contiguous(), # ops with "jt" require contiguous JT only
117116
"jt_all": lambda x: isinstance(
118117
x, NestedTensor
119118
), # ops with "jt_all" can accept all kinds of JT
@@ -966,7 +965,7 @@ def narrow(func, *args, **kwargs):
966965
return NestedTensor(values, **extract_kwargs(inp))
967966

968967

969-
@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
968+
@register_jagged_func(torch.ops.aten.chunk.default, "self: jt_all, chunks: any, dim: any?")
970969
def chunk_default(func, *args, **kwargs):
971970
_, new_kwargs = normalize_function( # type: ignore[misc]
972971
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True

0 commit comments

Comments
 (0)
0