@@ -777,14 +777,6 @@ namespace xt
777
777
using result_type = typename argfunc_result_type<E>::type;
778
778
using result_shape_type = typename result_type::shape_type;
779
779
780
- if (e.dimension () == 1 )
781
- {
782
- auto begin = e.template begin <L>();
783
- auto end = e.template end <L>();
784
- std::size_t i = static_cast <std::size_t >(std::distance (begin, std::min_element (begin, end)));
785
- return xtensor<size_t , 0 >{i};
786
- }
787
-
788
780
result_shape_type alt_shape;
789
781
xt::resize_container (alt_shape, e.dimension () - 1 );
790
782
@@ -854,6 +846,11 @@ namespace xt
854
846
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E >
855
847
inline auto argmin (const xexpression<E>& e, std::ptrdiff_t axis)
856
848
{
849
+ if (e.dimension () == 1 )
850
+ {
851
+ return argmin (e);
852
+ }
853
+
857
854
using value_type = typename E::value_type;
858
855
auto && ed = eval (e.derived_cast ());
859
856
std::size_t ax = normalize_axis (ed.dimension (), axis);
@@ -884,6 +881,11 @@ namespace xt
884
881
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E >
885
882
inline auto argmax (const xexpression<E>& e, std::ptrdiff_t axis)
886
883
{
884
+ if (e.dimension () == 1 )
885
+ {
886
+ return argmax (e);
887
+ }
888
+
887
889
using value_type = typename E::value_type;
888
890
auto && ed = eval (e.derived_cast ());
889
891
std::size_t ax = normalize_axis (ed.dimension (), axis);
0 commit comments