8000 Fix conjugate transpose for 0/1D complex tensor · IBMZ-Linux-OSS-Python/tensorflow@579775b · GitHub
[go: up one dir, main page]

Skip to content

Commit 579775b

Browse files
wenscarlpak-laura
authored andcommitted
Fix conjugate transpose for 0/1D complex tensor
1 parent 10e2e7c commit 579775b

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

tensorflow/core/kernels/transpose_functor_gpu.cu.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ typedef Eigen::GpuDevice GPUDevice;
3131
namespace tensorflow {
3232
namespace internal {
3333

34+
template <typename T>
35+
__global__ void ConjugateKernel(int nthreads, const T* __restrict__ src,
36+
T* __restrict__ dst) {
37+
GPU_1D_KERNEL_LOOP(idx, nthreads) {
38+
dst[idx] = Eigen::numext::conj(ldg(src + idx));
39+
}
40+
}
41+
3442
template <typename T, bool conjugate>
3543
__global__ void TransposeKernel(int nthreads, const T* __restrict__ src,
3644
const int32* __restrict__ buf,
@@ -62,6 +70,15 @@ void TransposeSimple(const GPUDevice& d, const Tensor& in,
6270
CHECK_LT(nelem, kint32max) << "Tensor too large to transpose on GPU";
6371
// Pack strides and permutation into one buffer.
6472
const int32 ndims = in.dims();
73+
GpuLaunchConfig cfg = GetGpuLaunchConfig(nelem, d);
74+
const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
75+
T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
76+
if (conjugate && ndims < 2) {
77+
TF_CHECK_OK(GpuLaunchKernel(ConjugateKernel<T>, cfg.block_count,
78+
cfg.thread_per_block, 0, d.stream(),
79+
cfg.virtual_thread_count, p, q));
80+
return;
81+
}
6582
gtl::InlinedVector<int32, 24> host_buf(ndims * 3);
6683
gtl::InlinedVector<int32, 8> in_strides = ComputeStride<int32>(in.shape());
6784
gtl::InlinedVector<int32, 8> out_strides = ComputeStride<int32>(out->shape());
@@ -78,9 +95,6 @@ void TransposeSimple(const GPUDevice& d, const Tensor& in,
7895
// therefore we are doing a sync copy effectively.
7996
d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes);
8097
// Launch kernel to q[...] = p[...].
81-
const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
82-
T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
83-
GpuLaunchConfig cfg = GetGpuLaunchConfig(nelem, d);
8498
TF_CHECK_OK(GpuLaunchKernel(
8599
TransposeKernel<T, conjugate>, cfg.block_count, cfg.thread_per_block, 0,
86100
d.stream(), cfg.virtual_thread_count, p,
@@ -179,8 +193,7 @@ template <typename T, bool conjugate>
179193
struct Transpose<GPUDevice, T, conjugate> {
180194
static void run(const GPUDevice& d, const Tensor& in,
181195
const gtl::ArraySlice<int32> perm, Tensor* out) {
182-
if (in.dims() < 2) return;
183-
if (internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, out)) {
196+
if (in.dims() > 1 && internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, out)) {
184197
return;
185198
}
186199

0 commit comments

Comments
 (0)
0