10000 Update base for Update on "Connect Tensor.__ipow__ to pow_ method" · pytorch/pytorch@51286ec · GitHub
[go: up one dir, main page]

Skip to content

Commit 51286ec

Browse files
committed
Update base for Update on "Connect Tensor.__ipow__ to pow_ method"
The `pow_` method should be connected to `Tensor.__ipow__` so that the operator `**=` works correctly. Part of #58742 [ghstack-poisoned]
2 parents 3074761 + e175065 commit 51286ec

File tree

155 files changed

+20547
-1394
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

155 files changed

+20547
-1394
lines changed

.github/scripts/gitutils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,7 @@ def push(self, branch: str, dry_run: bool, retry: int = 3) -> None:
248248
else:
249249
self._run_git("push", self.remote, branch)
250250
except RuntimeError as e:
251-
# Check if push were rejected because branch is stale
252-
if len(e.args) == 0 or re.search(r"\[rejected\].+\(fetch first\)\n", e.args[0]) is None:
253-
raise
251+
print(f"{cnt} push attempt failed with {e}")
254252
self.fetch()
255253
self._run_git("rebase", f"{self.remote}/{branch}")
256254

.github/scripts/trymerge.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,14 @@ def parse_args() -> Any:
384384
parser.add_argument("pr_num", type=int)
385385
return parser.parse_args()
386386

387+
def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -> bool:
388+
if comment_id is None:
389+
return False
390+
comment = pr.get_comment_by_id(comment_id)
391+
if comment.editor_login is not None:
392+
return False
393+
return comment.author_login == "facebook-github-bot"
394+
387395

388396
@dataclass
389397
class GitHubComment:
@@ -640,7 +648,7 @@ def has_internal_changes(self) -> bool:
640648
return False
641649
return checks[checkrun_name] != "SUCCESS"
642650

643-
def merge_ghstack_into(self, repo: GitRepo, force: bool) -> None:
651+
def merge_ghstack_into(self, repo: GitRepo, force: bool, comment_id: Optional[int] = None) -> None:
644652
assert self.is_ghstack_pr()
645653
approved_by = self.get_approved_by()
646654
# For ghstack, cherry-pick commits based from origin
@@ -661,7 +669,7 @@ def merge_ghstack_into(self, repo: GitRepo, force: bool) -> None:
661669
continue
662670
approved_by = pr.get_approved_by()
663671
# Raises exception if matching rule is not found
664-
find_matching_merge_rule(pr, repo, force=force)
672+
find_matching_merge_rule(pr, repo, force=force, skip_internal_checks=can_skip_internal_checks(self, comment_id))
665673

666674
# Adding the url here makes it clickable within the Github UI
667675
approved_by_urls = ', '.join(prefix_with_github_url(login) for login in approved_by)
@@ -670,11 +678,9 @@ def merge_ghstack_into(self, repo: GitRepo, force: bool) -> None:
670678
msg += f"\nApproved by: {approved_by_urls}\n"
671679
repo.amend_commit_message(msg)
672680

673-
def merge_into(self, repo: GitRepo, *, force: bool = False, dry_run: bool = False) -> None:
681+
def merge_into(self, repo: GitRepo, *, force: bool = False, dry_run: bool = False, comment_id: Optional[int] = None) -> None:
674682
# Raises exception if matching rule is not found
675-
find_matching_merge_rule(self, repo, force=force)
676-
if self.has_internal_changes():
677-
raise RuntimeError("This PR must be landed via phabricator")
683+
find_matching_merge_rule(self, repo, force=force, skip_internal_checks=can_skip_internal_checks(self, comment_id))
678684
if repo.current_branch() != self.default_branch():
679685
repo.checkout(self.default_branch())
680686
if not self.is_ghstack_pr():
@@ -688,7 +694,7 @@ def merge_into(self, repo: GitRepo, *, force: bool = False, dry_run: bool = Fals
688694
repo._run_git("merge", "--squash", pr_branch_name)
689695
repo._run_git("commit", f"--author=\"{self.get_author()}\"", "-m", msg)
690696
else:
691-
self.merge_ghstack_into(repo, force)
697+
self.merge_ghstack_into(repo, force, comment_id=comment_id)
692698

693699
repo.push(self.default_branch(), dry_run)
694700
if not dry_run:
@@ -823,9 +829,10 @@ def post_comment(msg: str) -> None:
823829
expected_association = "CONTRIBUTOR" if pr.is_base_repo_private() else "MEMBER"
824830
if author_association != expected_association and author_association != "OWNER":
825831
return post_comment(f"Will not revert as @{author_login} is not a {expected_association}, but {author_association}")
832+
skip_internal_checks = can_skip_internal_checks(pr, comment_id)
826833

827834
# Raises exception if matching rule is not found, but ignores all status checks
828-
find_matching_merge_rule(pr, repo, force=True)
835+
find_matching_merge_rule(pr, repo, force=True, skip_internal_checks=skip_internal_checks)
829836
commit_sha = pr.get_merge_commit()
830837
if commit_sha is None:
831838
commits = repo.commits_resolving_gh_pr(pr.pr_num)
@@ -834,7 +841,7 @@ def post_comment(msg: str) -> None:
834841
commit_sha = commits[0]
835842
msg = repo.commit_message(commit_sha)
836843
rc = RE_DIFF_REV.search(msg)
837-
if rc is not None:
844+
if rc is not None and not can_skip_internal_checks:
838845
raise RuntimeError(f"Can't revert PR that was landed via phabricator as {rc.group(1)}")
839846
repo.checkout(pr.default_branch())
840847
repo.revert(commit_sha)
@@ -913,7 +920,7 @@ def handle_exception(e: Exception, msg: str = "Merge failed") -> None:
913920
handle_exception(e)
914921
else:
915922
try:
916-
pr.merge_into(repo, dry_run=args.dry_run, force=args.force)
923+
pr.merge_into(repo, dry_run=args.dry_run, force=args.force, comment_id=args.comment_id)
917924
except Exception as e:
918925
handle_exception(e)
919926

.github/workflows/revert.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,5 @@ jobs:
3737
else
3838
python3 .github/scripts/trymerge.py --revert "${PR_NUM}"
3939
fi
40+
41+
concurrency: try-revert

.github/workflows/trymerge.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,23 @@ jobs:
3131
GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
3232
FORCE: ${{ github.event.client_payload.force}}
3333
ON_GREEN: ${{ github.event.client_payload.on_green}}
34+
COMMENT_ID: ${{ github.event.client_payload.comment_id }}
3435
run: |
36+
set -ex
3537
if [ -n "${FORCE}" ]; then
36-
python3 .github/scripts/trymerge.py --force "${PR_NUM}"
38+
if [ -n "${COMMENT_ID}" ]; then
39+
python3 .github/scripts/trymerge.py --force --comment-id "${COMMENT_ID}" "${PR_NUM}"
40+
else
41+
python3 .github/scripts/trymerge.py --force "${PR_NUM}"
42+
fi
3743
elif [ -n "${ON_GREEN}" ]; then
3844
python3 .github/scripts/trymerge.py --on-green "${PR_NUM}"
45+
elif [ -n "${COMMENT_ID}" ]; then
46+
python3 .github/scripts/trymerge.py --comment-id "${COMMENT_ID}" "${PR_NUM}"
3947
else
4048
python3 .github/scripts/trymerge.py "${PR_NUM}"
4149
fi
50+
51+
# TODO: Separate merge on green merges from regular merges to not hold up try-merge workflows overall concurrency
52+
# NOTE: force pushes are also put in their concurrency group to put them higher than regular merges
53+
concurrency: try-merge-${{ github.event.client_payload.force}}-${{ github.event.client_payload.on_green }}

CMakeLists.txt

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,25 +97,34 @@ if(APPLE)
9797
# Determine if we can link against MPSGraph
9898
set(MPS_FOUND OFF)
9999
execute_process(
100-
COMMAND bash -c "xcrun --sdk macosx --show-sdk-path"
101-
OUTPUT_VARIABLE _macosx_sdk_path
100+
COMMAND bash -c "xcodebuild -sdk macosx -version SDKVersion"
101+
RESULT_VARIABLE _exit_code
102+
OUTPUT_VARIABLE _macosx_sdk_version
102103
OUTPUT_STRIP_TRAILING_WHITESPACE)
103-
set(_MPS_supported_os_version OFF)
104-
if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin"
105-
AND DARWIN_MAJOR_VERSION VERSION_GREATER_EQUAL 21
106-
AND DARWIN_MINOR_VERSION VERSION_GREATER_EQUAL 3)
107-
set(_MPS_supported_os_version ON)
108-
endif()
109-
set(_SDK_SEARCH_PATH "${_macosx_sdk_path}/System/Library/Frameworks/")
110-
set(_FRAMEWORK_SEARCH_PATH "/System/Library/Frameworks/")
111-
112-
find_library(_MPS_fwrk_path_ NAMES MetalPerformanceShadersGraph MetalPerformanceShaders PATHS ${_FRAMEWORK_SEARCH_PATH} NO_DEFAULT_PATH)
113-
find_library(_MPS_sdk_path_ NAMES MetalPerformanceShadersGraph MetalPerformanceShaders PATHS ${_SDK_SEARCH_PATH} NO_DEFAULT_PATH)
114-
115-
if(_MPS_supported_os_version AND _MPS_fwrk_path_ AND _MPS_sdk_path_)
116-
set(MPS_FOUND ON)
117-
message(STATUS "MPSGraph framework found")
104+
if(_exit_code EQUAL 0)
105+
set(_MPS_supported_os_version OFF)
106+
if(_macosx_sdk_version VERSION_GREATER_EQUAL 12.3)
107+
set(_MPS_supported_os_version ON)
108+
endif()
109+
message(STATUS "sdk version: ${_macosx_sdk_version}, mps supported: ${_MPS_supported_os_version}")
110+
execute_process(
111+
COMMAND bash -c "xcrun --sdk macosx --show-sdk-path"
112+
OUTPUT_VARIABLE _macosx_sdk_path
113+
OUTPUT_STRIP_TRAILING_WHITESPACE)
114+
set(_SDK_SEARCH_PATH "${_macosx_sdk_path}/System/Library/Frameworks/")
115+
set(_FRAMEWORK_SEARCH_PATH "/System/Library/Frameworks/")
116+
117+
find_library(_MPS_fwrk_path_ NAMES MetalPerformanceShadersGraph MetalPerformanceShaders PATHS ${_FRAMEWORK_SEARCH_PATH} NO_DEFAULT_PATH)
118+
find_library(_MPS_sdk_path_ NAMES MetalPerformanceShadersGraph MetalPerformanceShaders PATHS ${_SDK_SEARCH_PATH} NO_DEFAULT_PATH)
119+
120+
if(_MPS_supported_os_version AND _MPS_fwrk_path_ AND _MPS_sdk_path_)
121+
set(MPS_FOUND ON)
122+
message(STATUS "MPSGraph framework found")
123+
else()
124+
message(STATUS "MPSGraph framework not found")
125+
endif()
118126
else()
127+
message(STATUS "MPS: unable to get MacOS sdk version")
119128
message(STATUS "MPSGraph framework not found")
120129
endif()
121130
endif()

aten/src/ATen/FunctionalTensorWrapper.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ void FunctionalTensorWrapper::set_constructor_metadata() {
1717
// For now I'm retroactively setting this in functorch,
1818
// but once Open Multiple Dispatch lands we should be able to calculate this in core.
1919
level_ = -1;
20-
// shallow_copy_from overwrites the storage and dispatch keyset...
21-
auto functional_storage = storage_;
22-
shallow_copy_from(value_.getIntrusivePtr());
23-
storage_ = functional_storage;
20+
// mirror all of the generic tensor metadata onto the wrapper
21+
copy_generic_tensor_metadata(value_.getIntrusivePtr().get(), this);
22+
refresh_numel();
23+
refresh_contiguous();
2424
storage_access_should_throw_ = false;
2525
key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set();
2626
// All of the keys corresponding to functorch transforms should not be copied over.

aten/src/ATen/core/ivalue.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -750,15 +750,17 @@ struct TORCH_API IValue final {
750750
// Scalar, which gets encoded as either an Int, a Double or a ComplexDouble
751751
IValue(const at::Scalar& s) : IValue() {
752752
if (s.isFloatingPoint()) {
753-
*this = s.toDouble();
753+
tag = Tag::Double;
754+
payload.u.as_double = s.toDouble();
754755
} else if (s.isComplex()) {
755756
*this = s.toComplexDouble();
756757
} else if (s.isBoolean()) {
757-
*this = s.toBool();
758-
} else if (s.isIntegral(false)) {
759-
*this = s.toLong();
758+
tag = Tag::Bool;
759+
payload.u.as_bool = s.toBool();
760760
} else {
761-
TORCH_CHECK(false, "Unknown type in Scalar");
761+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(s.isIntegral(false), "Unknown type in Scalar");
762+
tag = Tag::Int;
763+
payload.u.as_int = s.toLong();
762764
}
763765
}
764766

aten/src/ATen/core/ivalue_inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
11791179
continue;
11801180
}
11811181
c10::Device device = storage->device();
1182-
if (!device.is_cpu()) {
1182+
if (!device.is_cpu() && !device.is_meta()) {
11831183
TORCH_CHECK_VALUE(
11841184
device.type() == impl.type(),
11851185
"Expected all data ptrs to be on a device of type ",

aten/src/ATen/core/jit_type.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ struct TORCH_API TensorType : public SharedType {
787787
static const TypeKind Kind = TypeKind::TensorType;
788788

789789
static std::vector<int64_t> contiguousStridesOf(
790-
at::IntArrayRef sizes,
790+
at::IntArrayRef in_sizes,
791791
at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
792792
auto contiguous_fn = [](const at::IntArrayRef& sizes,
793793
const std::vector<int64_t>& dim_order) {
@@ -804,18 +804,18 @@ struct TORCH_API TensorType : public SharedType {
804804
return strides;
805805
};
806806

807-
std::vector<int64_t> dim_order(sizes.size());
807+
std::vector<int64_t> dim_order(in_sizes.size());
808808
if (memory_format == MemoryFormat::ChannelsLast) {
809809
dim_order = {1, 3, 2, 0};
810810
} else if (memory_format == MemoryFormat::ChannelsLast3d) {
811811
dim_order = {1, 4, 3, 2, 0};
812812
} else {
813-
auto ndims = sizes.size();
813+
auto ndims = in_sizes.size();
814814
for (size_t i = 0; i < ndims; i++) {
815815
dim_order[i] = ndims - i - 1; // Reverse
816816
}
817817
}
818-
return contiguous_fn(sizes, dim_order);
818+
return contiguous_fn(in_sizes, dim_order);
819819
}
820820

821821
private:

aten/src/ATen/core/library.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ namespace torch {
77
namespace {
88
// TODO: Consider representing debug info as a struct instead so you
99
// don't have to allocate strings all the time
10-
std::string debugString(const std::string& file, uint32_t line) {
10+
std::string debugString(const char* file, uint32_t line) {
1111
#ifdef STRIP_ERROR_MESSAGES
1212
return std::string();
1313
#else
1414
return c10::str("registered at ", file, ":", line);
1515
#endif
1616
}
1717

18-
std::string debugString(std::string debug, const std::string& file, uint32_t line) {
18+
std::string debugString(std::string debug, const char* file, uint32_t line) {
1919
#ifdef STRIP_ERROR_MESSAGES
2020
return std::string();
2121
#else

aten/src/ATen/cpu/vec/functional_base.h

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
namespace at { namespace vec {
1010

11-
// TODO: Make this more efficient
11+
// slow path
1212
template <typename scalar_t, typename Op>
1313
inline scalar_t vec_reduce_all(
1414
const Op& vec_fun,
@@ -27,6 +27,62 @@ inline scalar_t vec_reduce_all(
2727
return acc_arr[0];
2828
}
2929

30+
template <typename scalar_t, typename Op>
31+
struct VecReduceAllSIMD {
32+
static inline scalar_t apply(const Op& vec_fun, Vectorized<scalar_t> acc_vec) {
33+
return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
34+
}
35+
};
36+
37+
#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
38+
#if defined(CPU_CAPABILITY_AVX2)
39+
template <typename Op>
40+
struct VecReduceAllSIMD<float, Op> {
41+
static inline float apply(const Op& vec_fun, Vectorized<float> acc_vec) {
42+
using Vec = Vectorized<float>;
43+
Vec v = acc_vec;
44+
// 128-bit shuffle
45+
Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
46+
v = vec_fun(v, v1);
47+
// 64-bit shuffle
48+
v1 = _mm256_shuffle_ps(v, v, 0x4E);
49+
v = vec_fun(v, v1);
50+
// 32-bit shuffle
51+
v1 = _mm256_shuffle_ps(v, v, 0xB1);
52+
v = vec_fun(v, v1);
53+
return _mm256_cvtss_f32(v);
54+
}
55+
};
56+
#endif // defined(CPU_CAPABILITY_AVX2)
57+
#if defined(CPU_CAPABILITY_AVX512)
58+
template <typename Op>
59+
struct VecReduceAllSIMD<float, Op> {
60+
static inline float apply(const Op& vec_fun, Vectorized<float> acc_vec) {
61+
using Vec = Vectorized<float>;
62+
Vec v = acc_vec;
63+
// 256-bit shuffle
64+
Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
65+
v = vec_fun(v, v1);
66+
// 128-bit shuffle
67+
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
68+
v = vec_fun(v, v1);
69+
// 64-bit shuffle
70+
v1 = _mm512_shuffle_ps(v, v, 0x4E);
71+
v = vec_fun(v, v1);
72+
// 32-bit shuffle
73+
v1 = _mm512_shuffle_ps(v, v, 0xB1);
74+
v = vec_fun(v, v1);
75+
return _mm512_cvtss_f32(v);
76+
}
77+
};
78+
#endif // defined(CPU_CAPABILITY_AVX512)
79+
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
80+
81+
template <typename scalar_t, typename Op>
82+
inline scalar_t vec_reduce_all(const Op& vec_fun, Vectorized<scalar_t> acc_vec) {
83+
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
84+
}
85+
3086
template <typename scalar_t, typename Op>
3187
inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
3288
using Vec = vec::Vectorized<scalar_t>;
@@ -42,7 +98,7 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size
4298
Vec data_vec = Vec::loadu(data + d, size - d);
4399
acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
44100
}
45-
return vec_reduce_all(vec_fun, acc_vec, Vec::size());
101+
return vec_reduce_all(vec_fun, acc_vec);
46102
}
47103

48104
// similar to reduce_all, but reduces into two outputs
@@ -70,8 +126,8 @@ inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2&
70126
acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
71127
}
72128
return std::pair<scalar_t, scalar_t>(
73-
vec_reduce_all(vec_fun1, acc_vec1, Vec::size()),
74-
vec_reduce_all(vec_fun2, acc_vec2, Vec::size()));
129+
vec_reduce_all(vec_fun1, acc_vec1),
130+
vec_reduce_all(vec_fun2, acc_vec2));
75131
}
76132

77133
template <typename scalar_t, typename MapOp, typename ReduceOp>
@@ -95,7 +151,7 @@ inline scalar_t map_reduce_all(
95151
data_vec = map_fun(data_vec);
96152
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
97153
}
98-
return vec_reduce_all(red_fun, acc_vec, Vec::size());
154+
return vec_reduce_all(red_fun, acc_vec);
99155
}
100156

101157
template <typename scalar_t, typename MapOp, typename ReduceOp>
@@ -126,7 +182,7 @@ inline scalar_t map2_reduce_all(
126182
data_vec = map_fun(data_vec, data2_vec);
127183
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
128184
}
129-
return vec_reduce_all(red_fun, acc_vec, Vec::size());
185+
return vec_reduce_all(red_fun, acc_vec);
130186
}
131187

132188
template <typename scalar_t, typename MapOp, typename ReduceOp>
@@ -162,7 +218,7 @@ inline scalar_t map3_reduce_all(
162218
data_vec = map_fun(data_vec, data2_vec, data3_vec);
163219
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
164220
}
165-
return vec_reduce_all(red_fun, acc_vec, Vec::size());
221+
return vec_reduce_all(red_fun, acc_vec);
166222
}
167223

168224
template <typename scalar_t, typename Op>

0 commit comments

Comments
 (0)
0