8000 Fixed column major reducer · xtensor-stack/xtensor@3a9827c · GitHub
[go: up one dir, main page]

Skip to content

Commit 3a9827c

Browse files
committed
Fixed column major reducer
1 parent 3f60185 commit 3a9827c

File tree

1 file changed

+66
-67
lines changed

1 file changed

+66
-67
lines changed

include/xtensor/xreducer.hpp

Lines changed: 66 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ namespace xt
380380
{
381381
// for unknown reasons it's much faster to use a temporary variable and
382382
// std::accumulate here -- probably some cache behavior
383-
result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
383+
result_type tmp = init_fct();
384384
tmp = std::accumulate(begin , begin + outer_loop_size, tmp, reduce_fct);
385385

386386
// use merge function if necessary
@@ -413,7 +413,7 @@ namespace xt
413413
return merge ?
414414
reduce_fct(v1, v2) :
415415
// cast because return type of identity function is not upcasted
416-
reduce_fct(static_cast<result_type>(options_t::has_initial_value ? options.initial_value : init_fct()), v2);
416+
reduce_fct(static_cast<result_type>(init_fct()), v2);
417417
});
418418

419419
begin += inner_stride;
@@ -439,7 +439,11 @@ namespace xt
439439
}
440440
};
441441
}
442-
442+
if (options_t::has_initial_value)
443+
{
444+
std::transform(result.data(), result.data() + result.size(), result.data(),
445+
[&reduce_fct, &options](auto&& v) { return reduce_fct(v, static_cast<result_type>(options.initial_value)); });
446+
}
443447
return result;
444448
}
445449

@@ -893,8 +897,9 @@ namespace xt
893897

894898
private:
895899

896-
reference aggregate(size_type dim, /*keep_dims=*/ std::false_type) const;
897-
reference aggregate(size_type dim, /*keep_dims=*/ std::true_type) const;
900+
reference aggregate(size_type dim) const;
901+
reference aggregate_impl(size_type dim, /*keep_dims=*/ std::false_type) const;
902+
reference aggregate_impl(size_type dim, /*keep_dims=*/ std::true_type) const;
898903

899904
substepper_type get_substepper_begin() const;
900905
size_type get_dim(size_type dim) const noexcept;
@@ -1342,7 +1347,7 @@ namespace xt
13421347
template <class F, class CT, class X, class O>
13431348
inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
13441349
{
1345-
reference r = aggregate(0, typename O::keep_dims());
1350+
reference r = aggregate(0);
13461351
return r;
13471352
}
13481353

@@ -1426,96 +1431,90 @@ namespace xt
14261431
}
14271432

14281433
template <class F, class CT, class X, class O>
1429-
inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim, std::false_type) const -> reference
1434+
inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
14301435
{
1431-
reference res;
1432-
if(m_reducer->m_e.shape().empty())
1436+
if (m_reducer->m_e.shape().empty())
14331437
{
1434-
res = m_reducer->m_reduce(O::has_initial_value ? m_reducer->m_options.initial_value : m_reducer->m_init(), *m_stepper);
1438+
reference res =
1439+
m_reducer->m_reduce(static_cast<reference>(O::has_initial_value ? m_reducer->m_options.initial_value : m_reducer->m_init()),
1440+
*m_stepper);
14351441
return res;
14361442
}
14371443
else
14381444
{
1439-
size_type index = axis(dim);
1440-
size_type size = shape(index);
1441-
if (dim != m_reducer->m_axes.size() - 1)
1445+ E30A
reference res = aggregate_impl(dim, typename O::keep_dims());
1446+
if (O::has_initial_value && dim == 0)
14421447
{
1443-
res = aggregate(dim + 1, typename O::keep_dims());
1444-
if (O::has_initial_value && dim == 0)
1445-
{
1446-
res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1447-
}
1448+
res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1449+
}
1450+
return res;
1451+
}
1452+
}
14481453

1449-
for (size_type i = 1; i != size; ++i)
1450-
{
1451-
m_stepper.step(index);
1452-
res = m_reducer->m_merge(res, aggregate(dim + 1, typename O::keep_dims()));
1453-
}
1454+
template <class F, class CT, class X, class O>
1455+
inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
1456+
{
1457+
reference res;
1458+
size_type index = axis(dim);
1459+
size_type size = shape(index);
1460+
if (dim != m_reducer->m_axes.size() - 1)
1461+
{
1462+
res = aggregate_impl(dim + 1, typename O::keep_dims());
1463+
for (size_type i = 1; i != size; ++i)
1464+
{
1465+
m_stepper.step(index);
1466+
res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
14541467
}
1455-
else
1468+
}
1469+
else
1470+
{
1471+
res = m_reducer->m_init();
1472+
for (size_type i = 0; i != size; ++i, m_stepper.step(index))
14561473
{
1457-
res = m_reducer->m_init();
1458-
for (size_type i = 0; i != size; ++i, m_stepper.step(index))
1459-
{
1460-
res = m_reducer->m_reduce(res, *m_stepper);
1461-
}
1462-
m_stepper.step_back(index);
1474+
res = m_reducer->m_reduce(res, *m_stepper);
14631475
}
1464-
m_stepper.reset(index);
1476+
m_stepper.step_back(index);
14651477
}
1478+
m_stepper.reset(index);
14661479
return res;
14671480
}
14681481

14691482
template <class F, class CT, class X, class O>
1470-
inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim, std::true_type) const -> reference
1483+
inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type) const -> reference
14711484
{
14721485
reference res(0);
1473-
if(m_reducer->m_e.shape().empty())
1486+
auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1487+
if (ax_it != m_reducer->m_axes.end())
14741488
{
1475-
res = m_reducer->m_reduce(O::has_initial_value ? m_reducer->m_options.initial_value : m_reducer->m_init(), *m_stepper);
1476-
return res;
1477-
}
1478-
else
1479-
{
1480-
auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1481-
if (ax_it != m_reducer->m_axes.end())
1489+
size_type index = dim;
1490+
size_type size = m_reducer->m_e.shape()[index];
1491+
if (ax_it != m_reducer->m_axes.end() - 1)
14821492
{
1483-
size_type index = dim;
1484-
size_type size = m_reducer->m_e.shape()[index];
1485-
if (ax_it != m_reducer->m_axes.end() - 1)
1486-
{
1487-
res = aggregate(dim + 1, typename O::keep_dims());
1488-
if (O::has_initial_value && ax_it == m_reducer->m_axes.begin())
1489-
{
1490-
res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1491-
}
1492-
1493-
for (size_type i = 1; i != size; ++i)
1494-
{
1495-
m_stepper.step(index);
1496-
res = m_reducer->m_merge(res, aggregate(dim + 1, typename O::keep_dims()));
1497-
}
1498-
}
1499-
else
1493+
res = aggregate_impl(dim + 1, typename O::keep_dims());
1494+
for (size_type i = 1; i != size; ++i)
15001495
{
1501-
res = m_reducer->m_init();
1502-
for (size_type i = 0; i != size; ++i, m_stepper.step(index))
1503-
{
1504-
res = m_reducer->m_reduce(res, *m_stepper);
1505-
}
1506-
m_stepper.step_back(index);
1496+
m_stepper.step(index);
1497+
res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
15071498
}
1508-
m_stepper.reset(index);
15091499
}
15101500
else
15111501
{
1512-
if (dim < m_reducer->m_e.dimension())
1502+
res = m_reducer->m_init();
1503+
for (size_type i = 0; i != size; ++i, m_stepper.step(index))
15131504
{
1514-
res = aggregate(dim + 1, typename O::keep_dims());
1505+
res = m_reducer->m_reduce(res, *m_stepper);
15151506
}
1507+
m_stepper.step_back(index);
1508+
}
1509+
m_stepper.reset(index);
1510+
}
1511+
else
1512+
{
1513+
if (dim < m_reducer->m_e.dimension())
1514+
{
1515+
res = aggregate_impl(dim + 1, typename O::keep_dims());
15161516
}
15171517
}
1518-
15191518
return res;
15201519
}
15211520

0 commit comments

Comments
 (0)
0