8000 Fixed strides adaptor leading to a broadcasting issue · xtensor-stack/xtensor-python@82d66dd · GitHub
[go: up one dir, main page]

Skip to content

Commit 82d66dd

Browse files
committed
Fixed strides adaptor leading to a broadcasting issue
1 parent f2c2d17 commit 82d66dd

File tree

4 files changed

+100
-16
lines changed

4 files changed

+100
-16
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,8 @@ namespace xt
519519
m_shape = inner_shape_type(reinterpret_cast<size_type*>(PyArray_SHAPE(this->python_array())),
520520
static_cast<size_type>(PyArray_NDIM(this->python_array())));
521521
m_strides = inner_strides_type(reinterpret_cast<difference_type*>(PyArray_STRIDES(this->python_array())),
522-
static_cast<size_type>(PyArray_NDIM(this->python_array())));
522+
static_cast<size_type>(PyArray_NDIM(this->python_array())),
523+
reinterpret_cast<size_type*>(PyArray_SHAPE(this->python_array())));
523524

524525
if (L != layout_type::dynamic && !do_strides_match(m_shape, m_strides, L, 1))
525526
{

include/xtensor-python/pystrides_adaptor.hpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ namespace xt
4141
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
4242
using reverse_iterator = const_reverse_iterator;
4343

44+
using shape_type = size_t*;
45+
4446
pystrides_adaptor() = default;
45-
pystrides_adaptor(const_pointer data, size_type size);
47+
pystrides_adaptor(const_pointer data, size_type size, shape_type shape);
4648

4749
bool empty() const noexcept;
4850
size_type size() const noexcept;
@@ -66,6 +68,7 @@ namespace xt
6668

6769
const_pointer p_data;
6870
size_type m_size;
71+
shape_type p_shape;
6972
};
7073

7174
/**********************************
@@ -84,21 +87,23 @@ namespace xt
8487
using reference = typename pystrides_adaptor<N>::const_reference;
8588
using difference_type = typename pystrides_adaptor<N>::difference_type;
8689
using iterator_category = std::random_access_iterator_tag;
90+
using shape_pointer = typename pystrides_adaptor<N>::shape_type;
8791

88-
inline pystrides_iterator(pointer current)
92+
inline pystrides_iterator(pointer current, shape_pointer shape)
8993
: p_current(current)
94+
, p_shape(shape)
9095
{
9196
}
9297

9398
inline reference operator*() const
9499
{
95-
return *p_current / N;
100+
return *p_shape == size_t(1) ? 0 : *p_current / N;
96101
}
97102

98103
inline pointer operator->() const
99104
{
100105
// Returning the address of a temporary
101-
value_type res = *p_current / N;
106+
value_type res = this->operator*();
102107
return &res;
103108
}
104109

@@ -110,49 +115,55 @@ namespace xt
110115
inline self_type& operator++()
111116
{
112117
++p_current;
118+
++p_shape;
113119
return *this;
114120
}
115121

116122
inline self_type& operator--()
117123
{
118124
--p_current;
125+
--p_shape;
119126
return *this;
120127
}
121128

122129
inline self_type operator++(int)
123130
{
124131
self_type tmp(*this);
125132
++p_current;
133+
++p_shape;
126134
return tmp;
127135
}
128136

129137
inline self_type operator--(int)
130138
{
131139
self_type tmp(*this);
132140
--p_current;
141+
--p_shape;
133142
return tmp;
134143
}
135144

136145
inline self_type& operator+=(difference_type n)
137146
{
138147
p_current += n;
148+
p_shape += n;
139149
return *this;
140150
}
141151

142152
inline self_type& operator-=(difference_type n)
143153
{
144154
p_current -= n;
155+
p_shape -= n;
145156
return *this;
146157
}
147158

148159
inline self_type operator+(difference_type n) const
149160
{
150-
return self_type(p_current + n);
161+
return self_type(p_current + n, p_shape + n);
151162
}
152163

153164
inline self_type operator-(difference_type n) const
154165
{
155-
return self_type(p_current - n);
166+
return self_type(p_current - n, p_shape - n);
156167
}
157168

158169
inline difference_type operator-(const self_type& rhs) const
@@ -166,6 +177,7 @@ namespace xt
166177
private:
167178

168179
pointer p_current;
180+
shape_pointer p_shape;
169181
};
170182

171183
template <std::size_t N>
@@ -215,8 +227,8 @@ namespace xt
215227
************************************/
216228

217229
template <std::size_t N>
218-
inline pystrides_adaptor<N>::pystrides_adaptor(const_pointer data, size_type size)
219-
: p_data(data), m_size(size)
230+
inline pystrides_adaptor<N>::pystrides_adaptor(const_pointer data, size_type size, shape_type shape)
231+
: p_data(data), m_size(size), p_shape(shape)
220232
{
221233
}
222234

@@ -235,19 +247,19 @@ namespace xt
235247
template <std::size_t N>
236248
inline auto pystrides_adaptor<N>::operator[](size_type i) const -> const_reference
237249
{
238-
return p_data[i] / N;
250+
return p_shape[i] == size_t(1) ? 0 : p_data[i] / N;
239251
}
240252

241253
template <std::size_t N>
242254
inline auto pystrides_adaptor<N>::front() const -> const_reference
243255
{
244-
return p_data[0] / N;
256+
return this->operator[](0);
245257
}
246258

247259
template <std::size_t N>
248260
inline auto pystrides_adaptor<N>::back() const -> const_reference
249261
{
250-
return p_data[m_size - 1] / N;
262+
return this->operator[](m_size - 1);
251263
}
252264

253265
template <std::size_t N>
@@ -265,13 +277,13 @@ namespace xt
265277
template <std::size_t N>
266278
inline auto pystrides_adaptor<N>::cbegin() const -> const_iterator
267279
{
268-
return const_iterator(p_data);
280+
return const_iterator(p_data, p_shape);
269281
}
270282

271283
template <std::size_t N>
272284
inline auto pystrides_adaptor<N>::cend() const -> const_iterator
273285
{
274-
return const_iterator(p_data + m_size);
286+
return const_iterator(p_data + m_size, p_shape + m_size);
275287
}
276288

277289
template <std::size_t N>

test_python/main.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,28 @@ auto no_complex_overload_reg(const double& a)
108108
{
109109
return a;
110110
}
111+
//
112+
// Operator examples
113+
//
114+
xt::pyarray<double> array_addition(const xt::pyarray<double>& m, const xt::pyarray<double>& n)
115+
{
116+
return m + n;
117+
}
118+
119+
xt::pyarray<double> array_subtraction(xt::pyarray<double>& m, xt::pyarray<double>& n)
120+
{
121+
return m - n;
122+
}
123+
124+
xt::pyarray<double> array_multiplication(xt::pyarray<double>& m, xt::pyarray<double>& n)
125+
{
126+
return m * n;
127+
}
128+
129+
xt::pyarray<double> array_division(xt::pyarray<double>& m, xt::pyarray<double>& n)
130+
{
131+
return m / n;
132+
}
111133

112134
// Vectorize Examples
113135

@@ -310,6 +332,11 @@ PYBIND11_MODULE(xtensor_python_test, m)
310332
m.def("readme_example1", readme_example1);
311333
m.def("readme_example2", xt::pyvectorize(readme_example2));
312334

335+
m.def("array_addition", array_addition);
336+
m.def("array_subtraction", array_subtraction);
337+
m.def("array_multiplication", array_multiplication);
338+
m.def("array_division", array_division);
339+
313340
m.def("vectorize_example1", xt::pyvectorize(add));
314341

315342
m.def("rect_to_polar", xt::pyvectorize([](complex_t x) { return std::abs(x); }));

test_python/test_pyarray.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424

2525
class XtensorTest(TestCase):
26-
26+
"""
2727
def test_rm(self):
2828
xt.test_rm(np.array([10], dtype=int))
2929
@@ -62,6 +62,50 @@ def test_example3(self):
6262
with self.assertRaises(TypeError):
6363
x = np.arange(3*2).reshape(3, 2)
6464
xt.example3_xfixed2(x)
65+
"""
66+
def test_broadcast_addition(self):
67+
x = np.array([[2., 3., 4., 5.]])
68+
y = np.array([[1., 2., 3., 4.],
69+
[1., 2., 3., 4.],
70+
[1., 2., 3., 4.]])
71+
res = np.array([[3., 5., 7., 9.],
72+
[3., 5., 7., 9.],
73+
[3., 5., 7., 9.]])
74+
z = xt.array_addition(x, y)
75+
np.testing.assert_allclose(z, res, 1e-12)
76+
"""
77+
def test_broadcast_subtraction(self):
78+
x = np.array([[4., 5., 6., 7.]])
79+
y = np.array([[4., 3., 2., 1.],
80+
[4., 3., 2., 1.],
81+
[4., 3., 2., 1.]])
82+
res = np.array([[0., 2., 4., 6.],
83+
[0., 2., 4., 6.],
84+
[0., 2., 4., 6.]])
85+
z = xt.array_subtraction(x, y)
86+
np.testing.assert_allclose(z, res, 1e-12)
87+
88+
def test_broadcast_multiplication(self):
89+
x = np.array([[1., 2., 3., 4.]])
90+
y = np.array([[3., 2., 3., 2.],
91+
[3., 2., 3., 2.],
92+
[3., 2., 3., 2.]])
93+
res = np.array([[3., 4., 9., 8.],
94+
[3., 4., 9., 8.],
95+
[3., 4., 9., 8.]])
96+
z = xt.array_multiplication(x, y)
97+
np.testing.assert_allclose(z, res, 1e-12)
98+
99+
def test_broadcast_division(self):
100+
x = np.array([[8., 6., 4., 2.]])
101+
y = np.array([[2., 2., 2., 2.],
102+
[2., 2., 2., 2.],
103+
[2., 2., 2., 2.]])
104+
res = np.array([[4., 3., 2., 1.],
105+
[4., 3., 2., 1.],
106+
[4., 3., 2., 1.]])
107+
z = xt.array_division(x, y)
108+
np.testing.assert_allclose(z, res, 1e-12)
65109
66110
def test_vectorize(self):
67111
x1 = np.array([[0, 1], [2, 3]])
@@ -263,7 +307,7 @@ def test_native_casters(self):
263307
self.assertEqual(adapter.shape, (2, 2))
264308
adapter[1, 1] = -3
265309
self.assertEqual(arr[0, 5], -3)
266-
310+
"""
267311

268312
class AttributeTest(TestCase):
269313

0 commit comments

Comments
 (0)
0