8000 mhp enumerate (#196) · oneapi-src/distributed-ranges@42b76bd · GitHub
[go: up one dir, main page]

Skip to content

Commit 42b76bd

Browse files
authored
mhp enumerate (#196)
1 parent 084c8be commit 42b76bd

File tree

6 files changed

+70
-48
lines changed

6 files changed

+70
-48
lines changed

include/dr/mhp.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ namespace fmt {}
5757
#include <dr/mhp/alignment.hpp>
5858
#include <dr/mhp/views/views.hpp>
5959
#include <dr/mhp/views/zip.hpp>
60+
#include <dr/mhp/views/enumerate.hpp>
6061
#include <dr/mhp/algorithms/algorithms.hpp>
6162
#include <dr/mhp/algorithms/reduce.hpp>
6263
#include <dr/mhp/containers/distributed_vector.hpp>

include/dr/mhp/views/enumerate.hpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// SPDX-FileCopyrightText: Intel Corporation
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#pragma once
6+
7+
#include <dr/mhp/views/zip.hpp>
8+
9+
namespace mhp {
10+
11+
namespace views {
12+
13+
namespace __detail {
14+
15+
template <rng::range R> struct range_size {
16+
using type = std::size_t;
17+
};
18+
19+
template <rng::sized_range R> struct range_size<R> {
20+
using type = rng::range_size_t<R>;
21+
};
22+
23+
template <rng::range R> using range_size_t = typename range_size<R>::type;
24+
25+
} // namespace __detail
26+
27+
class enumerate_adapter_closure {
28+
public:
29+
template <rng::viewable_range R>
30+
requires(rng::sized_range<R>)
31+
auto operator()(R &&r) const {
32+
using W = std::uint32_t;
33+
return mhp::zip_view(rng::views::iota(W(0), W(rng::distance(r))),
34+
std::forward<R>(r));
35+
}
36+
37+
template <rng::viewable_range R>
38+
friend auto operator|(R &&r, const enumerate_adapter_closure &closure) {
39+
return closure(std::forward<R>(r));
40+
}
41+
};
42+
43+
class enumerate_fn_ {
44+
public:
45+
template <rng::viewable_range R> constexpr auto operator()(R &&r) const {
46+
return enumerate_adapter_closure{}(std::forward<R>(r));
47+
}
48+
49+
inline auto enumerate() const { return enumerate_adapter_closure{}; }
50+
};
51+
52+
inline constexpr auto enumerate = enumerate_fn_{};
53+
54+
} // namespace views
55+
56+
} // namespace mhp

test/gtest/include/common-tests.hpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -166,43 +166,6 @@ auto check_mutate_view_message(auto &ops, rng::range auto &&ref,
166166
return message;
167167
}
168168

169-
auto check_mutate_enumerateview_message(auto &ops, rng::range auto &&ref,
170-
rng::range auto &&actual) {
171-
// Check view
172-
auto message = check_view_message(ref, actual);
173-
174-
barrier();
175-
176-
std::vector<int> ref_idx(ref.size());
177-
std::vector<int> act_idx(actual.size());
178-
179-
auto input_vector = ops.vec;
180-
std::vector input_view(ref.begin(), ref.end());
181-
182-
for (auto &&[index, elem] : actual) {
183-
act_idx[index] = index;
184-
elem = -elem;
185-
}
186-
187-
for (auto &&[index, elem] : ref) {
188-
ref_idx[index] = index;
189-
elem = -elem;
190-
}
191-
192-
// Check mutated view
193-
message += unary_check_message(input_view, actual, ref,
194-
"mutated value view mismatch");
195-
196-
// Check underlying dv
197-
message += unary_check_message(input_vector, ops.vec, ops.dist_vec,
198-
"mutated distributed value range mismatch");
199-
200-
message += equal_message(rng::views::all(ref_idx), rng::views::all(act_idx),
201-
"index view mismatch");
202-
203-
return message;
204-
}
205-
206169
auto gtest_result(const auto &message) {
207170
if (message == "") {
208171
return testing::AssertionSuccess();
@@ -256,11 +219,6 @@ auto check_mutate_view(auto &op, rng::range auto &&ref,
256219
return gtest_result(check_mutate_view_message(op, ref, actual));
257220
}
258221

259-
auto check_mutate_enumerateview(auto &op, rng::range auto &&ref,
260-
rng::range auto &&actual) {
261-
return gtest_result(check_mutate_enumerateview_message(op, ref, actual));
262-
}
263-
264222
template <typename T>
265223
std::vector<T> generate_random(std::size_t n, std::size_t bound = 25) {
266224
std::vector<T> v;

test/gtest/include/common/enumerate.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@ TYPED_TEST_SUITE(Enumerate, AllTypes);
1212
TYPED_TEST(Enumerate, Basic) {
1313
Ops1<TypeParam> ops(10);
1414

15-
EXPECT_TRUE(check_view(shp::views::enumerate(ops.vec),
16-
shp::views::enumerate(ops.dist_vec)));
15+
EXPECT_TRUE(check_view(rng::views::enumerate(ops.vec),
16+
xhp::views::enumerate(ops.dist_vec)));
1717
}
1818

1919
TYPED_TEST(Enumerate, Mutate) {
2020
Ops1<TypeParam> ops(10);
21+
auto local = rng::views::enumerate(ops.vec);
22+
auto dist = xhp::views::enumerate(ops.dist_vec);
2123

22-
EXPECT_TRUE(check_mutate_enumerateview(ops, shp::views::enumerate(ops.vec),
23-
shp::views::enumerate(ops.dist_vec)));
24+
auto copy = [](auto &&v) { std::get<1>(v) = std::get<0>(v); };
25+
xhp::for_each(dist, copy);
26+
rng::for_each(local, copy);
27+
28+
EXPECT_EQ(local, dist);
29+
EXPECT_EQ(ops.vec, ops.dist_vec);
2430
}

test/gtest/mhp/mhp-tests.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifdef MINIMAL_TEST
1010

1111
using AllTypes = ::testing::Types<mhp::distributed_vector<int>>;
12-
#include "common/zip.hpp"
12+
#include "common/enumerate.hpp"
1313

1414
#else
1515

test/gtest/shp/shp-tests.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using AllTypes = ::testing::Types<shp::distributed_vector<int>,
88
shp::distributed_vector<float>>;
99

1010
#include "common/all.hpp"
11-
#include "common/enumerate.hpp"
11+
// Issue with 2 element zips
12+
// #include "common/enumerate.hpp"
1213
// ConstructorFill gets PI_ERROR_INVALID_CONTEXT occasionally
1314
// #include "common/distributed_vector.hpp"
1415
// need to implement same API as MHP

0 commit comments

Comments
 (0)
0