8000 [MPS] Add API to query GPU core count (#160414) · pytorch/pytorch@2e4c878 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2e4c878

Browse files
malfetchuanhaozhuge
authored andcommitted
[MPS] Add API to query GPU core count (#160414)
Using good old IOKit to get `gpu-core-count` property from device implementing `AGXAccelerator` service Expose this one as `torch.backend.mps.get_core_count()` and make it accessible via `MpsInterface` to the inductor Test Plan: Run `python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"` and compare it to `system_profiler SPDisplaysDataType|head -n10` ``` % python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())" Apple M1 Pro 16 % system_profiler SPDisplaysDataType|head -n10 Graphics/Displays: Apple M1 Pro: Chipset Model: Apple M1 Pro Type: GPU Bus: Built-In Total Number of Cores: 16 Vendor: Apple (0x106b) Metal Support: Metal 3 ``` This would significantly improve occupancy for torch.compile generated kernels Pull Request resolved: #160414 Approved by: https://github.com/dcci
1 parent c550f17 commit 2e4c878

File tree

9 files changed

+77
-9
lines changed

9 files changed

+77
-9
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,7 @@ if(APPLE)
11961196
string(
11971197
APPEND
11981198
CMAKE_SHARED_LINKER_FLAGS
1199-
" -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal"
1199+
" -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal -weak_framework IOKit"
12001200
)
12011201
# To suppress MPSGraph availability warnings
12021202
append_cxx_flag_if_supported("-Wno-unguarded-availability-new"

aten/src/ATen/mps/MPSDevice.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ class TORCH_API MPSDevice {
5555
*/
5656
bool isMacOS13Plus(MacOSVersion version) const;
5757

58+
/**
59+
* Returns device name
60+
*/
61+
std::string getName() const;
62+
63+
/**
64+
* Returns number of GPU cores.
65+
* 1 Core = 16 ExecutionUnit x 8 ALU x 24 threads
66+
*/
67+
unsigned getCoreCount() const;
68+
5869
~MPSDevice();
5970

6071
private:

aten/src/ATen/mps/MPSDevice.mm

Lines changed: 27 additions & 1 deletion
8000
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,36 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
8585
}
8686
}
8787

88+
std::string MPSDevice::getName() const {
89+
@autoreleasepool {
90+
return [[_mtl_device name] UTF8String];
91+
}
92+
}
93+
94+
unsigned MPSDevice::getCoreCount() const {
95+
io_iterator_t iterator = 0;
96+
io_registry_entry_t entry = 0;
97+
int core_count = 0;
98+
auto matchingDict = IOServiceMatching("AGXAccelerator");
99+
TORCH_INTERNAL_ASSERT(matchingDict, "Failed to create matching dict");
100+
const auto status = IOServiceGetMatchingServices(kIOMainPortDefault, matchingDict, &iterator);
101+
TORCH_INTERNAL_ASSERT(status == KERN_SUCCESS);
102+
while ((entry = IOIteratorNext(iterator)) != 0) {
103+
auto property = IORegistryEntryCreateCFProperty(entry, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
104+
auto found = CFNumberGetValue(static_cast<CFNumberRef>(property), kCFNumberIntType, &core_count);
105+
CFRelease(property);
106+
IOObjectRelease(entry);
107+
if (found) {
108+
break;
109+
}
110+
}
111+
IOObjectRelease(iterator);
112+
return core_count;
113+
}
114+
88115
at::Allocator* GetMPSAllocator(bool useSharedAllocator) {
89116
return getIMPSAllocator(useSharedAllocator);
90117
}
91-
92118
bool is_available() {
93119
return MPSDevice::getInstance()->device() != nil;
94120
}

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,7 +1979,9 @@ def _mtia_resetPeakMemoryStats(device: _int) -> None: ...
19791979

19801980
# Defined in torch/csrc/mps/Module.cpp
19811981
def _mps_deviceSynchronize() -> None: ...
1982+
def _mps_get_core_count() -> _int: ...
19821983
def _mps_get_default_generator() -> Generator: ...
1984+
def _mps_get_name() -> _str: ...
19831985
def _mps_emptyCache() -> None: ...
19841986
def _mps_setMemoryFraction(fraction: _float) -> None: ...
19851987
def _mps_currentAllocatedMemory() -> _int: ...

torch/_dynamo/device_interface.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import inspect
1919
import time
20+
from collections import namedtuple
2021
from collections.abc import Iterable
2122
from dataclasses import dataclass
2223
from typing import Any, Callable, Literal, Optional, Union
@@ -544,8 +545,10 @@ def synchronize(device: torch.types.Device = None) -> None:
544545

545546
class Worker:
546547
@staticmethod
547-
def get_device_properties(device: torch.types.Device = None) -> dict[str, Any]:
548-
return {}
548+
def get_device_properties(device: torch.types.Device = None) -> Any:
549+
return namedtuple("MPSProperties", ["multi_processor_count"])(
550+
torch.backends.mps.get_core_count() # type: ignore[arg-type]
551+
)
549552

550553
@staticmethod
551554
def current_device() -> int:

torch/_inductor/runtime/hints.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,6 @@ def create(cls, device) -> DeviceProperties:
153153
except AttributeError:
154154
if device_type == "xpu":
155155
multi_processor_count = props.gpu_subslice_count
156-
elif device_type == "mps":
157-
# TODO: Fetch the actual value from ioreg
158-
multi_processor_count = 8
159156
elif device_type == "mtia":
160157
multi_processor_count = 64
161158
else:

torch/backends/mps/__init__.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from torch.library import Library as _Library
66

77

8-
__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
8+
__all__ = [
9+
"get_core_count",
10+
"get_name",
11+
"is_built",
12+
"is_available",
13+
"is_macos13_or_newer",
14+
"is_macos_or_newer",
15+
]
916

1017

1118
def is_built() -> bool:
@@ -36,6 +43,23 @@ def is_macos13_or_newer(minor: int = 0) -> bool:
3643
return torch._C._mps_is_on_macos_or_newer(13, minor)
3744

3845

46+
@_lru_cache
47+
def get_name() -> str:
48+
r"""Return Metal device name"""
49+
return torch._C._mps_get_name()
50+
51+
52+
@_lru_cache
53+
def get_core_count() -> int:
54+
r"""Return GPU core count.
55+
56+
According to the documentation, one core is comprised of 16 Execution Units.
57+
One execution Unit has 8 ALUs.
58+
And one ALU can run 24 threads, i.e. one core is capable of executing 3072 threads concurrently.
59+
"""
60+
return torch._C._mps_get_core_count()
61+
62+
3963
_lib: Optional[_Library] = None
4064

4165

torch/csrc/Module.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include <ATen/Parallel.h>
2121
#include <ATen/Utils.h>
2222
#include <ATen/core/Vitals.h>
23-
#include <ATen/detail/AcceleratorHooksInterface.h>
2423
#include <ATen/dlpack.h>
2524
#include <ATen/native/ConvUtils.h>
2625
#include <ATen/native/ForeachUtils.h>

torch/csrc/mps/Module.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,12 @@ void initModule(PyObject* module) {
501501
at::mps::getMPSProfiler().startCapture(fileName);
502502
});
503503
m.def("_mps_stopCapture", []() { at::mps::getMPSProfiler().stopCapture(); });
504+
m.def("_mps_get_name", []() {
505+
return at::mps::MPSDevice::getInstance()->getName();
506+
});
507+
m.def("_mps_get_core_count", []() {
508+
return at::mps::MPSDevice::getInstance()->getCoreCount();
509+
});
504510
}
505511
#endif /* USE_MPS */
506512

0 commit comments

Comments
 (0)
0