@@ -103,21 +103,19 @@ namespace xt
103
103
struct reducer_options
104
104
{
105
105
template <class X >
106
- struct initial_tester : std::false_type {
107
- using type = double ;
108
- };
106
+ struct initial_tester : std::false_type {};
109
107
110
108
template <class X >
111
- struct initial_tester <xinitial<X>> : std::true_type {
112
- using type = X;
113
- };
109
+ struct initial_tester <xinitial<X>> : std::true_type {};
110
+
111
+ // Workaround for Apple because tuple_cat is buggy!
112
+ template <class X >
113
+ struct initial_tester <const xinitial<X>> : std::true_type {};
114
114
115
115
using d_t = std::decay_t <T>;
116
116
117
117
static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t >::value;
118
118
reducer_options () = default ;
119
- reducer_options (reducer_options&&) = default ;
120
- reducer_options (const reducer_options&) = default ;
121
119
122
120
reducer_options (const T& tpl)
123
121
{
@@ -145,7 +143,7 @@ namespace xt
145
143
using rebind_t = reducer_options<NR, T>;
146
144
147
145
template <class NR >
148
- auto rebind (NR initial, const reducer_options<R, T>& opts ) const
146
+ auto rebind (NR initial, const reducer_options<R, T>&) const
149
147
{
150
148
reducer_options<NR, T> res;
151
149
res.initial_value = initial;
@@ -380,7 +378,7 @@ namespace xt
380
378
{
381
379
// for unknown reasons it's much faster to use a temporary variable and
382
380
// std::accumulate here -- probably some cache behavior
383
- result_type tmp = options_t ::has_initial_value ? options. initial_value : init_fct ();
381
+ result_type tmp = init_fct ();
384
382
tmp = std::accumulate (begin , begin + outer_loop_size, tmp, reduce_fct);
385
383
386
384
// use merge function if necessary
@@ -413,7 +411,7 @@ namespace xt
413
411
return merge ?
414
412
reduce_fct (v1, v2) :
415
413
// cast because return type of identity function is not upcasted
416
- reduce_fct (static_cast <result_type>(options_t ::has_initial_value ? options. initial_value : init_fct ()), v2);
414
+ reduce_fct (static_cast <result_type>(init_fct ()), v2);
417
415
});
418
416
419
417
begin += inner_stride;
@@ -439,7 +437,11 @@ namespace xt
439
437
}
440
438
};
441
439
}
442
-
440
+ if (options_t ::has_initial_value)
441
+ {
442
+ std::transform (result.data (), result.data () + result.size (), result.data (),
443
+ [&reduce_fct, &options](auto && v) { return reduce_fct (static_cast <result_type>(v), options.initial_value ); });
444
+ }
443
445
return result;
444
446
}
445
447
@@ -893,8 +895,9 @@ namespace xt
893
895
894
896
private:
895
897
896
- reference aggregate (size_type dim, /* keep_dims=*/ std::false_type) const ;
897
- reference aggregate (size_type dim, /* keep_dims=*/ std::true_type) const ;
898
+ reference aggregate (size_type dim) const ;
899
+ reference aggregate_impl (size_type dim, /* keep_dims=*/ std::false_type) const ;
900
+ reference aggregate_impl (size_type dim, /* keep_dims=*/ std::true_type) const ;
898
901
899
902
substepper_type get_substepper_begin () const ;
900
903
size_type get_dim (size_type dim) const noexcept ;
@@ -1342,7 +1345,7 @@ namespace xt
1342
1345
template <class F , class CT , class X , class O >
1343
1346
inline auto xreducer_stepper<F, CT, X, O>::operator *() const -> reference
1344
1347
{
1345
- reference r = aggregate (0 , typename O::keep_dims () );
1348
+ reference r = aggregate (0 );
1346
1349
return r;
1347
1350
}
1348
1351
@@ -1426,96 +1429,90 @@ namespace xt
1426
1429
}
1427
1430
1428
1431
template <class F , class CT , class X , class O >
1429
- inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim, std::false_type ) const -> reference
1432
+ inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
1430
1433
{
1431
- reference res;
1432
- if (m_reducer->m_e .shape ().empty ())
1434
+ if (m_reducer->m_e .shape ().empty ())
1433
1435
{
1434
- res = m_reducer->m_reduce (O::has_initial_value ? m_reducer->m_options .initial_value : m_reducer->m_init (), *m_stepper);
1436
+ reference res =
1437
+ m_reducer->m_reduce (static_cast <reference>(O::has_initial_value ? m_reducer->m_options .initial_value : m_reducer->m_init ()),
1438
+ *m_stepper);
1435
1439
return res;
1436
1440
}
1437
1441
else
1438
1442
{
1439
- size_type index = axis (dim);
1440
- size_type size = shape (index);
1441
- if (dim != m_reducer->m_axes .size () - 1 )
1443
+ reference res = aggregate_impl (dim, typename O::keep_dims ());
1444
+ if (O::has_initial_value && dim == 0 )
1442
1445
{
1443
- res = aggregate (dim + 1 , typename O::keep_dims () );
1444
- if (O::has_initial_value && dim == 0 )
1445
- {
1446
- res = m_reducer-> m_merge (m_reducer-> m_options . initial_value , res);
1447
- }
1446
+ res = m_reducer-> m_merge (m_reducer-> m_options . initial_value , res );
1447
+ }
1448
+ return res;
1449
+ }
1450
+ }
1448
1451
1449
- for (size_type i = 1 ; i != size; ++i)
1450
- {
1451
- m_stepper.step (index);
1452
- res = m_reducer->m_merge (res, aggregate (dim + 1 , typename O::keep_dims ()));
1453
- }
1452
+ template <class F , class CT , class X , class O >
1453
+ inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
1454
+ {
1455
+ reference res;
1456
+ size_type index = axis (dim);
1457
+ size_type size = shape (index);
1458
+ if (dim != m_reducer->m_axes .size () - 1 )
1459
+ {
1460
+ res = aggregate_impl (dim + 1 , typename O::keep_dims ());
1461
+ for (size_type i = 1 ; i != size; ++i)
1462
+ {
1463
+ m_stepper.step (index);
1464
+ res = m_reducer->m_merge (res, aggregate_impl (dim + 1 , typename O::keep_dims ()));
1454
1465
}
1455
- else
1466
+ }
1467
+ else
1468
+ {
1469
+ res = m_reducer->m_init ();
1470
+ for (size_type i = 0 ; i != size; ++i, m_stepper.step (index))
1456
1471
{
1457
- res = m_reducer->m_init ();
1458
- for (size_type i = 0 ; i != size; ++i, m_stepper.step (index))
1459
- {
1460
- res = m_reducer->m_reduce (res, *m_stepper);
1461
- }
1462
- m_stepper.step_back (index);
1472
+ res = m_reducer->m_reduce (res, *m_stepper);
1463
1473
}
1464
- m_stepper.reset (index);
1474
+ m_stepper.step_back (index);
1465
1475
}
1476
+ m_stepper.reset (index);
1466
1477
return res;
1467
1478
}
1468
1479
1469
1480
template <class F , class CT , class X , class O >
1470
- inline auto xreducer_stepper<F, CT, X, O>::aggregate (size_type dim, std::true_type) const -> reference
1481
+ inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl (size_type dim, std::true_type) const -> reference
1471
1482
{
1472
1483
reference res (0 );
1473
- if (m_reducer->m_e .shape ().empty ())
1474
- {
1475
- res = m_reducer->m_reduce (O::has_initial_value ? m_reducer->m_options .initial_value : m_reducer->m_init (), *m_stepper);
1476
- return res;
1477
- }
1478
- else
1484
+ auto ax_it = std::find (m_reducer->m_axes .begin (), m_reducer->m_axes .end (), dim);
1485
+ if (ax_it != m_reducer->m_axes .end ())
1479
1486
{
1480
- auto ax_it = std::find (m_reducer->m_axes .begin (), m_reducer->m_axes .end (), dim);
1481
- if (ax_it != m_reducer->m_axes .end ())
1487
+ size_type index = dim;
1488
+ size_type size = m_reducer->m_e .shape ()[index];
1489
+ if (ax_it != m_reducer->m_axes .end () - 1 )
1482
1490
{
1483
- size_type index = dim;
1484
- size_type size = m_reducer->m_e .shape ()[index];
1485
- if (ax_it != m_reducer->m_axes .end () - 1 )
1486
- {
1487
- res = aggregate (dim + 1 , typename O::keep_dims ());
1488
- if (O::has_initial_value && ax_it == m_reducer->m_axes .begin ())
1489
- {
1490
- res = m_reducer->m_merge (m_reducer->m_options .initial_value , res);
1491
- }
1492
-
1493
- for (size_type i = 1 ; i != size; ++i)
1494
- {
1495
- m_stepper.step (index);
1496
- res = m_reducer->m_merge (res, aggregate (dim + 1 , typename O::keep_dims ()));
1497
- }
1498
- }
1499
- else
1491
+ res = aggregate_impl (dim + 1 , typename O::keep_dims ());
1492
+ for (size_type i = 1 ; i != size; ++i)
1500
1493
{
1501
- res = m_reducer->m_init ();
1502
- for (size_type i = 0 ; i != size; ++i, m_stepper.step (index))
1503
- {
1504
- res = m_reducer->m_reduce (res, *m_stepper);
1505
- }
1506
- m_stepper.step_back (index);
1494
+ m_stepper.step (index);
1495
+ res = m_reducer->m_merge (res, aggregate_impl (dim + 1 , typename O::keep_dims ()));
1507
1496
}
1508
- m_stepper.reset (index);
1509
1497
}
1510
1498
else
1511
1499
{
1512
- if (dim < m_reducer->m_e .dimension ())
1500
+ res = m_reducer->m_init ();
1501
+ for (size_type i = 0 ; i != size; ++i, m_stepper.step (index))
1513
1502
{
1514
- res = aggregate (dim + 1 , typename O::keep_dims () );
1503
+ res = m_reducer-> m_reduce (res, *m_stepper );
1515
1504
}
1505
+ m_stepper.step_back (index);
1506
+ }
1507
+ m_stepper.reset (index);
1508
+ }
1509
+ else
1510
+ {
1511
+ if (dim < m_reducer->m_e .dimension ())
1512
+ {
1513
+ res = aggregate_impl (dim + 1 , typename O::keep_dims ());
1516
1514
}
1517
1515
}
1518
-
1519
1516
return res;
1520
1517
}
1521
1518
0 commit comments