8000 Update base for Update on "[Array API] Add linalg.vecdot" · pytorch/pytorch@bdadc73 · GitHub
[go: up one dir, main page]

Skip to content

Commit bdadc73

Browse files
committed
Update base for Update on "[Array API] Add linalg.vecdot"
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in data-apis/array-api#356 Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this. Resolves #18027. cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi [ghstack-poisoned]
2 parents 145e48e + b3e7230 commit bdadc73

File tree

260 files changed

+29933
-16826
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

260 files changed

+29933
-16826
lines changed

.github/scale-config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,25 @@ runner_types:
3030
linux.2xlarge:
3131
instance_type: c5.2xlarge
3232
os: linux
33-
max_available: 750
33+
max_available: 1000
10000 3434
disk_size: 150
3535
is_ephemeral: false
3636
linux.4xlarge: # for binary-builds
3737
instance_type: c5.4xlarge
3838
os: linux
39-
max_available: 250
39+
max_available: 500
4040
disk_size: 150
4141
is_ephemeral: false
4242
linux.8xlarge.nvidia.gpu:
4343
instance_type: g3.8xlarge
4444
os: linux
45-
max_available: 125
45+
max_available: 200
4646
disk_size: 150
4747
is_ephemeral: false
4848
linux.4xlarge.nvidia.gpu:
4949
instance_type: g3.4xlarge
5050
os: linux
51-
max_available: 175
51+
max_available: 250
5252
disk_size: 150
5353
is_ephemeral: false
5454
linux.16xlarge.nvidia.gpu:

.github/scripts/gql_mocks.json

Lines changed: 11905 additions & 11153 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.github/scripts/test_trymerge.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def test_pending_status_check(self, mocked_gql: Any, mocked_read_merge_rules: An
176176
"""
177177
pr = GitHubPR("pytorch", "pytorch", 76118)
178178
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
179-
self.assertRaisesRegex(MandatoryChecksMissingError, ".*are not yet run.*", lambda: find_matching_merge_rule(pr, repo))
179+
self.assertRaisesRegex(MandatoryChecksMissingError,
180+
".*are pending/not yet run.*",
181+
lambda: find_matching_merge_rule(pr, repo))
180182

181183
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
182184
def test_get_author_many_reviews(self, mocked_gql: Any) -> None:

.github/scripts/trymerge.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,15 @@
7777
nodes {
7878
name
7979
conclusion
80+
detailsUrl
8081
}
8182
pageInfo {
8283
endCursor
8384
hasNextPage
8485
}
8586
}
8687
conclusion
88+
url
8789
}
8890
pageInfo {
8991
endCursor
@@ -179,13 +181,15 @@
179181
nodes {
180182
name
181183
conclusion
184+
detailsUrl
182185
}
183186
pageInfo {
184187
endCursor
185188
hasNextPage
186189
}
187190
}
188191
conclusion
192+
url
189193
}
190194
pageInfo {
191195
endCursor
@@ -411,7 +415,7 @@ def __init__(self, org: str, project: str, pr_num: int) -> None:
411415
self.pr_num = pr_num
412416
self.info = gh_get_pr_info(org, project, pr_num)
413417
self.changed_files: Optional[List[str]] = None
414-
self.conclusions: Optional[Dict[str, str]] = None
418+
self.conclusions: Optional[Dict[str, Tuple[str, str]]] = None
415419
self.comments: Optional[List[GitHubComment]] = None
416420
self._authors: Optional[List[Tuple[str, str]]] = None
417421
self._reviews: Optional[List[Tuple[str, str]]] = None
@@ -526,8 +530,8 @@ def get_committer_login(self, num: int = 0) -> str:
526530
def get_committer_author(self, num: int = 0) -> str:
527531
return self._fetch_authors()[num][1]
528532

529-
def get_checkrun_conclusions(self) -> Dict[str, str]:
530-
""" Returns list of checkrun / conclusions """
533+
def get_checkrun_conclusions(self) -> Dict[str, Tuple[str, str]]:
534+
""" Returns dict of checkrun -> [conclusion, url] """
531535
if self.conclusions is not None:
532536
return self.conclusions
533537
orig_last_commit = self.info["commits"]["nodes"][-1]["commit"]
@@ -539,10 +543,10 @@ def add_conclusions(nodes: List[Dict[str, Any]]) -> None:
539543
workflow_run = node["workflowRun"]
540544
checkruns = node["checkRuns"]
541545
if workflow_run is not None:
542-
conclusions[workflow_run["workflow"]["name"]] = node["conclusion"]
546+
conclusions[workflow_run["workflow"]["name"]] = (node["conclusion"], node["url"])
543547
if checkruns is not None:
544548
for checkrun_node in checkruns["nodes"]:
545-
conclusions[checkrun_node["name"]] = checkrun_node["conclusion"]
549+
conclusions[checkrun_node["name"]] = (checkrun_node["conclusion"], checkrun_node["detailsUrl"])
546550

547551
add_conclusions(checksuites["nodes"])
548552
while bool(checksuites["pageInfo"]["hasNextPage"]):
@@ -646,7 +650,7 @@ def has_internal_changes(self) -> bool:
646650
checks = self.get_checkrun_conclusions()
647651
if checks is None or checkrun_name not in checks:
648652
return False
649-
return checks[checkrun_name] != "SUCCESS"
653+
return checks[checkrun_name][0] != "SUCCESS"
650654

651655
def merge_ghstack_into(self, repo: GitRepo, force: bool, comment_id: Optional[int] = None) -> None:
652656
assert self.is_ghstack_pr()
@@ -785,25 +789,32 @@ def find_matching_merge_rule(pr: GitHubPR,
785789
f"{','.join(list(rule_approvers_set)[:5])}{', ...' if len(rule_approvers_set) > 5 else ''}")
786790
continue
787791
if rule.mandatory_checks_name is not None:
788-
pending_checks = []
789-
failed_checks = []
792+
pending_checks: List[Tuple[str, Optional[str]]] = []
793+
failed_checks: List[Tuple[str, Optional[str]]] = []
790794
checks = pr.get_checkrun_conclusions()
791795
# HACK: We don't want to skip CLA check, even when forced
792796
for checkname in filter(lambda x: force is False or "CLA Check" in x, rule.mandatory_checks_name):
793-
if checkname not in checks or checks[checkname] is None:
794-
pending_checks.append(checkname)
795-
elif checks[checkname] != 'SUCCESS':
796-
failed_checks.append(checkname)
797+
if checkname not in checks:
798+
pending_checks.append((checkname, None))
799+
elif checks[checkname][0] is None:
800+
pending_checks.append((checkname, checks[checkname][1]))
801+
elif checks[checkname][0] != 'SUCCESS':
802+
failed_checks.append((checkname, checks[checkname][1]))
803+
804+
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
805+
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
806+
797807
if len(failed_checks) > 0:
798808
if reject_reason_score < 30000:
799809
reject_reason_score = 30000
800-
reject_reason = f"Refusing to merge as mandatory check(s) {','.join(failed_checks)} failed for rule {rule_name}"
810+
reject_reason = ("Refusing to merge as mandatory check(s)" +
811+
checks_to_str(failed_checks) + f" failed for rule {rule_name}")
801812
continue
802813
elif len(pending_checks) > 0:
803814
if reject_reason_score < 20000:
804815
reject_reason_score = 20000
805-
reject_reason = f"Refusing to merge as mandatory check(s) {','.join(pending_checks)}"
806-
reject_reason += f" are not yet run for rule {rule_name}"
816+
reject_reason = f"Refusing to merge as mandatory check(s) {checks_to_str(pending_checks)}"
817+
reject_reason += f" are pending/not yet run for rule {rule_name}"
807818
continue
808819
if not skip_internal_checks and pr.has_internal_changes():
809820
raise RuntimeError("This PR has internal changes and must be landed via Phabricator")

aten/src/ATen/Context.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,26 @@ bool NoTF32Guard::should_disable_tf32() {
349349
return override_allow_tf32_flag;
350350
}
351351

352+
#ifdef USE_ROCM
353+
// Ops can query this flag to know they are in the backward pass.
354+
// This information can be used, for example, to select implementations
355+
// with different numerical or performance characteristics.
356+
// See https://pytorch.org/docs/stable/notes/numerical_accuracy.html for details.
357+
thread_local bool ROCmBackwardPassGuard::is_backward_pass_;
358+
359+
ROCmBackwardPassGuard::ROCmBackwardPassGuard() {
360+
is_backward_pass_ = true;
361+
}
362+
363+
ROCmBackwardPassGuard::~ROCmBackwardPassGuard() {
364+
is_backward_pass_ = false;
365+
}
366+
367+
bool ROCmBackwardPassGuard::is_backward_pass() {
368+
return is_backward_pass_;
369+
}
370+
#endif
371+
352372
bool Context::areVmapFallbackWarningsEnabled() const {
353373
return display_vmap_fallback_warnings_;
354374
}

aten/src/ATen/Context.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,4 +403,14 @@ struct TORCH_API NoTF32Guard {
403403
bool changed = false;
404404
};
405405

406+
#ifdef USE_ROCM
407+
struct TORCH_API ROCmBackwardPassGuard {
408+
ROCmBackwardPassGuard();
409+
~ROCmBackwardPassGuard();
410+
static bool is_backward_pass();
411+
private:
412+
static thread_local bool is_backward_pass_;
413+
};
414+
#endif
415+
406416
} // namespace at

aten/src/ATen/Dispatch.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,46 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
416416
} \
417417
}()
418418

419+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
420+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
421+
[&] { \
422+
const auto& the_type = TYPE; \
423+
/* don't use TYPE again in case it is an expensive or side-effect op */ \
424+
at::ScalarType _st = ::detail::scalar_type(the_type); \
425+
RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
426+
switch (_st) { \
427+
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
428+
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
429+
AT_PRIVATE_CASE_TYPE( \
430+
NAME, \
431+
at::ScalarType::ComplexDouble, \
432+
c10::complex<double>, \
433+
__VA_ARGS__) \
434+
AT_PRIVATE_CASE_TYPE( \
435+
NAME, \
436+
at::ScalarType::ComplexFloat, \
437+
c10::complex<float>, \
438+
__VA_ARGS__) \
439+
AT_PRIVATE_CASE_TYPE( \
440+
NAME, \
441+
SCALARTYPE1, \
442+
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
443+
__VA_ARGS__) \
444+
AT_PRIVATE_CASE_TYPE( \
445+
NAME, \
446+
SCALARTYPE2, \
447+
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
448+
__VA_ARGS__) \
449+
AT_PRIVATE_CASE_TYPE( \
450+
NAME, \
451+
SCALARTYPE3, \
452+
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), \
453+
__VA_ARGS__) \
454+
default: \
455+
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
456+
} \
457+
}()
458+
419459
#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
420460
[&] { \
421461
const auto& the_type = TYPE; \

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace c10 {
5151
_(prim, reshape_copy) \
5252
_(prim, squeeze_copy) \
5353
_(prim, unsqueeze_copy) \
54+
_(prim, flatten_copy) \
5455
_(prim, DifferentiableGraph) \
5556
_(prim, TensorExprGroup) \
5657
_(prim, TensorExprDynamicGroup) \

aten/src/ATen/cpu/vec/vec_base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ struct Vectorized {
538538
// 1 if the pred is true, otherwise 0.
539539
Vectorized<T> vector;
540540
for (int i = 0; i != size(); ++ i) {
541-
vector[i] = bool(op(values[i], other.values[i]));
541+
vector[i] = static_cast<T>(op(values[i], other.values[i]));
542542
}
543543
return vector;
544544
}

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
#include <cublasLt.h>
1616
#endif
1717

18+
#ifdef USE_ROCM
19+
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
20+
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
21+
#endif
22+
1823
#define CUDABLAS_POSINT_CHECK(FD, X) \
1924
TORCH_CHECK( \
2025
(X > 0 && X <= INT_MAX), \
@@ -246,13 +251,17 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
246251
float falpha = alpha;
247252
float fbeta = beta;
248253
#ifdef USE_ROCM
254+
int flag = 0;
255+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
256+
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
257+
#endif
249258
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
250259
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
251260
b, rocblas_datatype_f16_r, (int)ldb, strideb,
252261
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
253262
c, rocblas_datatype_f16_r, (int)ldc, stridec,
254263
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
255-
0, 0));
264+
0, flag));
256265
#else
257266
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
258267
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
@@ -392,6 +401,10 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
392401
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
393402
GEMM_CHECK_ARGVALUES(at::Half);
394403
#ifdef USE_ROCM
404+
int flag = 0;
405+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
406+
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
407+
#endif
395408
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
396409
handle,
397410
opa,
@@ -416,7 +429,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
416429
rocblas_datatype_f32_r,
417430
rocblas_gemm_algo_standard,
418431
0,
419-
0));
432+
flag));
420433
#else
421434
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
422435
if (prop->major >= 5) {

aten/src/ATen/cuda/detail/CUDAHooks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#pragma once
2+
13
#include <ATen/detail/CUDAHooksInterface.h>
24

35
#include <ATen/Generator.h>

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ namespace native {
2121
static void check_convert(const Scalar& scalar, ScalarType scalarType) {
2222
// Validate that is possible to convert scalar to tensor dtype without
2323
// overflow
24-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
24+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
2525
at::ScalarType::Bool,
2626
at::ScalarType::BFloat16,
2727
at::ScalarType::Half,
28+
at::ScalarType::ComplexHalf,
2829
scalarType,
2930
"check_convert",
3031
[&] { scalar.to<scalar_t>(); });

0 commit comments

Comments
 (0)
0