8000 Update · pytorch/pytorch@106c7b6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 106c7b6

Browse files
committed
Update
[ghstack-poisoned]
2 parents 59cf6e5 + 781446b commit 106c7b6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1632
-907
lines changed

aten/src/ATen/cuda/tunable/Tunable.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131

3232
// for validators
3333
#ifdef USE_ROCM
34+
#ifdef _WIN32
35+
#include <hip/hip_version.h>
36+
#else
3437
#include <rocm-core/rocm_version.h>
38+
#endif
3539
#define ROCBLAS_BETA_FEATURES_API
3640
#include <rocblas/rocblas.h>
3741
#include <hipblaslt/hipblaslt.h>
@@ -218,7 +222,11 @@ TuningResultsValidator::TuningResultsValidator() {
218222
#ifdef USE_ROCM
219223
// rocm
220224
{
225+
#ifdef _WIN32
226+
std::string rocm_version = HIP_VERSION_BUILD_NAME;
227+
#else
221228
std::string rocm_version = ROCM_BUILD_INFO;
229+
#endif
222230
RegisterValidator(
223231
"ROCM_VERSION",
224232
[rocm_version]() { return rocm_version; },

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -638,28 +638,28 @@ MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) {
638638
switch (type) {
639639
case ScalarType::Double:
640640
case ScalarType::Float:
641-
return {.value.f = scalar.to<float>(), .size = sizeof(float), .type = type};
641+
return {.size = sizeof(float), .type = type, .value.f = scalar.to<float>()};
642642
case ScalarType::Half:
643-
return {.value.h = scalar.to<Half>(), .size = sizeof(short), .type = type};
643+
return {.size = sizeof(short), .type = type, .value.h = scalar.to<Half>()};
644644
case ScalarType::BFloat16:
645-
return {.value.bf16 = scalar.to<BFloat16>(), .size = sizeof(short), .type = type};
645+
return {.size = sizeof(short), .type = type, .value.bf16 = scalar.to<BFloat16>()};
646646
case ScalarType::Long:
647-
return {.value.i = scalar.to<int64_t>(), .size = sizeof(int64_t), .type = type};
647+
return {.size = sizeof(int64_t), .type = type, .value.i = scalar.to<int64_t>()};
648648
case ScalarType::Int:
649-
return {.value.i = scalar.to<int32_t>(), .size = sizeof(int32_t), .type = type};
649+
return {.size = sizeof(int32_t), .type = type, .value.i = scalar.to<int32_t>()};
650650
case ScalarType::Short:
651-
return {.value.i = scalar.to<int16_t>(), .size = sizeof(int16_t), .type = type};
651+
return {.size = sizeof(int16_t), .type = type, .value.i = scalar.to<int16_t>()};
652652
case ScalarType::Char:
653-
return {.value.i = scalar.to<int8_t>(), .size = sizeof(int8_t), .type = type};
653+
return {.size = sizeof(int8_t), .type = type, .value.i = scalar.to<int8_t>()};
654654
case ScalarType::Byte:
655-
return {.value.i = scalar.to<uint8_t>(), .size = sizeof(uint8_t), .type = type};
655+
return {.size = sizeof(uint8_t), .type = type, .value.i = scalar.to<uint8_t>()};
656656
case ScalarType::Bool:
657-
return {.value.b = scalar.to<bool>(), .size = sizeof(bool), .type = type};
657+
return {.size = sizeof(bool), .type = type, .value.b = scalar.to<bool>()};
658658
case ScalarType::ComplexHalf:
659-
return {.value.ch = scalar.to<c10::complex<Half>>(), .size = sizeof(int32_t), .type = type};
659+
return {.size = sizeof(int32_t), .type = type, .value.ch = scalar.to<c10::complex<Half>>()};
660660
case ScalarType::ComplexFloat:
661661
case ScalarType::ComplexDouble:
662-
return {.value.cf = scalar.to<c10::complex<float>>(), .size = sizeof(int64_t), .type = type};
662+
return {.size = sizeof(int64_t), .type = type, .value.cf = scalar.to<c10::complex<float>>()};
663663
default:
664664
TORCH_INTERNAL_ASSERT(false, "Unsupported scalar type '", type, "' on MPS backend.");
665665
}
@@ -965,45 +965,49 @@ static dispatch_data_t getSectionData(const std::string& name) {
965965
std::optional<int64_t> extra) {
966966
auto inputTensor = iter.input(0);
967967
auto outputTensor = iter.output(0);
968-
bool is_dense_strided = is_dense_in_storage(inputTensor) && inputTensor.strides().equals(outputTensor.strides());
969-
bool needs_output_copy = false;
970-
uint32_t length = outputTensor.numel();
968+
bool is_storage_dense = is_dense_in_storage(inputTensor) && inputTensor.strides().equals(outputTensor.strides());
969+
uint32_t length = iter.numel();
971970
if (length == 0) {
972971
return;
973972
}
974973
using namespace mps;
975974
@autoreleasepool {
976975
id<MTLComputePipelineState> cplState = nil;
977-
cplState = getPipelineStateForFunc(fmt::format(
978-
"{}_dense_{}_{}", name, scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(inputTensor)));
979-
980-
if (!is_dense_strided) {
981-
inputTensor = inputTensor.contiguous();
982-
if (!outputTensor.is_contiguous()) {
983-
outputTensor = outputTensor.contiguous();
984-
needs_output_copy = true;
985-
}
986-
}
976+
cplState = getPipelineStateForFunc(fmt::format("{}_{}_{}_{}",
977+
name,
978+
is_storage_dense ? "dense" : "strided",
979+
scalarToMetalTypeString(outputTensor),
980+
scalarToMetalTypeString(inputTensor)));
987981

988982
MPSStream* mpsStream = getCurrentMPSStream();
989983
dispatch_sync(mpsStream->queue(), ^() {
990-
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
984+
auto computeEncoder = mpsStream->commandEncoder();
991985

992986
getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor});
993987

994988
[computeEncoder setComputePipelineState:cplState];
995-
mtl_setArgs(computeEncoder, outputTensor, inputTensor);
996-
if (extra) {
997-
mtl_setBytes(computeEncoder, *extra, 2);
989+
if (is_storage_dense) {
990+
mtl_setArgs(computeEncoder, outputTensor, inputTensor);
991+
if (extra) {
992+
mtl_setBytes(computeEncoder, *extra, 2);
993+
}
994+
} else {
995+
mtl_setArgs(computeEncoder,
996+
outputTensor,
997+
inputTensor,
998+
outputTensor.sizes(),
999+
inputTensor.strides(),
1000+
outputTensor.strides(),
1001+
inputTensor.ndimension());
1002+
if (extra) {
1003+
mtl_setBytes(computeEncoder, *extra, 6);
1004+
}
9981005
}
9991006
mtl_dispatch1DJob(computeEncoder, cplState, length);
10001007

10011008
getMPSProfiler().endProfileKernel(cplState);
10021009
});
10031010
}
1004-
if (needs_output_copy) {
1005-
iter.output(0).copy_(outputTensor);
1006-
}
10071011
}
10081012

10091013
MetalShaderLibrary& MetalShaderLibrary::getBundledLibrary() {

aten/src/ATen/native/mps/kernels/UnaryKernel.metal

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ INSTANTIATE_UNARY_KERNELS_VEC2(half);
9999
INSTANTIATE_UNARY_KERNELS_VEC2(float);
100100

101101
template <typename T>
102-
kernel void round_decimals_kernel(
102+
kernel void round_decimals_dense(
103103
device T* output [[buffer(0)]],
104104
constant T* input [[buffer(1)]],
105105
constant long& ndigits [[buffer(2)]],
@@ -108,14 +108,43 @@ kernel void round_decimals_kernel(
108108
rint(exp10(float(ndigits)) * input[index]) * exp10(float(-ndigits)));
109109
}
110110

111-
#define INSTANTIATE_ROUND_DECIMALS(DTYPE) \
112-
template \
113-
[[host_name("round_decimals_dense_" #DTYPE "_" #DTYPE)]] kernel void \
114-
round_decimals_kernel( \
115-
device DTYPE* output [[buffer(0)]], \
116-
constant DTYPE* input [[buffer(1)]], \
117-
constant long& ndigits [[buffer(2)]], \
118-
uint id [[thread_position_in_grid]])
111+
template <typename T>
112+
kernel void round_decimals_strided(
113+
device T* output [[buffer(0)]],
114+
constant T* input [[buffer(1)]],
115+
constant long* sizes [[buffer(2)]],
116+
constant long* input_strides [[buffer(3)]],
117+
constant long* output_strides [[buffer(4)]],
118+
constant uint& ndim [[buffer(5)]],
119+
constant long& ndigits [[buffer(6)]],
120+
uint index [[thread_position_in_grid]]) {
121+
int pos[max_ndim];
122+
pos_from_thread_index(int(index), pos, sizes, ndim);
123+
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
124+
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
125+
output[output_offs] = static_cast<T>(
126+
rint(exp10(float(ndigits)) * input[input_offs]) * exp10(float(-ndigits)));
127+
}
128+
129+
#define INSTANTIATE_ROUND_DECIMALS(DTYPE) \
130+
template \
131+
[[host_name("round_decimals_dense_" #DTYPE "_" #DTYPE)]] kernel void \
132+
round_decimals_dense( \
133+
device DTYPE* output [[buffer(0)]], \
134+
constant DTYPE* input [[buffer(1)]], \
135+
constant long& ndigits [[buffer(2)]], \
136+
uint index [[thread_position_in_grid]]); \
137+
template \
138+
[[host_name("round_decimals_strided_" #DTYPE "_" #DTYPE)]] kernel void \
139+
round_decimals_strided( \
140+
device DTYPE* output [[buffer(0)]], \
141+
constant DTYPE* input [[buffer(1)]], \
142+
constant long* sizes, \
143+
constant long* input_strides, \
144+
constant long* output_strides, \
145+
constant uint& ndim, \
146+
constant long& ndigits [[buffer(6)]], \
147+
uint index)
119148

120149
INSTANTIATE_ROUND_DECIMALS(float);
121150
INSTANTIATE_ROUND_DECIMALS(half);

benchmarks/dynamo/pr_time_benchmarks/check_results.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,19 @@ def log(event_name):
210210
writer.writerow([])
211211
writer.writerow([])
212212

213-
print("new expected results file content if needed:")
213+
print("=" * 80)
214+
print("=" * 80)
215+
print("=" * 80)
216+
print("To update expected results, run the following command:")
217+
print()
218+
print("cat > benchmarks/dynamo/pr_time_benchmarks/expected_results.csv << EOF")
214219
with open(reference_expected_results_path) as f:
215-
print(f.read())
220+
print(f.read().rstrip())
221+
print("EOF")
222+
print()
223+
print("=" * 80)
224+
print("=" * 80)
225+
print("=" * 80)
216226

217227
if fail:
218228
print(

c10/metal/indexing.h

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,41 @@ kernel void unary_dense(
5252
output[index] = f(input[index]);
5353
}
5454

55-
#define REGISTER_UNARY_OP(NAME, DTYPE0, DTYPE1) \
56-
static_assert( \
57-
::metal:: \
58-
is_same_v<DTYPE1, ::c10::metal::result_of<NAME##_functor, DTYPE0>>, \
59-
"Output dtype mismatch for unary op " #NAME " and input " #DTYPE0); \
60-
template [[host_name(#NAME "_dense_" #DTYPE1 "_" #DTYPE0)]] kernel void :: \
61-
c10::metal::unary_dense<DTYPE0, NAME##_functor>( \
62-
device ::c10::metal::result_of<NAME##_functor, DTYPE0> * output, \
63-
constant DTYPE0 * input, \
55+
template <typename T, typename F>
56+
kernel void unary_strided(
57+
device result_of<F, T>* output [[buffer(0)]],
58+
constant T* input [[buffer(1)]],
59+
constant long* sizes [[buffer(2)]],
60+
constant long* input_strides [[buffer(3)]],
61+
constant long* output_strides [[buffer(4)]],
62+
constant uint& ndim [[buffer(5)]],
63+
uint index [[thread_position_in_grid]]) {
64+
F f;
65+
int pos[max_ndim];
66+
pos_from_thread_index(int(index), pos, sizes, ndim);
67+
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
68+
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
69+
output[output_offs] = f(input[input_offs]);
70+
}
71+
72+
#define REGISTER_UNARY_OP(NAME, DTYPE0, DTYPE1) \
73+
static_assert( \
74+
::metal:: \
75+
is_same_v<DTYPE1, ::c10::metal::result_of<NAME##_functor, DTYPE0>>, \
76+
"Output dtype mismatch for unary op " #NAME " and input " #DTYPE0); \
77+
template [[host_name(#NAME "_dense_" #DTYPE1 "_" #DTYPE0)]] kernel void :: \
78+
c10::metal::unary_dense<DTYPE0, NAME##_functor>( \
79+
device ::c10::metal::result_of<NAME##_functor, DTYPE0> * output, \
80+
constant DTYPE0 * input, \
81+
uint index); \
82+
template [[host_name(#NAME "_strided_" #DTYPE1 "_" #DTYPE0)]] kernel void :: \
83+
c10::metal::unary_strided<DTYPE0, NAME##_functor>( \
84+
device ::c10::metal::result_of<NAME##_functor, DTYPE0> * output, \
85+
constant DTYPE0 * input, \
86+
constant long* sizes, \
87+
constant long* input_strides, \
88+
constant long* output_strides, \
89+
constant uint& ndim, \
6490
uint index)
6591

6692
#define DEFINE_UNARY_FLOATING_FUNCTOR(NAME) \

cmake/Codegen.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ if(INTERN_BUILD_ATEN_OPS)
9292
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
9393
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
9494
endif()
95-
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_100a.*")
95+
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_100.*")
9696
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_100a,code=sm_100a")
9797
endif()
9898
endif()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@
219219
# Builds libtorch.so and its dependencies as a wheel
220220
#
221221
# BUILD_PYTHON_ONLY
222-
# Builds pytorch as a wheel using libtorch.so from a seperate wheel
222+
# Builds pytorch as a wheel using libtorch.so from a separate wheel
223223

224224
import os
225225
import sys

test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter {
363363
};
364364

365365
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
366+
// Note (kwen2501) 03/07/2025
367+
// TODO: re-enable
368+
GTEST_SKIP() << "Skipping test as the trace write seems unstable.";
366369
int heartBeatIntervalInSec = 2;
367370
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
368371
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);

test/distributed/test_c10d_common.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,6 +1559,11 @@ def wait(self, timeout=5.0):
15591559

15601560

15611561
class DummyProcessGroup(dist.ProcessGroup):
1562+
def __init__(self, *args, **kwargs):
1563+
super().__init__(*args, **kwargs)
1564+
self._aborted = False
1565+
self._shutdown = False
1566+
15621567
def getBackendName(self):
15631568
return "Dummy"
15641569

@@ -1622,6 +1627,12 @@ def recv(self, tensor_list, src, tag=0):
16221627

16231628
return DummyWork()
16241629

1630+
def abort(self) -> None:
1631+
self._aborted = True
1632+
1633+
def shutdown(self) -> None:
1634+
self._shutdown = True
1635+
16251636

16261637
class PythonProcessGroupExtensionTest(MultiProcessTestCase):
16271638
def setUp(self):
@@ -1794,6 +1805,36 @@ def test_send_recv(self):
17941805
# intentionally not calling into `destroy_process_group` as not all
17951806
# user applications would explicitly that.
17961807

1808+
def test_shutdown(self) -> None:
1809+
dist.Backend.register_backend(
1810+
"dummy", PythonProcessGroupExtensionTest.create_dummy
1811+
)
1812+
1813+
os.environ["MASTER_ADDR"] = "localhost"
1814+
os.environ["MASTER_PORT"] = "6789"
1815+
dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
1816+
1817+
pg = c10d._get_default_group()
1818+
1819+
dist.destroy_process_group()
1820+
1821+
self.assertTrue(pg._shutdown)
1822+
1823+
def test_abort(self) -> None:
1824+
dist.Backend.register_backend(
1825+
"dummy", PythonProcessGroupExtensionTest.create_dummy
1826+
)
1827+
1828+
os.environ["MASTER_ADDR"] = "localhost"
1829+
os.environ["MASTER_PORT"] = "6789"
1830+
dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
1831+
1832+
pg = c10d._get_default_group()
1833+
1834+
c10d._abort_process_group()
1835+
1836+
self.assertTrue(pg._aborted)
1837+
17971838

17981839
instantiate_parametrized_tests(CommonDistributedDataParallelTest)
17991840

test/distributed/test_c10d_pypg.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ def test_attr_overrides(self):
191191
pg._set_group_desc("desc")
192192
self.assertEqual(pg.group_desc, "py:desc")
193193

194+
def test_abort_shutdown(self) -> None:
195+
# verify this are noops
196+
pg = DummyAttrProcessGroup(0, 1)
197+
pg.abort()
198+
pg.shutdown()
199+
194200

195201
if __name__ == "__main__":
196202
run_tests()

0 commit comments

Comments
 (0)
0