@@ -99,6 +99,7 @@ struct _interpreter {
9999 PyObject *s_python_function_barh;
100100 PyObject *s_python_function_colorbar;
101101 PyObject *s_python_function_subplots_adjust;
102+ PyObject *s_python_function_rcparams;
102103
103104
104105 /* For now, _interpreter is implemented as a singleton since its currently not possible to have
@@ -189,6 +190,7 @@ struct _interpreter {
189190 }
190191
191192 PyObject* matplotlib = PyImport_Import (matplotlibname);
193+
192194 Py_DECREF (matplotlibname);
193195 if (!matplotlib) {
194196 PyErr_Print ();
@@ -201,6 +203,8 @@ struct _interpreter {
201203 PyObject_CallMethod (matplotlib, const_cast <char *>(" use" ), const_cast <char *>(" s" ), s_backend.c_str ());
202204 }
203205
206+
207+
204208 PyObject* pymod = PyImport_Import (pyplotname);
205209 Py_DECREF (pyplotname);
206210 if (!pymod) { throw std::runtime_error (" Error loading module matplotlib.pyplot!" ); }
@@ -264,6 +268,7 @@ struct _interpreter {
264268 s_python_function_barh = safe_import (pymod, " barh" );
265269 s_python_function_colorbar = PyObject_GetAttrString (pymod, " colorbar" );
266270 s_python_function_subplots_adjust = safe_import (pymod," subplots_adjust" );
271+ s_python_function_rcparams = PyObject_GetAttrString (pymod, " rcParams" );
267272#ifndef WITHOUT_NUMPY
268273 s_python_function_imshow = safe_import (pymod, " imshow" );
269274#endif
@@ -464,6 +469,7 @@ template <typename Numeric>
464469void plot_surface (const std::vector<::std::vector<Numeric>> &x,
465470 const std::vector<::std::vector<Numeric>> &y,
466471 const std::vector<::std::vector<Numeric>> &z,
472+ const long fig_number=0 ,
467473 const std::map<std::string, std::string> &keywords =
468474 std::map<std::string, std::string>())
469475{
@@ -516,14 +522,29 @@ void plot_surface(const std::vector<::std::vector<Numeric>> &x,
516522
517523 for (std::map<std::string, std::string>::const_iterator it = keywords.begin ();
518524 it != keywords.end (); ++it) {
519- PyDict_SetItemString (kwargs, it->first .c_str (),
520- PyString_FromString (it->second .c_str ()));
525+ if (it->first == " linewidth" || it->first == " alpha" ) {
526+ PyDict_SetItemString (kwargs, it->first .c_str (),
527+ PyFloat_FromDouble (std::stod (it->second )));
528+ } else {
529+ PyDict_SetItemString (kwargs, it->first .c_str (),
530+ PyString_FromString (it->second .c_str ()));
531+ }
521532 }
522533
523-
524- PyObject *fig =
525- PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
526- detail::_interpreter::get ().s_python_empty_tuple );
534+ PyObject *fig_args = PyTuple_New (1 );
535+ PyObject* fig = nullptr ;
536+ PyTuple_SetItem (fig_args, 0 , PyLong_FromLong (fig_number));
537+ PyObject *fig_exists =
538+ PyObject_CallObject (
539+ detail::_interpreter::get ().s_python_function_fignum_exists , fig_args);
540+ if (!PyObject_IsTrue (fig_exists)) {
541+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
542+ detail::_interpreter::get ().s_python_empty_tuple );
543+ } else {
544+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
545+ fig_args);
546+ }
547+ Py_DECREF (fig_exists);
527548 if (!fig) throw std::runtime_error (" Call to figure() failed." );
528549
529550 PyObject *gca_kwargs = PyDict_New ();
@@ -559,6 +580,7 @@ template <typename Numeric>
559580void plot3 (const std::vector<Numeric> &x,
560581 const std::vector<Numeric> &y,
561582 const std::vector<Numeric> &z,
583+ const long fig_number=0 ,
562584 const std::map<std::string, std::string> &keywords =
563585 std::map<std::string, std::string>())
564586{
@@ -607,9 +629,18 @@ void plot3(const std::vector<Numeric> &x,
607629 PyString_FromString (it->second .c_str ()));
608630 }
609631
610- PyObject *fig =
611- PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
612- detail::_interpreter::get ().s_python_empty_tuple );
632+ PyObject *fig_args = PyTuple_New (1 );
633+ PyObject* fig = nullptr ;
634+ PyTuple_SetItem (fig_args, 0 , PyLong_FromLong (fig_number));
635+ PyObject *fig_exists =
636+ PyObject_CallObject (detail::_interpreter::get ().s_python_function_fignum_exists , fig_args);
637+ if (!PyObject_IsTrue (fig_exists)) {
638+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
639+ detail::_interpreter::get ().s_python_empty_tuple );
640+ } else {
641+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
642+ fig_args);
643+ }
613644 if (!fig) throw std::runtime_error (" Call to figure() failed." );
614645
615646 PyObject *gca_kwargs = PyDict_New ();
@@ -911,6 +942,103 @@ bool scatter(const std::vector<NumericX>& x,
911942 return res;
912943}
913944
945+ template <typename NumericX, typename NumericY, typename NumericZ>
946+ bool scatter (const std::vector<NumericX>& x,
947+ const std::vector<NumericY>& y,
948+ const std::vector<NumericZ>& z,
949+ const double s=1.0 , // The marker size in points**2
950+ const long fig_number=0 ,
951+ const std::map<std::string, std::string> & keywords = {}) {
952+ detail::_interpreter::get ();
953+
954+ // Same as with plot_surface: We lazily load the modules here the first time
955+ // this function is called because I'm not sure that we can assume "matplotlib
956+ // installed" implies "mpl_toolkits installed" on all platforms, and we don't
957+ // want to require it for people who don't need 3d plots.
958+ static PyObject *mpl_toolkitsmod = nullptr , *axis3dmod = nullptr ;
959+ if (!mpl_toolkitsmod) {
960+ detail::_interpreter::get ();
961+
962+ PyObject* mpl_toolkits = PyString_FromString (" mpl_toolkits" );
963+ PyObject* axis3d = PyString_FromString (" mpl_toolkits.mplot3d" );
964+ if (!mpl_toolkits || !axis3d) { throw std::runtime_error (" couldnt create string" ); }
965+
966+ mpl_toolkitsmod = PyImport_Import (mpl_toolkits);
967+ Py_DECREF (mpl_toolkits);
968+ if (!mpl_toolkitsmod) { throw std::runtime_error (" Error loading module mpl_toolkits!" ); }
969+
970+ axis3dmod = PyImport_Import (axis3d);
971+ Py_DECREF (axis3d);
972+ if (!axis3dmod) { throw std::runtime_error (" Error loading module mpl_toolkits.mplot3d!" ); }
973+ }
974+
975+ assert (x.size () == y.size ());
976+ assert (y.size () == z.size ());
977+
978+ PyObject *xarray = detail::get_array (x);
979+ PyObject *yarray = detail::get_array (y);
980+ PyObject *zarray = detail::get_array (z);
981+
982+ // construct positional args
983+ PyObject *args = PyTuple_New (3 );
984+ PyTuple_SetItem (args, 0 , xarray);
985+ PyTuple_SetItem (args, 1 , yarray);
986+ PyTuple_SetItem (args, 2 , zarray);
987+
988+ // Build up the kw args.
989+ PyObject *kwargs = PyDict_New ();
990+
991+ for (std::map<std::string, std::string>::const_iterator it = keywords.begin ();
992+ it != keywords.end (); ++it) {
993+ PyDict_SetItemString (kwargs, it->first .c_str (),
994+ PyString_FromString (it->second .c_str ()));
995+ }
996+ PyObject *fig_args = PyTuple_New (1 );
997+ PyObject* fig = nullptr ;
998+ PyTuple_SetItem (fig_args, 0 , PyLong_FromLong (fig_number));
999+ PyObject *fig_exists =
1000+ PyObject_CallObject (detail::_interpreter::get ().s_python_function_fignum_exists , fig_args);
1001+ if (!PyObject_IsTrue (fig_exists)) {
1002+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
1003+ detail::_interpreter::get ().s_python_empty_tuple );
1004+ } else {
1005+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
1006+ fig_args);
1007+ }
1008+ Py_DECREF (fig_exists);
1009+ if (!fig) throw std::runtime_error (" Call to figure() failed." );
1010+
1011+ PyObject *gca_kwargs = PyDict_New ();
1012+ PyDict_SetItemString (gca_kwargs, " projection" , PyString_FromString (" 3d" ));
1013+
1014+ PyObject *gca = PyObject_GetAttrString (fig, " gca" );
1015+ if (!gca) throw std::runtime_error (" No gca" );
1016+ Py_INCREF (gca);
1017+ PyObject *axis = PyObject_Call (
1018+ gca, detail::_interpreter::get ().s_python_empty_tuple , gca_kwargs);
1019+
1020+ if (!axis) throw std::runtime_error (" No axis" );
1021+ Py_INCREF (axis);
1022+
1023+ Py_DECREF (gca);
1024+ Py_DECREF (gca_kwargs);
1025+
1026+ PyObject *plot3 = PyObject_GetAttrString (axis, " scatter" );
1027+ if (!plot3) throw std::runtime_error (" No 3D line plot" );
1028+ Py_INCREF (plot3);
1029+ PyObject *res = PyObject_Call (plot3, args, kwargs);
1030+ if (!res) throw std::runtime_error (" Failed 3D line plot" );
1031+ Py_DECREF (plot3);
1032+
1033+ Py_DECREF (axis);
1034+ Py_DECREF (args);
1035+ Py_DECREF (kwargs);
1036+ Py_DECREF (fig);
1037+ if (res) Py_DECREF (res);
1038+ return res;
1039+
1040+ }
1041+
9141042template <typename Numeric>
9151043bool boxplot (const std::vector<std::vector<Numeric>>& data,
9161044 const std::vector<std::string>& labels = {},
@@ -1139,9 +1267,9 @@ bool contour(const std::vector<NumericX>& x, const std::vector<NumericY>& y,
11391267 const std::map<std::string, std::string>& keywords = {}) {
11401268 assert (x.size () == y.size () && x.size () == z.size ());
11411269
1142- PyObject* xarray = get_array (x);
1143- PyObject* yarray = get_array (y);
1144- PyObject* zarray = get_array (z);
1270+ PyObject* xarray = detail:: get_array (x);
1271+ PyObject* yarray = detail:: get_array (y);
1272+ PyObject* zarray = detail:: get_array (z);
11451273
11461274 PyObject* plot_args = PyTuple_New (3 );
11471275 PyTuple_SetItem (plot_args, 0 , xarray);
@@ -2094,12 +2222,14 @@ inline void axvspan(double xmin, double xmax, double ymin = 0., double ymax = 1.
20942222
20952223 // construct keyword args
20962224 PyObject* kwargs = PyDict_New ();
2097- for (std::map<std::string, std::string>::const_iterator it = keywords.begin (); it != keywords.end (); ++it)
2098- {
2099- if (it->first == " linewidth" || it->first == " alpha" )
2100- PyDict_SetItemString (kwargs, it->first .c_str (), PyFloat_FromDouble (std::stod (it->second )));
2101- else
2102- PyDict_SetItemString (kwargs, it->first .c_str (), PyString_FromString (it->second .c_str ()));
2225+ for (auto it = keywords.begin (); it != keywords.end (); ++it) {
2226+ if (it->first == " linewidth" || it->first == " alpha" ) {
2227+ PyDict_SetItemString (kwargs, it->first .c_str (),
2228+ PyFloat_FromDouble (std::stod (it->second )));
2229+ } else {
2230+ PyDict_SetItemString (kwargs, it->first .c_str (),
2231+ PyString_FromString (it->second .c_str ()));
2232+ }
21032233 }
21042234
21052235 PyObject* res = PyObject_Call (detail::_interpreter::get ().s_python_function_axvspan , args, kwargs);
@@ -2319,6 +2449,25 @@ inline void save(const std::string& filename)
23192449 Py_DECREF (res);
23202450}
23212451
2452+ inline void rcparams (const std::map<std::string, std::string>& keywords = {}) {
2453+ detail::_interpreter::get ();
2454+ PyObject* args = PyTuple_New (0 );
2455+ PyObject* kwargs = PyDict_New ();
2456+ for (auto it = keywords.begin (); it != keywords.end (); ++it) {
2457+ if (" text.usetex" == it->first )
2458+ PyDict_SetItemString (kwargs, it->first .c_str (), PyLong_FromLong (std::stoi (it->second .c_str ())));
2459+ else PyDict_SetItemString (kwargs, it->first .c_str (), PyString_FromString (it->second .c_str ()));
2460+ }
2461+
2462+ PyObject * update = PyObject_GetAttrString (detail::_interpreter::get ().s_python_function_rcparams , " update" );
2463+ PyObject * res = PyObject_Call (update, args, kwargs);
2464+ if (!res) throw std::runtime_error (" Call to rcParams.update() failed." );
2465+ Py_DECREF (args);
2466+ Py_DECREF (kwargs);
2467+ Py_DECREF (update);
2468+ Py_DECREF (res);
2469+ }
2470+
23222471inline void clf () {
23232472 detail::_interpreter::get ();
23242473
0 commit comments