You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Move mps_linear forward to use MPS kernels directly instead of MPSGraph (#152210)
This PR moves `mps_linear` to use MPSNDArrays and call into the MPS kernel directly instead of going through MPSGraph. It also adds a caching mechanism for reusing MPS kernels as there is also a small overhead attached to creating the kernel object.
The impact of the improvement is relatively more significant for small input kernels where the MPSGraph overhead represents a larger portion of the overall execution time of the operation but the speedup shows for both small and large input sizes as expected.
`mps_linear` before the changes:
```
input shapes: f32:[1,1,20], f32:[1,20]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x109d67110>
func(*args, **kwargs)
Median: 199.29 us
IQR: 9.56 us (196.71 to 206.27)
979 measurements, 1 runs per measurement, 1 thread
input shapes: f32:[1,1,5120], f32:[13284,5120]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x1063b4510>
func(*args, **kwargs)
Median: 979.29 us
IQR: 25.29 us (964.83 to 990.13)
205 measurements, 1 runs per measurement, 1 thread
```
`mps_linear` after the changes:
```
input shapes: f32:[1,1,20], f32:[1,20]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x10693a190>
func(*args, **kwargs)
Median: 176.08 us
IQR: 15.02 us (172.42 to 187.44)
1103 measurements, 1 runs per measurement, 1 thread
input shapes: f32:[1,1,5120], f32:[13284,5120]
torch.linear time: <torch.utils.benchmark.utils.common.Measurement object at 0x10d524dd0>
func(*args, **kwargs)
Median: 952.56 us
IQR: 15.63 us (945.47 to 961.10)
210 measurements, 1 runs per measurement, 1 thread
```
Pull Request resolved: #152210
Approved by: https://github.com/kulinseth, https://github.com/malfet
Co-authored-by: Nikita Shulga <nshulga@meta.com>
0 commit comments