8000 Fixed column major reducer · xtensor-stack/xtensor@39440fc · GitHub
[go: up one dir, main page]

Skip to content

Commit 39440fc

Browse files
committed
Fixed column major reducer
1 parent 3f60185 commit 39440fc

File tree

1 file changed

+73
-76
lines changed

1 file changed

+73
-76
lines changed

include/xtensor/xreducer.hpp

Lines changed: 73 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,19 @@ namespace xt
103103
struct reducer_options
104104
{
105105
template <class X>
106-
struct initial_tester : std::false_type {
107-
using type = double;
108-
};
106+
struct initial_tester : std::false_type {};
109107

110108
template <class X>
111-
struct initial_tester<xinitial<X>> : std::true_type {
112-
using type = X;
113-
};
109+
struct initial_tester<xinitial<X>> : std::true_type {};
110+
111+
// Workaround for Apple because tuple_cat is buggy!
112+
template <class X>
113+
struct initial_tester<const xinitial<X>> : std::true_type {};
114114

115115
using d_t = std::decay_t<T>;
116116

117117
static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
118118
reducer_options() = default;
119-
reducer_options(reducer_options&&) = default;
120-
reducer_options(const reducer_options&) = default;
121119

122120
reducer_options(const T& tpl)
123121
{
@@ -145,7 +143,7 @@ namespace xt
145143
using rebind_t = reducer_options<NR, T>;
146144

147145
template <class NR>
148-
auto rebind(NR initial, const reducer_options<R, T>& opts) const
146+
auto rebind(NR initial, const reducer_options<R, T>&) const
149147
{
150148
reducer_options<NR, T> res;
151149
res.initial_value = initial;
@@ -380,7 +378,7 @@ namespace xt
380378
{
381379
// for unknown reasons it's much faster to use a temporary variable and
382380
// std::accumulate here -- probably some cache behavior
383-
result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
381+
result_type tmp = init_fct();
384382
tmp = std::accumulate(begin , begin + outer_loop_size, tmp, reduce_fct);
385383

386384
// use merge function if necessary
@@ -413,7 +411,7 @@ namespace xt
413411
return merge ?
414412
reduce_fct(v1, v2) :
415413
// cast because return type of identity function is not upcasted
416-
reduce_fct(static_cast<result_type>(options_t::has_initial_value ? options.initial_value : init_fct()), v2);
414+
reduce_fct(static_cast<result_type>(init_fct()), v2);
417415
});
418416

419417
begin += inner_stride;
@@ -439,7 +437,11 @@ namespace xt
439437
}
440438
};
441439
}
442-
440+
if (options_t::has_initial_value)
441+
{
442+
std::transform(result.data(), result.data() + result.size(), result.data(),
443+
[&reduce_fct, &options](auto&& v) { return reduce_fct(static_cast<result_type>(v), options.initial_value); });
444+
}
443445
return result;
444446
}
445447

@@ -893,8 +895,9 @@ namespace xt
893895

894896
private:
895897

896-
reference aggregate(size_type dim, /*keep_dims=*/ std::false_type) const;
897-
reference aggregate(size_type dim, /*keep_dims=*/ std::true_type) const;
898+
reference aggregate(size_type dim) const;
899+
reference aggregate_impl(size_type dim, /*keep_dims=*/ std::false_type) const;
900+
reference aggregate_impl(size_type dim, /*keep_dims=*/ std::true_type) const;
898901

899902
substepper_type get_substepper_begin() const;
900903
size_type get_dim(size_type dim) const noexcept;
@@ -1342,7 +1345,7 @@ namespace xt
13421345
template <class F, class CT, class X, class O>
13431346
inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
13441347
{
1345-
reference r = aggregate(0, typename O::keep_dims());
1348+
reference r = aggregate(0);
13461349
return r;
13471350
}
13481351

@@ -1426,96 +1429,90 @@ namespace xt
14261429
}
14271430

14281431
template <class F, class CT, class X, class O>
1429-
inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim, std::false_type) const -> reference
1432+
inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
14301433
{
1431-
reference res;
1432-
if(m_reducer->m_e.shape().empty())
1434+
if (m_reducer->m_e.shape().empty())
14331435
{
1434-
res = m_reducer->m_reduce(O::has_initial_value ? m_reducer->m_options.initial_value : m_reducer->m_init(), *m_stepper);
1436+
reference res =
1437+
m_reducer->m_reduce(static_cast<reference>(O::has_initial_value ? m_reducer->m_options.initial_value : m_reducer->m_init()),
1438+
*m_stepper);
14351439
return res;
14361440
}
14371441
else
14381442
{
1439-
size_type index = axis(dim);
1440-
size_type size = shape(index);
1441-
if (dim != m_reducer->m_axes.size() - 1)
1443+
reference res = aggregate_impl(dim, typename O::keep_dims());
1444+
if (O::has_initial_value && dim == 0)
14421445
{
1443-
res = aggregate(dim + 1, typename O::keep_dims());
1444-
if (O::has_initial_value && dim == 0)
1445-
{
1446-
res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1447-
}
1446+
res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1447+
}
1448+
return res;
1449+
}
1450+
}
14481451

1449-
for (size_type i = 1; i != size; ++i)
1450-
{
1451-
m_stepper.step(index);
1452-
res = m_reducer->m_merge(res, aggregate(dim + 1, typename O::keep_dims()));
1453-
}
1452+
template <class F, class CT, class X, class O>
1453+
inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
1454+
{
1455+
reference res;
1456+
size_type index = axis(dim);
1457+
size_type size = shape(index);
1458+
if (dim != m_reducer->m_axes.size() - 1)
1459+
{
1460+
res = aggregate_impl(dim + 1, typename O::keep_dims());
1461+
for (size_type i = 1; i != size; ++i)
1462+
{
1463+
m_stepper.step(index);
1464+
res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
14541465
}
1455-
else
1466+
}
1467+
else
1468+
{
1469+
res = m_reducer->m_init();
1470+
for (size_type i = 0; i != size; ++i, m_stepper.step(index))
14561471
{
1457-
res = m_reducer->m_init();
1458-
for (size_type i = 0; i != size; ++i, m_stepper.step(index))
1459-
{
1460-
res = m_reducer->m_reduce(res, *m_stepper);
1461-
}
1462-
m_stepper.step_back(index);
1472+
res = m_reducer->m_reduce(res, *m_stepper);
14631473
}
1464-
m_stepper.reset(index);
1474+
m_stepper.step_back(index);
14651475
}
1476+
m_stepper.reset(index);
14661477
return res;
14671478
}
14681479

14691480
template <class F, class CT, class X, class O>
1470-
inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim, std::true_type) const -> reference
1481+
inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type) const -> reference
14711482
{
14721483
reference res(0);
1473-
if(m_reducer->m_e.shape().empty())
1474-
{
1475-
res = m_reducer->m_reduce(O::has_initial_value ? m_reducer->m_options.initial_value : m_reducer->m_init(), *m_stepper);
1476-
return res;
1477-
}
1478-
else
1484+
auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1485+
if (ax_it != m_reducer->m_axes.end())
14791486
{
1480-
auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1481-
if (ax_it != m_reducer->m_axes.end())
1487+
size_type index = dim;
1488+
size_type size = m_reducer->m_e.shape()[index];
1489+
if (ax_it != m_reducer->m_axes.end() - 1)
14821490
{
1483-
size_type index = dim;
1484-
size_type size = m_reducer->m_e.shape()[index];
1485-
if (ax_it != m_reducer->m_axes.end() - 1)
1486-
{
1487-
res = aggregate(dim + 1, typename O::keep_dims());
1488-
if (O::has_initial_value && ax_it == m_reducer->m_axes.begin())
1489-
{
1490-
res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1491-
}
1492-
1493-
for (size_type i = 1; i != size; ++i)
1494-
{
1495-
m_stepper.step(index);
1496-
res = m_reducer->m_merge(res, aggregate(dim + 1, typename O::keep_dims()));
1497-
}
1498-
}
1499-
else
1491+
res = aggregate_impl(dim + 1, typename O::keep_dims());
1492+
for (size_type i = 1; i != size; ++i)
15001493
{
1501-
res = m_reducer->m_init();
1502-
for (size_type i = 0; i != size; ++i, m_stepper.step(index))
1503-
{
1504-
res = m_reducer->m_reduce(res, *m_stepper);
1505-
}
1506-
m_stepper.step_back(index);
1494+
m_stepper.step(index);
1495+
res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
15071496
}
1508-
m_stepper.reset(index);
15091497
}
15101498
else
15111499
{
1512-
if (dim < m_reducer->m_e.dimension())
1500+
res = m_reducer->m_init();
1501+
for (size_type i = 0; i != size; ++i, m_stepper.step(index))
15131502
{
1514-
res = aggregate(dim + 1, typename O::keep_dims());
1503+
res = m_reducer->m_reduce(res, *m_stepper);
15151504
}
1505+
m_stepper.step_back(index);
1506+
}
1507+
m_stepper.reset(index);
1508+
}
1509+
else
1510+
{
1511+
if (dim < m_reducer->m_e.dimension())
1512+
{
1513+
res = aggregate_impl(dim + 1, typename O::keep_dims());
15161514
}
15171515
}
1518-
15191516
return res;
15201517
}
15211518

0 commit comments

Comments
 (0)
0