8000 [tf.data] Fix broken sparse/ragged iterators on TPU devices. · IBMZ-Linux-OSS-Python/tensorflow@4527fc0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4527fc0

Browse files
mrrytensorflower-gardener
authored andcommitted
[tf.data] Fix broken sparse/ragged iterators on TPU devices.
In a recent change, we added a colocation constraint between iterator ops and the sparse/ragged decoding ops that transform their outputs into structured tensors. This broke some workloads that attempted to prefetch SparseTensor and RaggedTensor objects to TPU memory, because the relevant decoding kernels were not registered for `DEVICE_TPU`. This change adds the missing kernel registrations, using the same host-memory annotations that are present for `DEVICE_GPU` and preserving the previous behavior (when the decoding ops would fall back to running on some `DEVICE_CPU`... although not necessarily in the same process). PiperOrigin-RevId: 673113849
1 parent 2294b26 commit 4527fc0

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

tensorflow/core/kernels/deserialize_sparse_variant_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,14 @@ REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
375375
.HostMemory("sparse_values")
376376
.HostMemory("sparse_shape"),
377377
DeserializeSparseOp)
378+
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
379+
.Device(DEVICE_TPU)
380+
.TypeConstraint<Variant>("Tserialized")
381+
.HostMemory("serialized_sparse")
382+
.HostMemory("sparse_indices")
383+
.HostMemory("sparse_values")
384+
.HostMemory("sparse_shape"),
385+
DeserializeSparseOp)
378386

379387
} // namespace
380388

tensorflow/core/kernels/ragged_tensor_from_variant_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,14 @@ class RaggedTensorFromVariantOp : public OpKernel {
367367
.HostMemory("encoded_ragged") \
368368
.HostMemory("output_nested_splits") \
369369
.HostMemory("output_dense_values"), \
370+
RaggedTensorFromVariantOp<value_type, split_type>) \
371+
REGISTER_KERNEL_BUILDER(Name("RaggedTensorFromVariant") \
372+
.Device(DEVICE_TPU) \
373+
.TypeConstraint<value_type>("Tvalues") \
374+
.TypeConstraint<split_type>("Tsplits") \
375+
.HostMemory("encoded_ragged") \
376+
.HostMemory("output_nested_splits") \
377+
.HostMemory("output_dense_values"), \
370378
RaggedTensorFromVariantOp<value_type, split_type>);
371379
#define REGISTER_KERNELS(value_type) \
372380
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \

0 commit comments

Comments
 (0)
0