@@ -59,6 +59,7 @@ struct _interpreter {
59
59
PyObject *s_python_function_fill_between;
60
60
PyObject *s_python_function_hist;
61
61
PyObject *s_python_function_scatter;
62
+ PyObject *s_python_function_spy;
62
63
PyObject *s_python_function_subplot;
63
64
PyObject *s_python_function_legend;
64
65
PyObject *s_python_function_xlim;
@@ -188,6 +189,7 @@ struct _interpreter {
188
189
PyObject_GetAttrString (pymod, " fill_between" );
189
190
s_python_function_hist = PyObject_GetAttrString (pymod, " hist" );
190
191
s_python_function_scatter = PyObject_GetAttrString (pymod, " scatter" );
192
+ s_python_function_spy = PyObject_GetAttrString (pymod, " spy" );
191
193
s_python_function_subplot = PyObject_GetAttrString (pymod, " subplot" );
192
194
s_python_function_legend = PyObject_GetAttrString (pymod, " legend" );
193
195
s_python_function_ylim = PyObject_GetAttrString (pymod, " ylim" );
@@ -236,7 +238,8 @@ struct _interpreter {
236
238
!s_python_function_errorbar || !s_python_function_tight_layout ||
237
239
!s_python_function_stem || !s_python_function_xkcd ||
238
240
!s_python_function_text || !s_python_function_suptitle ||
239
- !s_python_function_bar || !s_python_function_subplots_adjust) {
241
+ !s_python_function_bar || !s_python_function_subplots_adjust ||
242
+ !s_python_function_spy) {
240
243
throw std::runtime_error (" Couldn't find required function!" );
241
244
}
242
245
@@ -253,6 +256,7 @@ struct _interpreter {
253
256
!PyFunction_Check (s_python_function_loglog) ||
254
257
!PyFunction_Check (s_python_function_fill) ||
255
258
!PyFunction_Check (s_python_function_fill_between) ||
259
+ !PyFunction_Check (s_python_function_spy) ||
256
260
!PyFunction_Check (s_python_function_subplot) ||
257
261
!PyFunction_Check (s_python_function_legend) ||
258
262
!PyFunction_Check (s_python_function_annotate) ||
@@ -396,7 +400,7 @@ PyObject *get_2darray(const std::vector<::std::vector<Numeric>> &v) {
396
400
detail::_interpreter::get (); // interpreter needs to be initialized for the
397
401
// numpy commands to work
398
402
if (v.size () < 1 )
399
- throw std::runtime_error (" get_2d_array v too small" );
403
+ throw std::runtime_error (" get_2darray v too small" );
400
404
401
405
npy_intp vsize[2 ] = {static_cast <npy_intp>(v.size ()),
402
406
static_cast <npy_intp>(v[0 ].size ())};
@@ -408,14 +412,39 @@ PyObject *get_2darray(const std::vector<::std::vector<Numeric>> &v) {
408
412
409
413
for (const ::std::vector<Numeric> &v_row : v) {
410
414
if (v_row.size () != static_cast <size_t >(vsize[1 ]))
411
- throw std::runtime_error (" Missmatched array size" );
415
+ throw std
6D40
::runtime_error (" mismatched array size" );
412
416
std::copy (v_row.begin (), v_row.end (), vd_begin);
413
417
vd_begin += vsize[1 ];
414
418
}
415
419
416
420
return reinterpret_cast <PyObject *>(varray);
417
421
}
418
422
423
+ // suitable for Eigen matrices
424
+ template <typename Matrix>
425
+ PyObject *get_2darray (const Matrix &A) {
426
+ detail::_interpreter::get (); // interpreter needs to be initialized for the
427
+ // numpy commands to work
428
+ if (A.size () < 1 )
429
+ throw std::runtime_error (" get_2darray A too small" );
430
+
431
+ npy_intp vsize[2 ] = {static_cast <npy_intp>(A.rows ()),
432
+ static_cast <npy_intp>(A.cols ())};
433
+
434
+ PyArrayObject *varray =
435
+ (PyArrayObject *)PyArray_SimpleNew (2 , vsize, NPY_DOUBLE);
436
+
437
+ double *vd_begin = static_cast <double *>(PyArray_DATA (varray));
438
+
439
+ for (std::size_t i = 0 ; i < A.rows (); ++i) {
440
+ for (std::size_t j = 0 ; j < A.cols (); ++j) {
441
+ *(vd_begin + i * A.cols () + j) = A (i, j);
442
+ }
443
+ }
444
+
445
+ return reinterpret_cast <PyObject *>(varray);
446
+ }
447
+
419
448
#else // fallback if we don't have numpy: copy every element of the given vector
420
449
421
450
template <typename Vector> PyObject *get_array (const Vector &v) {
@@ -869,6 +898,29 @@ bool scatter(const VectorX &x, const VectorY &y, const double s = 1.0) {
869
898
return res;
870
899
}
871
900
901
+ // @brief Spy plot
902
+ // @param A the matrix
903
+ template <typename Matrix>
904
+ bool spy (const Matrix &A, double precision=0 ) {
905
+ PyObject *Aarray = get_2darray (A);
906
+
907
+ PyObject *kwargs = PyDict_New ();
908
+ PyDict_SetItemString (kwargs, " precision" , PyFloat_FromDouble (precision));
909
+
910
+ PyObject *plot_args = PyTuple_New (1 );
911
+ PyTuple_SetItem (plot_args, 0 , Aarray);
912
+
913
+ PyObject *res = PyObject_Call (
914
+ detail::_interpreter::get ().s_python_function_spy , plot_args, kwargs);
915
+
916
+ Py_DECREF (plot_args);
917
+ Py_DECREF (kwargs);
918
+ if (res)
919
+ Py_DECREF (res);
920
+
921
+ return res;
922
+ }
923
+
872
924
template <typename Numeric>
873
925
bool bar (const std::vector<Numeric> &y, std::string ec = " black" ,
874
926
std::string ls = " -" , double lw = 1.0 ,
0 commit comments