8000 Cuda reduce in a consistent direction by bunelr · Pull Request #1542 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Cuda reduce in a consistent direction #1542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Nicer looking fix for the Min/MaxReduce
  • Loading branch information
bunelr committed May 12, 2017
commit 53dfce956a4923ead2881a7e65f1e338f815e63c
12 changes: 6 additions & 6 deletions torch/lib/THC/THCTensorMathReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ kernelTransformReduceOuterDimIndex(K *tgt1,

for (unsigned col = 0; col < row_size; ++col) {
// +1 for Lua index
acc = binary_op(thrust::make_pair<K, Index>(*src, col + TH_INDEX_BASE),
acc);
acc = binary_op(acc,
thrust::make_pair<K, Index>(*src, col + TH_INDEX_BASE));
src += num_irows;
}

Expand Down Expand Up @@ -550,7 +550,7 @@ kernelTransformReduceInnermostDimIndex(K *tgt1,
K *src = src_ + row * row_size;
// Sequential reduction within a thread.
for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) {
acc = binary_op(thrust::make_pair<K, Index>(src[col], col + TH_INDEX_BASE), acc);
acc = binary_op(acc, thrust::make_pair<K, Index>(src[col], col + TH_INDEX_BASE));
}
}

Expand All @@ -568,7 +568,7 @@ kernelTransformReduceInnermostDimIndex(K *tgt1,
thrust::make_pair<K, Index>(sline[threadIdx.x], iline[threadIdx.x]);
thrust::pair<K, Index> arg2 =
thrust::make_pair<K, Index>(sline[threadIdx.x + s], iline[threadIdx.x + s]);
thrust::pair<K, Index> res = binary_op(arg2, arg1);
thrust::pair<K, Index> res = binary_op(arg1, arg2);

sline[threadIdx.x] = res.first;
iline[threadIdx.x] = res.second;
Expand Down Expand Up @@ -665,7 +665,7 @@ struct MaxValuePair {
__host__ __device__
thrust::pair<T, Index> operator()(const thrust::pair<T, Index>& a,
const thrust::pair<T, Index>& b) {
return THCNumerics<T>::gt(a.first, b.first) ? a : b;
return THCNumerics<T>::ge(a.first, b.first) ? a : b;
}
};

Expand All @@ -674,7 +674,7 @@ struct MinValuePair {
__host__ __device__
thrust::pair<T, Index> operator()(const thrust::pair<T, Index>& a,
const thrust::pair<T, Index>& b) {
return THCNumerics<T>::lt(a.first, b.first) ? a : b;
return THCNumerics<T>::le(a.first, b.first) ? a : b;
}
};

Expand Down
0