8000 Merge pull request #276 from gouarin/add-xtensor-fixed · xtensor-stack/xtensor-python@9cf4028 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9cf4028

Browse files
authored
Merge pull request #276 from gouarin/add-xtensor-fixed
add binding for xtensor_fixed
2 parents 6544a55 + 9a26ec8 commit 9cf4028

File tree

6 files changed

+88
-5
lines changed

6 files changed

+88
-5
lines changed

.azure-pipelines/azure-pipelines-osx.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ jobs:
22
- job: 'OSX'
33
strategy:
44
matrix:
5-
macOS_10_14:
6-
image_name: 'macOS-10.14'
75
macOS_10_15:
86
image_name: 'macOS-10.15'
7+
macOS_11:
8+
image_name: 'macOS-11'
99
pool:
1010
vmImage: $(image_name)
1111
variables:

include/xtensor-python/pycontainer.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,9 @@ namespace xt
320320
inline auto pycontainer<D>::get_buffer_size() const -> size_type
321321
{
322322
const size_type& (*min)(const size_type&, const size_type&) = std::min<size_type>;
323-
size_type min_stride = this->strides().empty() ? size_type(1) :
323+
size_type min_stride = this->strides().empty() ? size_type(1) :
324324
std::max(size_type(1), std::accumulate(this->strides().cbegin(),
325-
this->strides().cend(),
325+
this->strides().cend(),
326326
std::numeric_limits<size_type>::max(),
327327
min));
328328
return min_stride * static_cast<size_type>(PyArray_SIZE(this->python_array()));

include/xtensor-python/pynative_casters.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ namespace pybind11
2828
{
2929
};
3030

31+
// Type caster for casting xt::xtensor_fixed to ndarray
32+
template <class T, class FSH, xt::layout_type L>
33+
struct type_caster<xt::xtensor_fixed<T, FSH, L>> : xtensor_type_caster_base<xt::xtensor_fixed<T, FSH, L>>
34+
{
35+
};
36+
3137
// Type caster for casting xt::xstrided_view to ndarray
3238
template <class CT, class S, xt::layout_type L, class FST>
3339
struct type_caster<xt::xstrided_view<CT, S, L, FST>> : xtensor_type_caster_base<xt::xstrided_view<CT, S, L, FST>>

include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <vector>
1616

1717
#include "xtensor/xtensor.hpp"
18+
#include "xtensor/xfixed.hpp"
1819

1920
#include <pybind11/numpy.h>
2021
#include <pybind11/pybind11.h>
@@ -64,6 +65,15 @@ namespace pybind11
6465
}
6566
};
6667

68+
template <class T, class FSH, xt::layout_type L>
69+
struct pybind_array_getter<xt::xtensor_fixed<T, FSH, L>>
70+
{
71+
static auto run(handle src)
72+
{
73+
return pybind_array_getter_impl<T, L>::run(src);
74+
}
75+
};
76+
6777
template <class CT, class S, xt::layout_type L, class FST>
6878
struct pybind_array_getter<xt::xstrided_view<CT, S, L, FST>>
6979
{
@@ -113,6 +123,37 @@ namespace pybind11
113123
}
114124
};
115125

126+
template <class T, class FSH, xt::layout_type L>
127+
struct pybind_array_dim_checker<xt::xtensor_fixed<T, FSH, L>>
128+
{
129+
template <class B>
130+
static bool run(const B& buf)
131+
{
132+
return buf.ndim() == FSH::size();
133+
}
134+
};
135+
136+
137+
template <class T>
138+
struct pybind_array_shape_checker
139+
{
140+
template <class B>
141+
static bool run(const B& buf)
142+
{
143+
return true;
144+
}
145+
};
146+
147+
template <class T, class FSH, xt::layout_type L>
148+
struct pybind_array_shape_checker<xt::xtensor_fixed<T, FSH, L>>
149+
{
150+
template <class B>
151+
static bool run(const B& buf)
152+
{
153+
auto shape = FSH();
154+
return std::equal(shape.begin(), shape.end(), buf.shape());
155+
}
156+
};
116157

117158
// Casts a strided expression type to numpy array.If given a base,
118159
// the numpy array references the src data, otherwise it'll make a copy.
@@ -215,6 +256,11 @@ namespace pybind11
215256
return false;
216257
}
217258

259+
if (!pybind_array_shape_checker<Type>::run(buf))
260+
{
261+
return false;
262+
}
263+
218264
std::vector<size_t> shape(buf.ndim());
219265
std::copy(buf.shape(), buf.shape() + buf.ndim(), shape.begin());
220266
value = Type::from_shape(shape);

test_python/main.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "xtensor/xmath.hpp"
1313
#include "xtensor/xarray.hpp"
14+
#include "xtensor/xfixed.hpp"
1415
#define FORCE_IMPORT_ARRAY
1516
#include "xtensor-python/pyarray.hpp"
1617
#include "xtensor-python/pytensor.hpp"
@@ -60,6 +61,22 @@ xt::xtensor<int, 2, xt::layout_type::column_major> example3_xtensor2_colmajor(
6061
return xt::transpose(m) + 2;
6162
}
6263

64+
xt::xtensor_fixed<int, xt::xshape<4, 3, 2>> example3_xfixed3(const xt::xtensor_fixed<int, xt::xshape<2, 3, 4>>& m)
65+
{
66+
return xt::transpose(m) + 2;
67+
}
68+
69+
xt::xtensor_fixed<int, xt::xshape<3, 2>> example3_xfixed2(const xt::xtensor_fixed<int, xt::xshape<2, 3>>& m)
70+
{
71+
return xt::transpose(m) + 2;
72+
}
73+
74+
xt::xtensor_fixed<int, xt::xshape<3, 2>, xt::layout_type::column_major> example3_xfixed2_colmajor(
75+
const xt::xtensor_fixed<int, xt::xshape<2, 3>, xt::layout_type::column_major>& m)
76+
{
77+
return xt::transpose(m) + 2;
78+
}
79+
6380
// Readme Examples
6481

6582
double readme_example1(xt::pyarray<double>& m)
@@ -281,6 +298,9 @@ PYBIND11_MODULE(xtensor_python_test, m)
281298
m.def("example3_xtensor3", example3_xtensor3);
282299
m.def("example3_xtensor2", example3_xtensor2);
283300
m.def("example3_xtensor2_colmajor", example3_xtensor2_colmajor);
301+
m.def("example3_xfixed3", example3_xfixed3);
302+
m.def("example3_xfixed2", example3_xfixed2);
303+
m.def("example3_xfixed2_colmajor", example3_xfixed2_colmajor);
284304

285305
m.def("complex_overload", no_complex_overload);
286306
m.def("complex_overload", complex_overload);

test_python/test_pyarray.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,20 @@ def test_example3(self):
4949
np.testing.assert_array_equal(xt.example3_xtensor2(y[1:, 1:, 0]), v.T + 2)
5050
np.testing.assert_array_equal(xt.example3_xtensor2_colmajor(xc), xc.T + 2)
5151

52+
np.testing.assert_array_equal(xt.example3_xfixed3(y), y.T + 2)
53+
np.testing.assert_array_equal(xt.example3_xfixed2(x), x.T + 2)
54+
np.testing.assert_array_equal(xt.example3_xfixed2_colmajor(xc), xc.T + 2)
55+
5256
with self.assertRaises(TypeError):
5357
xt.example3_xtensor3(x)
5458

59+
with self.assertRaises(TypeError):
60+
xt.example3_xfixed3(x)
61+
62+
with self.assertRaises(TypeError):
63+
x = np.arange(3*2).reshape(3, 2)
64+
xt.example3_xfixed2(x)
65+
5566
def test_vectorize(self):
5667
x1 = np.array([[0, 1], [2, 3]])
5768
x2 = np.array([0, 1])
@@ -85,7 +96,7 @@ def test_readme_example2(self):
8596
x = np.arange(15).reshape(3, 5)
8697
y = [1, 2, 3, 4, 5]
8798
z = xt.readme_example2(x, y)
88-
np.testing.assert_allclose(z,
99+
np.testing.assert_allclose(z,
89100
[[-0.540302, 1.257618, 1.89929 , 0.794764, -1.040465],
90101
[-1.499227, 0.136731, 1.646979, 1.643002, 0.128456],
91102
[-1.084323, -0.583843, 0.45342 , 1.073811, 0.706945]], 1e-5)

0 commit comments

Comments
 (0)
0