From dc9172ae3346b51490ef9027f9c2b92e365d3476 Mon Sep 17 00:00:00 2001 From: Victor Milovanov Date: Sun, 31 May 2020 18:54:18 -0700 Subject: [PATCH] implements buffer interface for .NET arrays of primitive types fixes https://github.com/losttech/Gradient/issues/27 --- .github/workflows/main.yml | 1 + CHANGELOG.md | 6 +- src/embed_tests/NumPyTests.cs | 94 ++++++++++++++++++ src/embed_tests/TestExample.cs | 53 ---------- src/embed_tests/TestPyBuffer.cs | 13 +++ src/runtime/arrayobject.cs | 163 +++++++++++++++++++++++++++++++ src/runtime/bufferinterface.cs | 11 +++ src/runtime/exceptions.cs | 1 + src/runtime/native/TypeOffset.cs | 1 + src/runtime/typemanager.cs | 3 + 10 files changed, 291 insertions(+), 55 deletions(-) create mode 100644 src/embed_tests/NumPyTests.cs delete mode 100644 src/embed_tests/TestExample.cs diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b5a0080a1..2818fb09c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -41,6 +41,7 @@ jobs: - name: Install dependencies run: | pip install --upgrade -r requirements.txt + pip install numpy # for tests - name: Build and Install run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index bc30155d8..ce9102a5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,10 @@ This document follows the conventions laid out in [Keep a CHANGELOG][]. - Ability to implement delegates with `ref` and `out` parameters in Python, by returning the modified parameter values in a tuple. ([#1355][i1355]) - `PyType` - a wrapper for Python type objects, that also permits creating new heap types from `TypeSpec` - Improved exception handling: - - exceptions can now be converted with codecs - - `InnerException` and `__cause__` are propagated properly +- exceptions can now be converted with codecs +- `InnerException` and `__cause__` are propagated properly +- .NET arrays implement Python buffer protocol + ### Changed - Drop support for Python 2, 3.4, and 3.5 diff --git a/src/embed_tests/NumPyTests.cs b/src/embed_tests/NumPyTests.cs new file mode 100644 index 000000000..f31f7b25b --- /dev/null +++ b/src/embed_tests/NumPyTests.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using Python.Runtime; +using Python.Runtime.Codecs; + +namespace Python.EmbeddingTest +{ + public class NumPyTests + { + [OneTimeSetUp] + public void SetUp() + { + PythonEngine.Initialize(); + TupleCodec.Register(); + } + + [OneTimeTearDown] + public void Dispose() + { + PythonEngine.Shutdown(); + } + + [Test] + public void TestReadme() + { + dynamic np; + try + { + np = Py.Import("numpy"); + } + catch (PythonException) + { + Assert.Inconclusive("Numpy or dependency not installed"); + return; + } + + Assert.AreEqual("1.0", np.cos(np.pi * 2).ToString()); + + dynamic sin = np.sin; + StringAssert.StartsWith("-0.95892", sin(5).ToString()); + + double c = np.cos(5) + sin(5); + Assert.AreEqual(-0.675262, c, 0.01); + + dynamic a = np.array(new List { 1, 2, 3 }); + Assert.AreEqual("float64", a.dtype.ToString()); + + dynamic b = np.array(new List { 6, 5, 4 }, Py.kw("dtype", np.int32)); + Assert.AreEqual("int32", b.dtype.ToString()); + + Assert.AreEqual("[ 6. 10. 12.]", (a * b).ToString().Replace(" ", " ")); + } + + [Test] + public void MultidimensionalNumPyArray() + { + PyObject np; + try { + np = Py.Import("numpy"); + } catch (PythonException) { + Assert.Inconclusive("Numpy or dependency not installed"); + return; + } + + var array = new[,] { { 1, 2 }, { 3, 4 } }; + var ndarray = np.InvokeMethod("asarray", array.ToPython()); + Assert.AreEqual((2,2), ndarray.GetAttr("shape").As<(int,int)>()); + Assert.AreEqual(1, ndarray[(0, 0).ToPython()].InvokeMethod("__int__").As()); + Assert.AreEqual(array[1, 0], ndarray[(1, 0).ToPython()].InvokeMethod("__int__").As()); + } + + [Test] + public void Int64Array() + { + PyObject np; + try + { + np = Py.Import("numpy"); + } + catch (PythonException) + { + Assert.Inconclusive("Numpy or dependency not installed"); + return; + } + + var array = new long[,] { { 1, 2 }, { 3, 4 } }; + var ndarray = np.InvokeMethod("asarray", array.ToPython()); + Assert.AreEqual((2, 2), ndarray.GetAttr("shape").As<(int, int)>()); + Assert.AreEqual(1, ndarray[(0, 0).ToPython()].InvokeMethod("__int__").As()); + Assert.AreEqual(array[1, 0], ndarray[(1, 0).ToPython()].InvokeMethod("__int__").As()); + } + } +} diff --git a/src/embed_tests/TestExample.cs b/src/embed_tests/TestExample.cs deleted file mode 100644 index 671f9e33d..000000000 --- a/src/embed_tests/TestExample.cs +++ /dev/null @@ -1,53 +0,0 @@ -using System; -using System.Collections.Generic; -using NUnit.Framework; -using Python.Runtime; - -namespace Python.EmbeddingTest -{ - public class TestExample - { - [OneTimeSetUp] - public void SetUp() - { - PythonEngine.Initialize(); - } - - [OneTimeTearDown] - public void Dispose() - { - PythonEngine.Shutdown(); - } - - [Test] - public void TestReadme() - { - dynamic np; - try - { - np = Py.Import("numpy"); - } - catch (PythonException) - { - Assert.Inconclusive("Numpy or dependency not installed"); - return; - } - - Assert.AreEqual("1.0", np.cos(np.pi * 2).ToString()); - - dynamic sin = np.sin; - StringAssert.StartsWith("-0.95892", sin(5).ToString()); - - double c = np.cos(5) + sin(5); - Assert.AreEqual(-0.675262, c, 0.01); - - dynamic a = np.array(new List { 1, 2, 3 }); - Assert.AreEqual("float64", a.dtype.ToString()); - - dynamic b = np.array(new List { 6, 5, 4 }, Py.kw("dtype", np.int32)); - Assert.AreEqual("int32", b.dtype.ToString()); - - Assert.AreEqual("[ 6. 10. 12.]", (a * b).ToString().Replace(" ", " ")); - } - } -} diff --git a/src/embed_tests/TestPyBuffer.cs b/src/embed_tests/TestPyBuffer.cs index 0338a1480..43ed5ffd4 100644 --- a/src/embed_tests/TestPyBuffer.cs +++ b/src/embed_tests/TestPyBuffer.cs @@ -1,6 +1,8 @@ +using System; using System.Text; using NUnit.Framework; using Python.Runtime; +using Python.Runtime.Codecs; namespace Python.EmbeddingTest { class TestPyBuffer @@ -9,6 +11,7 @@ class TestPyBuffer public void SetUp() { PythonEngine.Initialize(); + TupleCodec.Register(); } [OneTimeTearDown] @@ -64,5 +67,15 @@ public void TestBufferRead() } } } + + [Test] + public void ArrayHasBuffer() + { + var array = new[,] {{1, 2}, {3,4}}; + var memoryView = PythonEngine.Eval("memoryview"); + var mem = memoryView.Invoke(array.ToPython()); + Assert.AreEqual(1, mem[(0, 0).ToPython()].As()); + Assert.AreEqual(array[1,0], mem[(1, 0).ToPython()].As()); + } } } diff --git a/src/runtime/arrayobject.cs b/src/runtime/arrayobject.cs index 5c97c6dbf..ac2425001 100644 --- a/src/runtime/arrayobject.cs +++ b/src/runtime/arrayobject.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Runtime.InteropServices; namespace Python.Runtime { @@ -366,5 +368,166 @@ public static int sq_contains(IntPtr ob, IntPtr v) return 0; } + + #region Buffer protocol + static int GetBuffer(BorrowedReference obj, out Py_buffer buffer, PyBUF flags) + { + buffer = default; + + if (flags == PyBUF.SIMPLE) + { + Exceptions.SetError(Exceptions.BufferError, "SIMPLE not implemented"); + return -1; + } + if ((flags & PyBUF.F_CONTIGUOUS) == PyBUF.F_CONTIGUOUS) + { + Exceptions.SetError(Exceptions.BufferError, "only C-contiguous supported"); + return -1; + } + var self = (Array)((CLRObject)GetManagedObject(obj)).inst; + Type itemType = self.GetType().GetElementType(); + + bool formatRequested = (flags & PyBUF.FORMATS) != 0; + string format = GetFormat(itemType); + if (formatRequested && format is null) + { + Exceptions.SetError(Exceptions.BufferError, "unsupported element type: " + itemType.Name); + return -1; + } + GCHandle gcHandle; + try + { + gcHandle = GCHandle.Alloc(self, GCHandleType.Pinned); + } catch (ArgumentException ex) + { + Exceptions.SetError(Exceptions.BufferError, ex.Message); + return -1; + } + + int itemSize = Marshal.SizeOf(itemType); + IntPtr[] shape = GetShape(self); + IntPtr[] strides = GetStrides(shape, itemSize); + buffer = new Py_buffer + { + buf = gcHandle.AddrOfPinnedObject(), + obj = Runtime.SelfIncRef(obj.DangerousGetAddress()), + len = (IntPtr)(self.LongLength*itemSize), + itemsize = (IntPtr)itemSize, + _readonly = false, + ndim = self.Rank, + format = format, + shape = ToUnmanaged(shape), + strides = (flags & PyBUF.STRIDES) == PyBUF.STRIDES ? ToUnmanaged(strides) : IntPtr.Zero, + suboffsets = IntPtr.Zero, + _internal = (IntPtr)gcHandle, + }; + + return 0; + } + static void ReleaseBuffer(BorrowedReference obj, ref Py_buffer buffer) + { + if (buffer._internal == IntPtr.Zero) return; + + UnmanagedFree(ref buffer.shape); + UnmanagedFree(ref buffer.strides); + UnmanagedFree(ref buffer.suboffsets); + + var gcHandle = (GCHandle)buffer._internal; + gcHandle.Free(); + buffer._internal = IntPtr.Zero; + } + + static IntPtr[] GetStrides(IntPtr[] shape, long itemSize) + { + var result = new IntPtr[shape.Length]; + result[shape.Length - 1] = new IntPtr(itemSize); + for (int dim = shape.Length - 2; dim >= 0; dim--) + { + itemSize *= shape[dim + 1].ToInt64(); + result[dim] = new IntPtr(itemSize); + } + return result; + } + static IntPtr[] GetShape(Array array) + { + var result = new IntPtr[array.Rank]; + for (int i = 0; i < result.Length; i++) + result[i] = (IntPtr)array.GetLongLength(i); + return result; + } + + static void UnmanagedFree(ref IntPtr address) + { + if (address == IntPtr.Zero) return; + + Marshal.FreeHGlobal(address); + address = IntPtr.Zero; + } + static unsafe IntPtr ToUnmanaged(T[] array) where T : unmanaged + { + IntPtr result = Marshal.AllocHGlobal(checked(Marshal.SizeOf(typeof(T)) * array.Length)); + fixed (T* ptr = array) + { + var @out = (T*)result; + for (int i = 0; i < array.Length; i++) + @out[i] = ptr[i]; + } + return result; + } + + static readonly Dictionary ItemFormats = new Dictionary + { + [typeof(byte)] = "B", + [typeof(sbyte)] = "b", + + [typeof(bool)] = "?", + + [typeof(short)] = "h", + [typeof(ushort)] = "H", + // see https://github.com/pybind/pybind11/issues/1908#issuecomment-658358767 + [typeof(int)] = "i", + [typeof(uint)] = "I", + [typeof(long)] = "q", + [typeof(ulong)] = "Q", + + [typeof(IntPtr)] = "n", + [typeof(UIntPtr)] = "N", + + // TODO: half = "e" + [typeof(float)] = "f", + [typeof(double)] = "d", + }; + + static string GetFormat(Type elementType) + => ItemFormats.TryGetValue(elementType, out string result) ? result : null; + + static readonly GetBufferProc getBufferProc = GetBuffer; + static readonly ReleaseBufferProc releaseBufferProc = ReleaseBuffer; + static readonly IntPtr BufferProcsAddress = AllocateBufferProcs(); + static IntPtr AllocateBufferProcs() + { + var procs = new PyBufferProcs + { + Get = Marshal.GetFunctionPointerForDelegate(getBufferProc), + Release = Marshal.GetFunctionPointerForDelegate(releaseBufferProc), + }; + IntPtr result = Marshal.AllocHGlobal(Marshal.SizeOf(typeof(PyBufferProcs))); + Marshal.StructureToPtr(procs, result, fDeleteOld: false); + return result; + } + #endregion + + /// + /// + /// + public static void InitializeSlots(IntPtr type, ISet initialized, SlotsHolder slotsHolder) + { + if (initialized.Add(nameof(TypeOffset.tp_as_buffer))) + { + // TODO: only for unmanaged arrays + int offset = TypeOffset.GetSlotOffset(nameof(TypeOffset.tp_as_buffer)); + Marshal.WriteIntPtr(type, offset, BufferProcsAddress); + } + } } } diff --git a/src/runtime/bufferinterface.cs b/src/runtime/bufferinterface.cs index 0c0ac2140..e39cdd5b4 100644 --- a/src/runtime/bufferinterface.cs +++ b/src/runtime/bufferinterface.cs @@ -103,4 +103,15 @@ public enum PyBUF /// FULL_RO = (INDIRECT | FORMATS), } + + internal struct PyBufferProcs + { + public IntPtr Get; + public IntPtr Release; + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + delegate int GetBufferProc(BorrowedReference obj, out Py_buffer buffer, PyBUF flags); + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + delegate void ReleaseBufferProc(BorrowedReference obj, ref Py_buffer buffer); } diff --git a/src/runtime/exceptions.cs b/src/runtime/exceptions.cs index cc8da3899..f1a06c328 100644 --- a/src/runtime/exceptions.cs +++ b/src/runtime/exceptions.cs @@ -413,6 +413,7 @@ public static variables on the Exceptions class filled in from public static IntPtr AssertionError; public static IntPtr AttributeError; + public static IntPtr BufferError; public static IntPtr EOFError; public static IntPtr FloatingPointError; public static IntPtr EnvironmentError; diff --git a/src/runtime/native/TypeOffset.cs b/src/runtime/native/TypeOffset.cs index 6e6da2d93..b5957a9c7 100644 --- a/src/runtime/native/TypeOffset.cs +++ b/src/runtime/native/TypeOffset.cs @@ -159,6 +159,7 @@ static void ValidateRequiredOffsetsPresent(PropertyInfo[] offsetProperties) "GetClrType", "getPreload", "Initialize", + "InitializeSlots", "ListAssemblies", "_load_clr_module", "Release", diff --git a/src/runtime/typemanager.cs b/src/runtime/typemanager.cs index 26dcea153..8db3516ac 100644 --- a/src/runtime/typemanager.cs +++ b/src/runtime/typemanager.cs @@ -745,6 +745,9 @@ internal static void InitializeSlots(IntPtr type, Type impl, SlotsHolder slotsHo seen.Add(name); } + var initSlot = impl.GetMethod("InitializeSlots", BindingFlags.Static | BindingFlags.Public); + initSlot?.Invoke(null, parameters: new object[] { type, seen, slotsHolder }); + impl = impl.BaseType; }