8000 Bugfix: calling `argmax(a, 0)` with `a.dimension() == 1` resulted in … · xtensor-stack/xtensor@fc39f69 · GitHub
[go: up one dir, main page]

Skip to content

Commit fc39f69

Browse files
committed
Bugfix: calling argmax(a, 0) with a.dimension() == 1 resulted in argmin(a, 0)
1 parent 3084585 commit fc39f69

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

include/xtensor/xsort.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -777,14 +777,6 @@ namespace xt
777777
using result_type = typename argfunc_result_type<E>::type;
778778
using result_shape_type = typename result_type::shape_type;
779779

780-
if (e.dimension() == 1)
781-
{
782-
auto begin = e.template begin<L>();
783-
auto end = e.template end<L>();
784-
std::size_t i = static_cast<std::size_t>(std::distance(begin, std::min_element(begin, end)));
785-
return xtensor<size_t, 0>{i};
786-
}
787-
788780
result_shape_type alt_shape;
789781
xt::resize_container(alt_shape, e.dimension() - 1);
790782

@@ -854,6 +846,11 @@ namespace xt
854846
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
855847
inline auto argmin(const xexpression<E>& e, std::ptrdiff_t axis)
856848
{
849+
if (e.dimension() == 1)
850+
{
851+
return argmin(e);
852+
}
853+
857854
using value_type = typename E::value_type;
858855
auto&& ed = eval(e.derived_cast());
859856
std::size_t ax = normalize_axis(ed.dimension(), axis);
@@ -884,6 +881,11 @@ namespace xt
884881
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
885882
inline auto argmax(const xexpression<E>& e, std::ptrdiff_t axis)
886883
{
884+
if (e.dimension() == 1)
885+
{
886+
return argmax(e);
887+
}
888+
887889
using value_type = typename E::value_type;
888890
auto&& ed = eval(e.derived_cast());
889891
std::size_t ax = normalize_axis(ed.dimension(), axis);

test/test_xsort.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ namespace xt
190190
EXPECT_EQ(ex, argmin(xa));
191191
EXPECT_EQ(ex_2, argmin(xa, 0));
192192
EXPECT_EQ(ex_3, argmin(xa, 1));
193+
194+
xtensor<double, 1> ya = {0, 1, 2, 3, 4};
195+
EXPECT_EQ(0, argmin(ya)());
196+
EXPECT_EQ(0, argmin(ya, 0)());
193197
}
194198

195199
TEST(xsort, argmax)
@@ -218,6 +222,10 @@ namespace xt
218222

219223
xtensor<std::size_t, 2> ex_6 = {{0, 0, 0, 0}, {0, 0, 0, 0}};
220224
EXPECT_EQ(ex_6, argmax(c, 1));
225+
226+
xtensor<double, 1> ya = {0, 1, 2, 3, 4};
227+
EXPECT_EQ(4, argmax(ya)());
228+
EXPECT_EQ(4, argmax(ya, 0)());
221229
}
222230

223231
TEST(xsort, sort_large_prob)

0 commit comments

Comments
 (0)
0