8000 Move mps_linear forward to use MPS kernels directly instead of MPSGra… · pytorch/pytorch@4e24ee7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4e24ee7

Browse files
jhavukainenmalfet
authored andcommitted
Move mps_linear forward to use MPS kernels directly instead of MPSGraph (#152210)
This PR moves `mps_linear` to use MPSNDArrays and call into the MPS kernel directly instead of going through MPSGraph. It also adds a caching mechanism for reusing MPS kernels as there is also a small overhead attached to creating the kernel object. The impact of the improvement is relatively more significant for small input kernels where the MPSGraph overhead represents a larger portion of the overall execution time of the operation but the speedup shows for both small and large input sizes as expected. `mps_linear` before the changes: ``` input shapes: f32:[1,1,20], f32:[1,20] torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x109d67110> func(*args, **kwargs) Median: 199.29 us IQR: 9.56 us (196.71 to 206.27) 979 measurements, 1 runs per measurement, 1 thread input shapes: f32:[1,1,5120], f32:[13284,5120] torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x1063b4510> func(*args, **kwargs) Median: 979.29 us IQR: 25.29 us (964.83 to 990.13) 205 measurements, 1 runs per measurement, 1 thread ``` `mps_linear` after the changes: ``` input shapes: f32:[1,1,20], f32:[1,20] torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x10693a190> func(*args, **kwargs) Median: 176.08 us IQR: 15.02 us (172.42 to 187.44) 1103 measurements, 1 runs per measurement, 1 thread input shapes: f32:[1,1,5120], f32:[13284,5120] torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x10d524dd0> func(*args, **kwargs) Median: 952.56 us IQR: 15.63 us (945.47 to 961.10) 210 measurements, 1 runs per measurement, 1 thread ``` Pull Request resolved: #152210 Approved by: https://github.com/kulinseth, https://github.com/malfet Co-authored-by: Nikita Shulga <nshulga@meta.com>
1 parent d07fbd4 commit 4e24ee7

File tree

3 files changed

+180
-3
lines changed

3 files changed

+180
-3
lines changed

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
100100
const TensorBase& input,
101101
bool includesInt64 = false);
102102

103+
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray);
103104
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
104105
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil);
105106
// The MPSShape could vary based on memory format
@@ -160,6 +161,26 @@ std::string get_mem_format_string(c10::MemoryFormat memory_format);
160161

161162
using MPSCacheKey = uint64_t;
162163

164+
struct MPSCachedKernel {
165+
MPSCachedKernel(NSObject* object) : _object([object retain]) {}
166+
virtual ~MPSCachedKernel() {
167+
[_object release];
168+
_object = nullptr;
169+
}
170+
171+
// Delete copy constructor and assignment
172+
MPSCachedKernel(const MPSCachedKernel&) = delete;
173+
void operator=(const MPSCachedKernel&) = delete;
174+
175+
template <typename T>
176+
inline T* kernel() const {
177+
return (T*)_object;
178+
}
179+
180+
private:
181+
NSObject* _object = nullptr;
182+
};
183+
163184
// derive this class to cache a graph and its inputs/outputs
164185
// can be used to store any NSObject
165186
struct MPSCachedGraph {
@@ -214,6 +235,97 @@ struct MPSBinaryGradCachedGraph : public MPSCachedGraph {
214235
MPSGraphTensor* gradInputTensor_ = nil;
215236
};
216237

238+
struct MPSKernelCache {
239+
typedef MPSCachedKernel* (^CreateCachedKernelBlock)();
240+
241+
struct CacheEntry {
242+
CacheEntry(const std::string& key, MPSCachedKernel* cachedKernel) : cachedKernel_(cachedKernel), key_(key) {}
243+
MPSCachedKernel* cachedKernel_ = nullptr;
244+
std::string key_;
245+
};
246+
247+
public:
248+
static MPSKernelCache* getInstance() {
249+
if (_instance_cache == nullptr) {
250+
_instance_cache = new MPSKernelCache();
251+
}
252+
return _instance_cache;
253+
}
254+
255+
~MPSKernelCache() {
256+
dispatch_release(serialQueue_);
257+
for (const auto& i : cache_) {
258+
delete i.second.cachedKernel_;
259+
}
260+
}
261+
262+
// Disallow the copy constructor and operator= functions
263+
MPSKernelCache(const MPSKernelCache&) = delete;
264+
void operator=(const MPSKernelCache&) = delete;
265+
266+
MPSCachedKernel* CreateCachedKernel(const std::string& key, CreateCachedKernelBlock createCacheBlock) {
267+
__block MPSCachedKernel* cachedKernel = nil;
268+
MPSCacheKey hash = std::hash<std::string>{}(key);
269+
dispatch_sync_with_rethrow(serialQueue_, ^() {
270+
if (cache_.count(hash) != 0) {
271+
auto& entry = cache_.at(hash);
272+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n");
273+
cachedKernel = entry.cachedKernel_;
274+
} else {
275+
cachedKernel = createCacheBlock();
276+
CacheEntry entry(key, cachedKernel);
277+
cache_.emplace(hash, entry);
278+
}
279+
});
280+
return cachedKernel;
281+
}
282+
template <typename T>
283+
inline T* CreateCachedKernelAs(const std::string& key, CreateCachedKernelBlock createCacheBlock) {
284+
return static_cast<T*>(CreateCachedKernel(key, createCacheBlock));
285+
}
286+
287+
MPSCachedKernel* LookUp(const std::string& key) const {
288+
__block MPSCachedKernel* cachedKernel = nil;
289+
290+
MPSCacheKey hash = std::hash<std::string>{}(key);
291+
dispatch_sync_with_rethrow(serialQueue_, ^() {
292+
if (cache_.count(hash) != 0) {
293+
auto& entry = cache_.at(hash);
294+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n");
295+
cachedKernel = entry.cachedKernel_;
296+
}
297+
});
298+
return cachedKernel;
299+
}
300+
301+
template <typename T>
302+
inline T* LookUpAs(const std::string& key) const {
303+
return static_cast<T*>(LookUp(key));
304+
}
305+
306+
private:
307+
MPSKernelCache() {
308+
serialQueue_ = dispatch_queue_create("kernel cache queue", DISPATCH_QUEUE_SERIAL);
309+
}
310+
311+
static MPSKernelCache* _instance_cache;
312+
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
313+
dispatch_queue_t serialQueue_ = nullptr;
314+
};
315+
316+
// Common template for creating cached kernel if missing
317+
template <typename T>
318+
inline T* LookUpOrCreateCachedKernel(const std::string& key, std::function<MPSKernel*()> instantiate) {
319+
auto cache_ = MPSKernelCache::getInstance();
320+
if (auto rc = cache_->LookUpAs<T>(key)) {
321+
return rc;
322+
}
323+
return cache_->CreateCachedKernelAs<T>(key, ^mps::MPSCachedKernel*() {
324+
auto k_ = new mps::MPSCachedKernel(instantiate());
325+
return k_;
326+
});
327+
}
328+
217329
// TODO: Improve the overall design of MPSGraphCache.
218330
// https://github.com/pytorch/pytorch/issues/77176
219331
// Cache holding various keys mapped to graphs

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ void printTensorNDArray(const TensorBase& t) {
467467
offset:t.storage_offset() * t.element_size()
468468
descriptor:srcTensorDesc] autorelease];
469469
if (strides != nil) {
470-
srcNDArray = [srcNDArray arrayViewWithShape:sizes strides:strides];
470+
srcNDArray = getStridedMPSNDArray(t, srcNDArray);
471471
}
472472
return srcNDArray;
473473
}
@@ -476,7 +476,7 @@ void printTensorNDArray(const TensorBase& t) {
476476
return getMPSNDArray(t, getMPSShape(sizes.empty() ? t.sizes() : sizes), strides.empty() ? nil : getMPSShape(strides));
477477
}
478478

479-
static MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray) {
479+
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray) {
480480
auto strides = src.strides();
481481
auto sizes = src.sizes();
482482
auto nStrides = strides.size();
@@ -778,6 +778,8 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) {
778778

779779
MPSGraphCache* MPSGraphCache::_instance_cache = nullptr;
780780

781+
MPSKernelCache* MPSKernelCache::_instance_cache = nullptr;
782+
781783
void MPSGraphCache::profileCachedGraph(const CacheEntry& cacheEntry) const {
782784
auto& profiler = getMPSProfiler();
783785
if (profiler.isOperationProfilingEnabled()) {

aten/src/ATen/native/mps/operations/Linear.mm

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright © 2022 Apple Inc.
22
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
33
#include <ATen/ExpandUtils.h>
4+
#include <ATen/mps/MPSProfiler.h>
5+
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
46
#include <ATen/native/mps/OperationUtils.h>
57
#include <ATen/ops/linear_backward_native.h>
68
#include <ATen/ops/linear_native.h>
@@ -9,6 +11,61 @@
911

1012
using namespace mps;
1113

14+
static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const Tensor& bias, Tensor& output) {
15+
bool is_bias_defined = bias.defined();
16+
17+
MPSStream* mpsStream = getCurrentMPSStream();
18+
id<MTLDevice> device = MPSDevice::getInstance()->device();
19+
20+
const string key = "mps_linear" + getTensorsStringKey({input, weight, bias}, true, true);
21+
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
22+
@autoreleasepool {
23+
mpsStream->endKernelCoalescing();
24+
25+
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
26+
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
27+
28+
MPSDataType mpsDataType = getMPSDataType(weight.scalar_type());
29+
30+
auto inputNDArray = getMPSNDArray(input, input.sizes(), input.strides());
31+
auto outNDArray = getMPSNDArray(output, output.sizes(), output.strides());
32+
33+
id<MTLBuffer> weightBuf = getMTLBufferStorage(weight);
34+
MPSNDArrayDescriptor* weightDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType
35+
shape:getMPSShape(weight.sizes())];
36+
weightDesc.preferPackedRows = YES;
37+
[weightDesc transposeDimension:0 withDimension:1];
38+
MPSNDArray* weightNDArray = [[MPSNDArray alloc] initWithBuffer:weightBuf
39+
offset:weight.storage_offset() * weight.element_size()
40+
descriptor:weightDesc];
41+
42+
if (is_bias_defined) {
43+
auto biasNDArray = getMPSNDArray(bias, bias.sizes(), bias.strides());
44+
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(
45+
key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3]; });
46+
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
47+
48+
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
49+
[kernel encodeToCommandEncoder:computeEncoder
50+
commandBuffer:commandBuffer
51+
sourceArrays:@[ inputNDArray, weightNDArray, biasNDArray ]
52+
destinationArray:outNDArray];
53+
getMPSProfiler().endProfileKernel(kernel);
54+
} else {
55+
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(
56+
key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2]; });
57+
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
58+
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
59+
[kernel encodeToCommandEncoder:computeEncoder
60+
commandBuffer:commandBuffer
61+
sourceArrays:@[ inputNDArray, weightNDArray ]
62+
destinationArray:outNDArray];
63+
getMPSProfiler().endProfileKernel(kernel);
64+
}
65+
}
66+
});
67+
}
68+
1269
Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::optional<Tensor>& bias_opt) {
1370
// wT = transpose(weight);
1471
// y=x*wT+b
@@ -17,6 +74,8 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt
1774
TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps");
1875
TORCH_CHECK(supportedFloatingOrComplexType(weight_arg), "MPS device does not support linear for non-float weights");
1976
TORCH_CHECK(weight_arg.is_mps(), "Tensor for argument weight is on ", weight_arg.device(), " but expected on mps");
77+
TORCH_CHECK((input.scalar_type() != kComplexFloat && input.scalar_type() != kComplexHalf),
78+
"mps linear does not support complex types");
2079

2180
const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt));
2281
const bool is_bias_defined = bias.defined();
@@ -54,8 +113,12 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt
54113
return output;
55114
}
56115

116+
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) {
117+
_mps_linear_nograph(input, weight, bias, output);
118+
// Squeeze last dim of 1D linear
119+
return weight_arg.dim() != 1 ? output : output.squeeze(-1);
120+
}
57121
MPSStream* stream = getCurrentMPSStream();
58-
59122
struct CachedGraph : public MPSCachedGraph {
60123
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) 3DE3 {}
61124
MPSGraphTensor* inputTensor_ = nil;

0 commit comments

Comments
 (0)
0