8000 Add `max_pool3d` backward pass for MPS by kurtamohler · Pull Request #157498 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions aten/src/ATen/native/mps/kernels/Pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,14 @@ struct PoolingParams {
_ARRAY_NS::array<int64_t, N - 2> padding;
_ARRAY_NS::array<int64_t, N - 2> dilation;
};

template <unsigned N = 5>
struct PoolingBackwardParams {
int32_t dims;
int32_t pooling_dims;
_ARRAY_NS::array<int64_t, N> grad_input_sizes;
_ARRAY_NS::array<int64_t, N> grad_input_strides;
_ARRAY_NS::array<int64_t, N> grad_output_sizes;
_ARRAY_NS::array<int64_t, N> grad_output_strides;
_ARRAY_NS::array<int64_t, N> indices_strides;
};
162 changes: 137 additions & 25 deletions aten/src/ATen/native/mps/kernels/Pooling.metal
EDB6
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include <ATen/native/mps/kernels/Pooling.h>
#include <c10/metal/atomic.h>
#include <metal_array>
#include <metal_stdlib>

using namespace metal;
using namespace c10::metal;

// Iterates through all the input elements that this kernel needs to
// apply max to. Specialized for 3 pooling dimensions.
Expand Down Expand Up @@ -83,6 +86,50 @@ void max_pool_3d_input_iter(
*indices = max_index;
}

struct PoolOffsets {
int64_t output;
int64_t indices;
int64_t input_leading;

PoolOffsets() : output(0), indices(0), input_leading(0) {}
};

// Finds the offset of the output element that a forward pass thread will
// calculate, `output[N, C, d, h, w]`. Also, find the offset of the input for
// the leading dim indices, `input[N, C]`. Optionally, keep track of the output
// pooling dimension indices, `[d, h , w]`.
PoolOffsets find_pool_offsets(
constant int64_t* output_sizes,
constant int64_t* output_strides,
constant int64_t* indices_strides,
constant int64_t* input_strides,
device int64_t* work_pooling_dim_indices,
int32_t dims,
int32_t leading_dims,
uint tid) {
int64_t output_idx = static_cast<int64_t>(tid);
PoolOffsets offsets;

for (int64_t dim = dims - 1; dim >= 0; dim--) {
int64_t dim_idx = output_idx % (output_sizes[dim]);
offsets.output += output_strides[dim] * dim_idx;
offsets.indices += indices_strides[dim] * dim_idx;

if (dim < leading_dims) {
offsets.input_leading += input_strides[dim] * dim_idx;
} else {
// Keep track of pooling dimension indices of the output element, so we
// can use them in the input iteration later on.
if (work_pooling_dim_indices != nullptr) {
work_pooling_dim_indices[dim - leading_dims] = dim_idx;
}
}
output_idx = output_idx / output_sizes[dim];
}

return offsets;
}

// Kernel computes one element of the output per kernel call.
template <typename T>
kernel void max_pool(
Expand Down Expand Up @@ -113,32 +160,20 @@ kernel void max_pool(
// element of the output. We need to fill it with the proper values below.
device int64_t* work_pooling_dim_indices =
work_pooling_dim_indices_ + tid * pooling_dims;
int64_t output_idx = static_cast<int64_t>(tid);
int64_t output_offset = 0;
int64_t indices_offset = 0;
int64_t input_leading_offset = 0;

// First, find the offset of the output element this thread will calculate,
// `output[N, C, d, h, w]`. Also, find the offset of the input for the leading
// dim indices, `input[N, C]` and keep track of the pooling dimension indices,
// `[d, h , w]`.
for (int64_t dim = dims - 1; dim >= 0; dim--) {
int64_t dim_idx = output_idx % (output_sizes[dim]);
output_offset += output_strides[dim] * dim_idx;
indices_offset += indices_strides[dim] * dim_idx;

if (dim < leading_dims) {
input_leading_offset += input_strides[dim] * dim_idx;
} else {
// Keep track of pooling dimension indices of the output element, so we
// can use them in the input iteration later on.
work_pooling_dim_indices[dim - leading_dims] = dim_idx;
}
output_idx = output_idx / output_sizes[dim];
}
output += output_offset;
indices += indices_offset;
input += input_leading_offset;
PoolOffsets offsets = find_pool_offsets(
output_sizes,
output_strides,
indices_strides,
input_strides,
work_pooling_dim_indices,
dims,
leading_dims,
tid);

output += offsets.output;
indices += offsets.indices;
input += offsets.input_leading;

max_pool_3d_input_iter<T>(
input,
Expand All @@ -153,6 +188,69 @@ kernel void max_pool(
dilation);
}

// Finds the element in the grad input which corresponds to the index into the
// pool, and then adds the grad output element to it.
template <typename T>
void max_pool_backward_impl(
device AtomicType_t<T>* grad_input,
T grad_output_element,
int32_t input_index,
constant int64_t* grad_input_sizes,
constant int64_t* grad_input_strides,
int32_t grad_input_leading_offset,
int32_t pooling_dims) {
int32_t size_prod = 1;
int32_t pool_offset = 0;

for (int32_t dim = pooling_dims - 1; dim >= 0; dim--) {
int32_t next_size_prod = grad_input_sizes[dim] * size_prod;
pool_offset +=
grad_input_strides[dim] * ((input_index % next_size_prod) / size_prod);
size_prod *= grad_input_sizes[dim];
}

AtomicType<T>::atomic_add(
Copy link
Collaborator Author
@kurtamohler kurtamohler Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is AtomicType::atomic_add deterministic? My guess is no, in which case I think I should mark this op nondeterministic for torch.use_deterministic_algorithms. I think we would only need the nondeterministic alert to be raised if the input requires grad and the stride is less than kernel size in any of the dimensions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is need for it, I could write a deterministic alternative that doesn't use atomic add. The CUDA impl doesn't have this yet either

grad_input, grad_input_leading_offset + pool_offset, grad_output_element);
}

// Kernel computes one element of the grad input per kernel call.
template <typename T>
kernel void max_pool_backward(
device AtomicType_t<T>* grad_input [[buffer(0)]],
constant T* grad_output [[buffer(1)]],
constant int64_t* indices [[buffer(2)]],
constant PoolingBackwardParams<5>& params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
int32_t pooling_dims = params.pooling_dims;
int32_t dims = params.dims;
constant int64_t* grad_input_sizes = params.grad_input_sizes.data();
constant int64_t* grad_input_strides = params.grad_input_strides.data();
constant int64_t* grad_output_sizes = params.grad_output_sizes.data();
constant int64_t* grad_output_strides = params.grad_output_strides.data();
constant int64_t* indices_strides = params.indices_strides.data();

int32_t leading_dims = dims - pooling_dims;

PoolOffsets offsets = find_pool_offsets(
grad_output_sizes,
grad_output_strides,
indices_strides,
grad_input_strides,
nullptr,
dims,
leading_dims,
tid);

max_pool_backward_impl<T>(
grad_input,
grad_output[offsets.output],
indices[offsets.indices],
grad_input_sizes + leading_dims,
grad_input_strides + leading_dims,
offsets.input_leading,
pooling_dims);
}

#define REGISTER_MAX_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant void* input_ [[buffer(0)]], \
Expand All @@ -162,6 +260,15 @@ kernel void max_pool(
constant PoolingParams<5>& params [[buffer(4)]], \
uint tid [[thread_position_in_grid]]);

#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
template [[host_name("max_pool_backward_" #DTYPE)]] \
kernel void max_pool_backward<DTYPE>( \
1E0A device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output_ [[buffer(1)]], \
constant int64_t* grad_indices_ [[buffer(2)]], \
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);

REGISTER_MAX_POOL_OP(float);
REGISTER_MAX_POOL_OP(half);
REGISTER_MAX_POOL_OP(int);
Expand All @@ -170,6 +277,11 @@ REGISTER_MAX_POOL_OP(short);
REGISTER_MAX_POOL_OP(char);
REGISTER_MAX_POOL_OP(uchar);
REGISTER_MAX_POOL_OP(bool);

REGISTER_MAX_POOL_BACKWARD_OP(float);
REGISTER_MAX_POOL_BACKWARD_OP(half);

#if __METAL_VERSION__ >= 310
REGISTER_MAX_POOL_OP(bfloat);
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
#endif
135 changes: 123 additions & 12 deletions aten/src/ATen/native/mps/operations/Pooling.mm
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ATen/ops/max_pool2d_native.h>
#include <ATen/ops/max_pool2d_with_indices_backward_native.h>
#include <ATen/ops/max_pool2d_with_indices_native.h>
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
#include <ATen/ops/max_pool3d_with_indices_native.h>
#endif

Expand Down Expand Up @@ -270,16 +271,16 @@ static IntArrayRef tensor_to_intarrayref(const Tensor& tensor) {
return IntArrayRef(data_ptr, length);
}

static void max_pool_with_indices_out_mps_template(const Tensor& output,
const Tensor& indices,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const int32_t pooling_dims,
const std::string& op_name) {
using PoolSizes = std::tuple<int32_t, Tensor, Tensor, Tensor, Tensor, Tensor>;

static PoolSizes process_pool_sizes(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const int32_t pooling_dims,
const std::string& op_name) {
TORCH_INTERNAL_ASSERT(pooling_dims == 1 || pooling_dims == 2 || pooling_dims == 3);

const int32_t dims = input.dim();
Expand Down Expand Up @@ -387,9 +388,27 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,

t_output_size.slice(0, leading_dims) = t_output_pooling_size;

return std::tuple<int32_t, Tensor, Tensor, Tensor, Tensor, Tensor>(
dims, t_output_size, t_kernel_size, t_stride, t_padding, t_dilation);
}

static void max_pool_with_indices_out_mps_template(const Tensor& output,
const Tensor& indices,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, t_output_size, t_kernel_size, t_stride, t_padding, t_dilation] =
process_pool_sizes(input, kernel_size, stride, padding, dilation, ceil_mode, pooling_dims, op_name);

IntArrayRef output_size = tensor_to_intarrayref(t_output_size);
output.resize_(output_size);
indices.resize_(output_size);
const auto memory_format = input.suggest_memory_format();
output.resize_(output_size, memory_format);
indices.resize_(output_size, memory_format);

auto iter = TensorIteratorConfig().add_output(output).resize_outputs(false).check_all_same_dtype(false).build();

Expand Down Expand Up @@ -436,6 +455,52 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
});
}

static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
const Tensor& indices,
const Tensor& input,
const Tensor& grad_output,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, t_output_size, t_kernel_size, t_stride, t_padding, t_dilation] =
process_pool_sizes(input, kernel_size, stride, padding, dilation, ceil_mode, pooling_dims, op_name);

const auto memory_format = input.suggest_memory_format();
grad_input.resize_(input.sizes(), memory_format);
grad_input.fill_(0);

id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = grad_output.numel();
PoolingBackwardParams<5> params;

params.dims = dims;
params.pooling_dims = pooling_dims;
memcpy(params.grad_input_sizes.data(), grad_input.sizes().data(), dims * sizeof(int64_t));
memcpy(params.grad_input_strides.data(), grad_input.strides().data(), dims * sizeof(int64_t));
memcpy(params.grad_output_strides.data(), grad_output.strides().data(), dims * sizeof(int64_t));
memcpy(params.grad_output_sizes.data(), grad_output.sizes().data(), dims * sizeof(int64_t));
memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t));

dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto maxPoolPSO = lib.getPipelineStateForFunc("max_pool_backward_" + scalarToMetalTypeString(input));

getMPSProfiler().beginProfileKernel(maxPoolPSO, op_name, {input});
[computeEncoder setComputePipelineState:maxPoolPSO];
mtl_setArgs(computeEncoder, grad_input, grad_output, indices, params);

mtl_dispatch1DJob(computeEncoder, maxPoolPSO, numThreads);
getMPSProfiler().endProfileKernel(maxPoolPSO);
}
});
}

static void avg_pool2d_template(const Tensor& input,
const Tensor& output,
const std::optional<Tensor>& grad_output_opt,
Expand Down Expand Up @@ -738,6 +803,52 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output,
return std::tuple<Tensor, Tensor>(output, indices);
}

Tensor& max_pool3d_with_indices_backward_out_mps(const Tensor& grad_output,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices,
Tensor& grad_input) {
mps::max_pool_with_indices_backward_out_mps_template(grad_input,
indices,
input,
grad_output,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
/*pooling_dims=*/3,
"max_pool3d_backward");
return grad_input;
}

Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices) {
auto grad_input = at::empty({0}, input.options());
mps::max_pool_with_indices_backward_out_mps_template(grad_input,
indices,
input,
grad_output,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
/*pooling_dims=*/3,
"max_pool3d_backward");
return grad_input;
}

TORCH_IMPL_FUNC(avg_pool2d_out_mps)
(const Tensor& input,
int64_t kH,
Expand Down
Loading
Loading
0