@@ -59,6 +59,7 @@ struct _interpreter {
5959 PyObject *s_python_function_errorbar;
6060 PyObject *s_python_function_annotate;
6161 PyObject *s_python_function_tight_layout;
62+ PyObject *s_python_colormap;
6263 PyObject *s_python_empty_tuple;
6364 PyObject *s_python_function_stem;
6465 PyObject *s_python_function_xkcd;
@@ -115,9 +116,13 @@ struct _interpreter {
115116
116117 PyObject* matplotlibname = PyString_FromString (" matplotlib" );
117118 PyObject* pyplotname = PyString_FromString (" matplotlib.pyplot" );
119+ PyObject* mpl_toolkits = PyString_FromString (" mpl_toolkits" );
120+ PyObject* axis3d = PyString_FromString (" mpl_toolkits.mplot3d" );
118121 PyObject* pylabname = PyString_FromString (" pylab" );
119- if (!pyplotname || !pylabname || !matplotlibname) {
120- throw std::runtime_error (" couldnt create string" );
122+ PyObject* cmname = PyString_FromString (" matplotlib.cm" );
123+ if (!pyplotname || !pylabname || !matplotlibname || !mpl_toolkits ||
124+ !axis3d || !cmname) {
125+ throw std::runtime_error (" couldnt create string" );
121126 }
122127
123128 PyObject* matplotlib = PyImport_Import (matplotlibname);
@@ -134,11 +139,22 @@ struct _interpreter {
134139 Py_DECREF (pyplotname);
135140 if (!pymod) { throw std::runtime_error (" Error loading module matplotlib.pyplot!" ); }
136141
142+ s_python_colormap = PyImport_Import (cmname);
143+ Py_DECREF (cmname);
144+ if (!s_python_colormap) { throw std::runtime_error (" Error loading module matplotlib.cm!" ); }
137145
138146 PyObject* pylabmod = PyImport_Import (pylabname);
139147 Py_DECREF (pylabname);
140148 if (!pylabmod) { throw std::runtime_error (" Error loading module pylab!" ); }
141149
150+ PyObject* mpl_toolkitsmod = PyImport_Import (mpl_toolkits);
151+ Py_DECREF (mpl_toolkitsmod);
152+ if (!mpl_toolkitsmod) { throw std::runtime_error (" Error loading module mpl_toolkits!" ); }
153+
154+ PyObject* axis3dmod = PyImport_Import (axis3d);
155+ Py_DECREF (axis3dmod);
156+ if (!axis3dmod) { throw std::runtime_error (" Error loading module mpl_toolkits.mplot3d!" ); }
157+
142158 s_python_function_show = PyObject_GetAttrString (pymod, " show" );
143159 s_python_function_close = PyObject_GetAttrString (pymod, " close" );
144160 s_python_function_draw = PyObject_GetAttrString (pymod, " draw" );
@@ -325,6 +341,30 @@ PyObject* get_array(const std::vector<Numeric>& v)
325341 return varray;
326342}
327343
344+ template <typename Numeric>
345+ PyObject* get_2darray (const std::vector<::std::vector<Numeric>>& v)
346+ {
347+ detail::_interpreter::get (); // interpreter needs to be initialized for the numpy commands to work
348+ if (v.size () < 1 ) throw std::runtime_error (" get_2d_array v too small" );
349+
350+ npy_intp vsize[2 ] = {static_cast <npy_intp>(v.size ()),
351+ static_cast <npy_intp>(v[0 ].size ())};
352+
353+ PyArrayObject *varray =
354+ (PyArrayObject *)PyArray_SimpleNew (2 , vsize, NPY_DOUBLE);
355+
356+ double *vd_begin = static_cast <double *>(PyArray_DATA (varray));
357+
358+ for (const ::std::vector<Numeric> &v_row : v) {
359+ if (v_row.size () != static_cast <size_t >(vsize[1 ]))
360+ throw std::runtime_error (" Missmatched array size" );
361+ std::copy (v_row.begin (), v_row.end (), vd_begin);
362+ vd_begin += vsize[1 ];
363+ }
364+
365+ return reinterpret_cast <PyObject *>(varray);
366+ }
367+
328368#else // fallback if we don't have numpy: copy every element of the given vector
329369
330370template <typename Numeric>
@@ -369,6 +409,76 @@ bool plot(const std::vector<Numeric> &x, const std::vector<Numeric> &y, const st
369409 return res;
370410}
371411
412+ template <typename Numeric>
413+ void plot_surface (const std::vector<::std::vector<Numeric>> &x,
414+ const std::vector<::std::vector<Numeric>> &y,
415+ const std::vector<::std::vector<Numeric>> &z,
416+ const std::map<std::string, std::string> &keywords =
417+ std::map<std::string, std::string>()) {
418+ assert (x.size () == y.size ());
419+ assert (y.size () == z.size ());
420+
421+ // using numpy arrays
422+ PyObject *xarray = get_2darray (x);
423+ PyObject *yarray = get_2darray (y);
424+ PyObject *zarray = get_2darray (z);
425+
426+ // construct positional args
427+ PyObject *args = PyTuple_New (3 );
428+ PyTuple_SetItem (args, 0 , xarray);
429+ PyTuple_SetItem (args, 1 , yarray);
430+ PyTuple_SetItem (args, 2 , zarray);
431+
432+ // Build up the kw args.
433+ PyObject *kwargs = PyDict_New ();
434+ PyDict_SetItemString (kwargs, " rstride" , PyInt_FromLong (1 ));
435+ PyDict_SetItemString (kwargs, " cstride" , PyInt_FromLong (1 ));
436+
437+ PyObject *python_colormap_coolwarm = PyObject_GetAttrString (
438+ detail::_interpreter::get ().s_python_colormap , " coolwarm" );
439+
440+ PyDict_SetItemString (kwargs, " cmap" , python_colormap_coolwarm);
441+
442+ for (std::map<std::string, std::string>::const_iterator it = keywords.begin ();
443+ it != keywords.end (); ++it) {
444+ PyDict_SetItemString (kwargs, it->first .c_str (),
445+ PyString_FromString (it->second .c_str ()));
446+ }
447+
448+
449+ PyObject *fig =
450+ PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
451+ detail::_interpreter::get ().s_python_empty_tuple );
452+ if (!fig) throw std::runtime_error (" Call to figure() failed." );
453+
454+ PyObject *gca_kwargs = PyDict_New ();
455+ PyDict_SetItemString (gca_kwargs, " projection" , PyString_FromString (" 3d" ));
456+
457+ PyObject *gca = PyObject_GetAttrString (fig, " gca" );
458+ if (!gca) throw std::runtime_error (" No gca" );
459+ Py_INCREF (gca);
460+ PyObject *axis = PyObject_Call (
461+ gca, detail::_interpreter::get ().s_python_empty_tuple , gca_kwargs);
462+
463+ if (!axis) throw std::runtime_error (" No axis" );
464+ Py_INCREF (axis);
465+
466+ Py_DECREF (gca);
467+ Py_DECREF (gca_kwargs);
468+
469+ PyObject *plot_surface = PyObject_GetAttrString (axis, " plot_surface" );
470+ if (!plot_surface) throw std::runtime_error (" No surface" );
471+ Py_INCREF (plot_surface);
472+ PyObject *res = PyObject_Call (plot_surface, args, kwargs);
473+ if (!res) throw std::runtime_error (" failed surface" );
474+ Py_DECREF (plot_surface);
475+
476+ Py_DECREF (axis);
477+ Py_DECREF (args);
478+ Py_DECREF (kwargs);
479+ if (res) Py_DECREF (res);
480+ }
481+
372482template <typename Numeric>
373483bool stem (const std::vector<Numeric> &x, const std::vector<Numeric> &y, const std::map<std::string, std::string>& keywords)
374484{
0 commit comments