8000 add get_2darray for templated matrices and spy plot · intercore/matplotlib-cpp@ab37736 · GitHub
[go: up one dir, main page]

Skip to content

Commit ab37736

Browse files
committed
add get_2darray for templated matrices and spy plot
1 parent 4ba2bf4 commit ab37736

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ eigen_include = -I /usr/local/include/eigen3
2020
example_execs = minimal modern basic animation nonblock xkcd quiver bar surface subplot fill_inbetween fill update
2121

2222
# Executable names for examples using Eigen
23-
eigen_execs = eigen loglog semilogx semilogy small
23+
eigen_execs = eigen loglog semilogx semilogy small spy
2424

2525
# Example targets (default if just 'make' is called)
2626
examples: $(example_execs)

examples/spy.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <vector>
2+
#include <Eigen/Dense>
3+
#include "../matplotlibcpp.h"
4+
namespace plt = matplotlibcpp;
5+
6+
int main() {
7+
8+
const unsigned n = 100;
9+
Eigen::MatrixXd A(n / 2, n);
10+
std::vector<std::vector<double>> B;
11+
12+
for (unsigned i = 0; i < n / 2; ++i) {
13+
A(i, i) = 1;
14+
std::vector<double> row(n);
15+
row[i] = 1;
16+
17+
if (i < n / 2) {
18+
A(i, i + n / 2) = 1;
19+
row[i + n / 2] = 1;
20+
}
21+
B.push_back(row);
22+
}
23+
24+
for (unsigned i = 0; i < n / 2; ++i) {
25+
for (unsigned j = 0; j < n; ++j) {
26+
if (A(i, j) != B[i][j]) {
27+
std::cout << i << "," << j << " differ!\n";
28+
}
29+
}
30+
}
31+
32+
plt::figure();
33+
plt::title("Eigen");
34+
plt::spy(A);
35+
36+
plt::figure();
37+
plt::title("vector");
38+
plt::spy(B);
39+
plt::show();
40+
return 0;
41+
}

matplotlibcpp.h

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct _interpreter {
5959
PyObject *s_python_function_fill_between;
6060
PyObject *s_python_function_hist;
6161
PyObject *s_python_function_scatter;
62+
PyObject *s_python_function_spy;
6263
PyObject *s_python_function_subplot;
6364
PyObject *s_python_function_legend;
6465
PyObject *s_python_function_xlim;
@@ -188,6 +189,7 @@ struct _interpreter {
188189
PyObject_GetAttrString(pymod, "fill_between");
189190
s_python_function_hist = PyObject_GetAttrString(pymod, "hist");
190191
s_python_function_scatter = PyObject_GetAttrString(pymod, "scatter");
192+
s_python_function_spy = PyObject_GetAttrString(pymod, "spy");
191193
s_python_function_subplot = PyObject_GetAttrString(pymod, "subplot");
192194
s_python_function_legend = PyObject_GetAttrString(pymod, "legend");
193195
s_python_function_ylim = PyObject_GetAttrString(pymod, "ylim");
@@ -236,7 +238,8 @@ struct _interpreter {
236238
!s_python_function_errorbar || !s_python_function_tight_layout ||
237239
!s_python_function_stem || !s_python_function_xkcd ||
238240
!s_python_function_text || !s_python_function_suptitle ||
239-
!s_python_function_bar || !s_python_function_subplots_adjust) {
241+
!s_python_function_bar || !s_python_function_subplots_adjust ||
242+
!s_python_function_spy) {
240243
throw std::runtime_error("Couldn't find required function!");
241244
}
242245

@@ -253,6 +256,7 @@ struct _interpreter {
253256
!PyFunction_Check(s_python_function_loglog) ||
254257
!PyFunction_Check(s_python_function_fill) ||
255258
!PyFunction_Check(s_python_function_fill_between) ||
259+
!PyFunction_Check(s_python_function_spy) ||
256260
!PyFunction_Check(s_python_function_subplot) ||
257261
!PyFunction_Check(s_python_function_legend) ||
258262
!PyFunction_Check(s_python_function_annotate) ||
@@ -396,7 +400,7 @@ PyObject *get_2darray(const std::vector<::std::vector<Numeric>> &v) {
396400
detail::_interpreter::get(); // interpreter needs to be initialized for the
397401
// numpy commands to work
398402
if (v.size() < 1)
399-
throw std::runtime_error("get_2d_array v too small");
403+
throw std::runtime_error("get_2darray v too small");
400404

401405
npy_intp vsize[2] = {static_cast<npy_intp>(v.size()),
402406
static_cast<npy_intp>(v[0].size())};
@@ -408,14 +412,39 @@ PyObject *get_2darray(const std::vector<::std::vector<Numeric>> &v) {
408412

409413
for (const ::std::vector<Numeric> &v_row : v) {
410414
if (v_row.size() != static_cast<size_t>(vsize[1]))
411-
throw std::runtime_error("Missmatched array size");
415+
throw std 6D40 ::runtime_error("mismatched array size");
412416
std::copy(v_row.begin(), v_row.end(), vd_begin);
413417
vd_begin += vsize[1];
414418
}
415419

416420
return reinterpret_cast<PyObject *>(varray);
417421
}
418422

423+
// suitable for Eigen matrices
424+
template <typename Matrix>
425+
PyObject *get_2darray(const Matrix &A) {
426+
detail::_interpreter::get(); // interpreter needs to be initialized for the
427+
// numpy commands to work
428+
if (A.size() < 1)
429+
throw std::runtime_error("get_2darray A too small");
430+
431+
npy_intp vsize[2] = {static_cast<npy_intp>(A.rows()),
432+
static_cast<npy_intp>(A.cols())};
433+
434+
PyArrayObject *varray =
435+
(PyArrayObject *)PyArray_SimpleNew(2, vsize, NPY_DOUBLE);
436+
437+
double *vd_begin = static_cast<double *>(PyArray_DATA(varray));
438+
439+
for (std::size_t i = 0; i < A.rows(); ++i) {
440+
for (std::size_t j = 0; j < A.cols(); ++j) {
441+
*(vd_begin + i * A.cols() + j) = A(i, j);
442+
}
443+
}
444+
445+
return reinterpret_cast<PyObject *>(varray);
446+
}
447+
419448
#else // fallback if we don't have numpy: copy every element of the given vector
420449

421450
template <typename Vector> PyObject *get_array(const Vector &v) {
@@ -869,6 +898,29 @@ bool scatter(const VectorX &x, const VectorY &y, const double s = 1.0) {
869898
return res;
870899
}
871900

901+
// @brief Spy plot
902+
// @param A the matrix
903+
template <typename Matrix>
904+
bool spy(const Matrix &A, double precision=0) {
905+
PyObject *Aarray = get_2darray(A);
906+
907+
PyObject *kwargs = PyDict_New();
908+
PyDict_SetItemString(kwargs, "precision", PyFloat_FromDouble(precision));
909+
910+
PyObject *plot_args = PyTuple_New(1);
911+
PyTuple_SetItem(plot_args, 0, Aarray);
912+
913+
PyObject *res = PyObject_Call(
914+
detail::_interpreter::get().s_python_function_spy, plot_args, kwargs);
915+
916+
Py_DECREF(plot_args);
917+
Py_DECREF(kwargs);
918+
if (res)
919+
Py_DECREF(res);
920+
921+
return res;
922+
}
923+
872924
template <typename Numeric>
873925
bool bar(const std::vector<Numeric> &y, std::string ec = "black",
874926
std::string ls = "-", double lw = 1.0,

0 commit comments

Comments
 (0)
0