diff --git a/matplotlibcpp.h b/matplotlibcpp.h index d95d46a..ffc7899 100644 --- a/matplotlibcpp.h +++ b/matplotlibcpp.h @@ -1,4 +1,6 @@ #pragma once +#ifndef __MATPLOTLIBCPP__H__ +#define __MATPLOTLIBCPP__H__ // Python headers must be included before any system headers, since // they define _POSIX_C_SOURCE @@ -351,9 +353,9 @@ template <> struct select_npy_type { const static NPY_TYPES type = NPY // Sanity checks; comment them out or change the numpy type below if you're compiling on // a platform where they don't apply static_assert(sizeof(long long) == 8); -template <> struct select_npy_type { const static NPY_TYPES type = NPY_INT64; }; +// template <> struct select_npy_type { const static NPY_TYPES type = NPY_INT64; }; static_assert(sizeof(unsigned long long) == 8); -template <> struct select_npy_type { const static NPY_TYPES type = NPY_UINT64; }; +// template <> struct select_npy_type { const static NPY_TYPES type = NPY_UINT64; }; template PyObject* get_array(const std::vector& v) @@ -470,6 +472,33 @@ bool plot(const std::vector &x, const std::vector &y, const st return res; } +// FIXME: +//Given a figure, create a subplot with a 3d axis and return a pointer to that axis. also calls incref on the axis pointer. +PyObject* init_3d_axis(PyObject *fig) +{ + PyObject *asp_kwargs = PyDict_New(); + PyDict_SetItemString(asp_kwargs, "projection", PyString_FromString("3d")); + + PyObject *asp = PyObject_GetAttrString(fig, "add_subplot"); + Py_INCREF(asp); + PyObject *tmpax = PyObject_Call(asp, detail::_interpreter::get().s_python_empty_tuple, asp_kwargs); + Py_INCREF(tmpax); + + PyObject *gca = PyObject_GetAttrString(fig, "gca"); + if (!gca) throw std::runtime_error("No gca"); + Py_INCREF(gca); + PyObject *axis = PyObject_Call(gca, detail::_interpreter::get().s_python_empty_tuple, detail::_interpreter::get().s_python_empty_tuple); + + if (!axis) throw std::runtime_error("No axis"); + Py_INCREF(axis); + + Py_DECREF(gca); + Py_DECREF(tmpax); + Py_DECREF(asp); + return axis; +} + + // TODO - it should be possible to make this work by implementing // a non-numpy alternative for `detail::get_2darray()`. #ifndef WITHOUT_NUMPY @@ -555,20 +584,8 @@ void plot_surface(const std::vector<::std::vector> &x, Py_DECREF(fig_exists); if (!fig) throw std::runtime_error("Call to figure() failed."); - PyObject *gca_kwargs = PyDict_New(); - PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); + PyObject *axis = init_3d_axis(fig); - PyObject *gca = PyObject_GetAttrString(fig, "gca"); - if (!gca) throw std::runtime_error("No gca"); - Py_INCREF(gca); - PyObject *axis = PyObject_Call( - gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); - - if (!axis) throw std::runtime_error("No axis"); - Py_INCREF(axis); - - Py_DECREF(gca); - Py_DECREF(gca_kwargs); PyObject *plot_surface = PyObject_GetAttrString(axis, "plot_surface"); if (!plot_surface) throw std::runtime_error("No surface"); @@ -723,20 +740,7 @@ void plot3(const std::vector &x, } if (!fig) throw std::runtime_error("Call to figure() failed."); - PyObject *gca_kwargs = PyDict_New(); - PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); - - PyObject *gca = PyObject_GetAttrString(fig, "gca"); - if (!gca) throw std::runtime_error("No gca"); - Py_INCREF(gca); - PyObject *axis = PyObject_Call( - gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); - - if (!axis) throw std::runtime_error("No axis"); - Py_INCREF(axis); - - Py_DECREF(gca); - Py_DECREF(gca_kwargs); + PyObject *axis = init_3d_axis(fig); PyObject *plot3 = PyObject_GetAttrString(axis, "plot"); if (!plot3) throw std::runtime_error("No 3D line plot"); @@ -1126,20 +1130,7 @@ bool scatter(const std::vector& x, Py_DECREF(fig_exists); if (!fig) throw std::runtime_error("Call to figure() failed."); - PyObject *gca_kwargs = PyDict_New(); - PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); - - PyObject *gca = PyObject_GetAttrString(fig, "gca"); - if (!gca) throw std::runtime_error("No gca"); - Py_INCREF(gca); - PyObject *axis = PyObject_Call( - gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); - - if (!axis) throw std::runtime_error("No axis"); - Py_INCREF(axis); - - Py_DECREF(gca); - Py_DECREF(gca_kwargs); + PyObject *axis = init_3d_axis(fig); PyObject *plot3 = PyObject_GetAttrString(axis, "scatter"); if (!plot3) throw std::runtime_error("No 3D line plot"); @@ -1496,27 +1487,15 @@ bool quiver(const std::vector& x, const std::vector& y, cons { PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); } - + //get figure gca to enable 3d projection PyObject *fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, detail::_interpreter::get().s_python_empty_tuple); if (!fig) throw std::runtime_error("Call to figure() failed."); - PyObject *gca_kwargs = PyDict_New(); - PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); - - PyObject *gca = PyObject_GetAttrString(fig, "gca"); - if (!gca) throw std::runtime_error("No gca"); - Py_INCREF(gca); - PyObject *axis = PyObject_Call( - gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); + PyObject *axis = init_3d_axis(fig); - if (!axis) throw std::runtime_error("No axis"); - Py_INCREF(axis); - Py_DECREF(gca); - Py_DECREF(gca_kwargs); - //plot our boys bravely, plot them strongly, plot them with a wink and clap PyObject *plot3 = PyObject_GetAttrString(axis, "quiver"); if (!plot3) throw std::runtime_error("No 3D line plot"); @@ -2522,7 +2501,7 @@ inline void set_zlabel(const std::string &str, const std::map& keywords = std::map()) { detail::_interpreter::get(); @@ -2532,10 +2511,16 @@ inline void grid(bool flag) PyObject* args = PyTuple_New(1); PyTuple_SetItem(args, 0, pyflag); - PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_grid, args); + PyObject* kwargs = PyDict_New(); + for (auto it = keywords.begin(); it != keywords.end(); ++it) { + PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); + } + + PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_grid, args, kwargs); if(!res) throw std::runtime_error("Call to grid() failed."); Py_DECREF(args); + Py_DECREF(kwargs); Py_DECREF(res); } @@ -2655,7 +2640,7 @@ inline void rcparams(const std::map& keywords = {}) { PyDict_SetItemString(kwargs, it->first.c_str(), PyLong_FromLong(std::stoi(it->second.c_str()))); else PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); } - + PyObject * update = PyObject_GetAttrString(detail::_interpreter::get().s_python_function_rcparams, "update"); PyObject * res = PyObject_Call(update, args, kwargs); if(!res) throw std::runtime_error("Call to rcParams.update() failed."); @@ -2984,3 +2969,5 @@ class Plot }; } // end namespace matplotlibcpp + +#endif //!__MATPLOTLIBCPP__H__ \ No newline at end of file