@@ -380,7 +380,7 @@ namespace xt
380
380
{
381
381
// for unknown reasons it's much faster to use a temporary variable and
382
382
// std::accumulate here -- probably some cache behavior
383
- result_type tmp = options_t ::has_initial_value ? options. initial_value : init_fct ();
383
+ result_type tmp = init_fct ();
384
384
tmp = std::accumulate (begin , begin + outer_loop_size, tmp, reduce_fct);
385
385
386
386
// use merge function if necessary
@@ -413,7 +413,7 @@ namespace xt
413
413
return merge ?
414
414
reduce_fct (v1, v2) :
415
415
// cast because return type of identity function is not upcasted
416
- reduce_fct (static_cast <result_type>(options_t ::has_initial_value ? options. initial_value : init_fct ()), v2);
416
+ reduce_fct (static_cast <result_type>(init_fct ()), v2);
417
417
});
418
418
419
419
begin += inner_stride;
@@ -439,7 +439,11 @@ namespace xt
439
439
}
440
440
};
441
441
}
442
-
442
+ if (options_t ::has_initial_value)
443
+ {
444
+ std::transform (result.data (), result.data () + result.size (), result.data (),
445
+ [&reduce_fct, &options](auto && v) { return reduce_fct (v, static_cast <result_type>(options.initial_value )); });
446
+ }
443
447
return result;
444
448
}
445
449
@@ -893,8 +897,9 @@ namespace xt
893
897
894
898
private:
895
899
896
- reference aggregate (size_type dim, /* keep_dims=*/ std::false_type) const ;
897
- reference aggregate (size_type dim, /* keep_dims=*/ std::true_type) const ;
900
+ reference aggregate (size_type dim) const ;
901
+ reference aggregate_impl (size_type dim, /* keep_dims=*/ std::false_type) const ;
902
+ reference aggregate_impl (size_type dim, /* keep_dims=*/ std::true_type) const ;
898
903
899
904
substepper_type get_substepper_begin () const ;
900
905
size_type get_dim (size_type dim) const noexcept ;
@@ -1342,7 +1347,7 @@ namespace xt
1342
1347
template <class F , class CT , class X , class O >
1343
1348
inline auto xreducer_stepper<F, CT, X, O>::operator *() const -> reference
1344
1349
{
1345
- reference r = aggregate (0 , typename O::keep_dims () );
1350
+ reference r = aggregate (0 );
1346
1351
return r;
1347
1352
}
1348
1353
@@ -1426,96 +1431,90 @@ namespace xt
1426
1431
}
1427
1432
1428
1433
template <class F , class CT , class X , class O >
1429
- inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim, std::false_type ) const -> reference
1434
+ inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
1430
1435
{
1431
- reference res;
1432
- if (m_reducer->m_e .shape ().empty ())
1436
+ if (m_reducer->m_e .shape ().empty ())
1433
1437
{
1434
- res = m_reducer->m_reduce (O::has_initial_value ? m_reducer->m_options .initial_value : m_reducer->m_init (), *m_stepper);
1438
+ reference res =
1439
+ m_reducer->m_reduce (static_cast <reference>(O::has_initial_value ? m_reducer->m_options .initial_value : m_reducer->m_init ()),
1440
+ *m_stepper);
1435
1441
return res;
1436
1442
}
1437
1443
else
1438
1444
{
1439
- size_type index = axis (dim);
1440
- size_type size = shape (index);
1441
- if (dim != m_reducer->m_axes .size () - 1 )
1445
+
E30A
reference res = aggregate_impl (dim, typename O::keep_dims ());
1446
+ if (O::has_initial_value && dim == 0 )
1442
1447
{
1443
- res = aggregate (dim + 1 , typename O::keep_dims () );
1444
- if (O::has_initial_value && dim == 0 )
1445
- {
1446
- res = m_reducer-> m_merge (m_reducer-> m_options . initial_value , res);
1447
- }
1448
+ res = m_reducer-> m_merge (m_reducer-> m_options . initial_value , res );
1449
+ }
1450
+ return res;
1451
+ }
1452
+ }
1448
1453
1449
- for (size_type i = 1 ; i != size; ++i)
1450
- {
1451
- m_stepper.step (index);
1452
- res = m_reducer->m_merge (res, aggregate (dim + 1 , typename O::keep_dims ()));
1453
- }
1454
+ template <class F , class CT , class X , class O >
1455
+ inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
1456
+ {
1457
+ reference res;
1458
+ size_type index = axis (dim);
1459
+ size_type size = shape (index);
1460
+ if (dim != m_reducer->m_axes .size () - 1 )
1461
+ {
1462
+ res = aggregate_impl (dim + 1 , typename O::keep_dims ());
1463
+ for (size_type i = 1 ; i != size; ++i)
1464
+ {
1465
+ m_stepper.step (index);
1466
+ res = m_reducer->m_merge (res, aggregate_impl (dim + 1 , typename O::keep_dims ()));
1454
1467
}
1455
- else
1468
+ }
1469
+ else
1470
+ {
1471
+ res = m_reducer->m_init ();
1472
+ for (size_type i = 0 ; i != size; ++i, m_stepper.step (index))
1456
1473
{
1457
- res = m_reducer->m_init ();
1458
- for (size_type i = 0 ; i != size; ++i, m_stepper.step (index))
1459
- {
1460
- res = m_reducer->m_reduce (res, *m_stepper);
1461
- }
1462
- m_stepper.step_back (index);
1474
+ res = m_reducer->m_reduce (res, *m_stepper);
1463
1475
}
1464
- m_stepper.reset (index);
1476
+ m_stepper.step_back (index);
1465
1477
}
1478
+ m_stepper.reset (index);
1466
1479
return res;
1467
1480
}
1468
1481
1469
1482
template <class F , class CT , class X , class O >
1470
- inline auto xreducer_stepper<F, CT, X, O>::aggregate (size_type dim, std::true_type) const -> reference
1483
+ inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl (size_type dim, std::true_type) const -> reference
1471
1484
{
1472
1485
reference res (0 );
1473
- if (m_reducer->m_e .shape ().empty ())
1486
+ auto ax_it = std::find (m_reducer->m_axes .begin (), m_reducer->m_axes .end (), dim);
1487
+ if (ax_it != m_reducer->m_axes .end ())
1474
1488
{
1475
- res = m_reducer->m_reduce (O::has_initial_value ? m_reducer->m_options .initial_value : m_reducer->m_init (), *m_stepper);
1476
- return res;
1477
- }
1478
- else
1479
- {
1480
- auto ax_it = std::find (m_reducer->m_axes .begin (), m_reducer->m_axes .end (), dim);
1481
- if (ax_it != m_reducer->m_axes .end ())
1489
+ size_type index = dim;
1490
+ size_type size = m_reducer->m_e .shape ()[index];
1491
+ if (ax_it != m_reducer->m_axes .end () - 1 )
1482
1492
{
1483
- size_type index = dim;
1484
- size_type size = m_reducer->m_e .shape ()[index];
1485
- if (ax_it != m_reducer->m_axes .end () - 1 )
1486
- {
1487
- res = aggregate (dim + 1 , typename O::keep_dims ());
1488
- if (O::has_initial_value && ax_it == m_reducer->m_axes .begin ())
1489
- {
1490
- res = m_reducer->m_merge (m_reducer->m_options .initial_value , res);
1491
- }
1492
-
1493
- for (size_type i = 1 ; i != size; ++i)
1494
- {
1495
- m_stepper.step (index);
1496
- res = m_reducer->m_merge (res, aggregate (dim + 1 , typename O::keep_dims ()));
1497
- }
1498
- }
1499
- else
1493
+ res = aggregate_impl (dim + 1 , typename O::keep_dims ());
1494
+ for (size_type i = 1 ; i != size; ++i)
1500
1495
{
1501
- res = m_reducer->m_init ();
1502
- for (size_type i = 0 ; i != size; ++i, m_stepper.step (index))
1503
- {
1504
- res = m_reducer->m_reduce (res, *m_stepper);
1505
- }
1506
- m_stepper.step_back (index);
1496
+ m_stepper.step (index);
1497
+ res = m_reducer->m_merge (res, aggregate_impl (dim + 1 , typename O::keep_dims ()));
1507
1498
}
1508
- m_stepper.reset (index);
1509
1499
}
1510
1500
else
1511
1501
{
1512
- if (dim < m_reducer->m_e .dimension ())
1502
+ res = m_reducer->m_init ();
1503
+ for (size_type i = 0 ; i != size; ++i, m_stepper.step (index))
1513
1504
{
1514
- res = aggregate (dim + 1 , typename O::keep_dims () );
1505
+ res = m_reducer-> m_reduce (res, *m_stepper );
1515
1506
}
1507
+ m_stepper.step_back (index);
1508
+ }
1509
+ m_stepper.reset (index);
1510
+ }
1511
+ else
1512
+ {
1513
+ if (dim < m_reducer->m_e .dimension ())
1514
+ {
1515
+ res = aggregate_impl (dim + 1 , typename O::keep_dims ());
1516
1516
}
1517
1517
}
1518
-
1519
1518
return res;
1520
1519
}
1521
1520
0 commit comments