@@ -31,6 +31,14 @@ typedef Eigen::GpuDevice GPUDevice;
31
31
namespace tensorflow {
32
32
namespace internal {
33
33
34
+ template <typename T>
35
+ __global__ void ConjugateKernel (int nthreads, const T* __restrict__ src,
36
+ T* __restrict__ dst) {
37
+ GPU_1D_KERNEL_LOOP (idx, nthreads) {
38
+ dst[idx] = Eigen::numext::conj (ldg (src + idx));
39
+ }
40
+ }
41
+
34
42
template <typename T, bool conjugate>
35
43
__global__ void TransposeKernel (int nthreads, const T* __restrict__ src,
36
44
const int32* __restrict__ buf,
@@ -62,6 +70,15 @@ void TransposeSimple(const GPUDevice& d, const Tensor& in,
62
70
CHECK_LT (nelem, kint32max) << " Tensor too large to transpose on GPU" ;
63
71
// Pack strides and permutation into one buffer.
64
72
const int32 ndims = in.dims ();
73
+ GpuLaunchConfig cfg = GetGpuLaunchConfig (nelem, d);
74
+ const T* p = reinterpret_cast <const T*>(in.tensor_data ().data ());
75
+ T* q = reinterpret_cast <T*>(const_cast <char *>((out->tensor_data ().data ())));
76
+ if (conjugate && ndims < 2 ) {
77
+ TF_CHECK_OK (GpuLaunchKernel (ConjugateKernel<T>, cfg.block_count ,
78
+ cfg.thread_per_block , 0 , d.stream (),
79
+ cfg.virtual_thread_count , p, q));
80
+ return ;
81
+ }
65
82
gtl::InlinedVector<int32, 24 > host_buf (ndims * 3 );
66
83
gtl::InlinedVector<int32, 8 > in_strides = ComputeStride<int32>(in.shape ());
67
84
gtl::InlinedVector<int32, 8 > out_strides = ComputeStride<int32>(out->shape ());
@@ -78,9 +95,6 @@ void TransposeSimple(const GPUDevice& d, const Tensor& in,
78
95
// therefore we are doing a sync copy effectively.
79
96
d.memcpyHostToDevice (dev_buf, host_buf.data (), num_bytes);
80
97
// Launch kernel to q[...] = p[...].
81
- const T* p = reinterpret_cast <const T*>(in.tensor_data ().data ());
82
- T* q = reinterpret_cast <T*>(const_cast <char *>((out->tensor_data ().data ())));
83
- GpuLaunchConfig cfg = GetGpuLaunchConfig (nelem, d);
84
98
TF_CHECK_OK (GpuLaunchKernel (
85
99
TransposeKernel<T, conjugate>, cfg.block_count , cfg.thread_per_block , 0 ,
86
100
d.stream (), cfg.virtual_thread_count , p,
@@ -179,8 +193,7 @@ template <typename T, bool conjugate>
179
193
struct Transpose <GPUDevice, T, conjugate> {
180
194
static void run (const GPUDevice& d, const Tensor& in,
181
195
const gtl::ArraySlice<int32> perm, Tensor* out) {
182
- if (in.dims () < 2 ) return ;
183
- if (internal::TransposeUsingTile<T, conjugate>::run (d, in, perm, out)) {
196
+ if (in.dims () > 1 && internal::TransposeUsingTile<T, conjugate>::run (d, in, perm, out)) {
184
197
return ;
185
198
}
186
199
0 commit comments