diff --git a/CHANGELOG.md b/CHANGELOG.md index 35ef66882..3599c619b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ See [Mixins/collections.py](src/runtime/Mixins/collections.py). - .NET arrays implement Python buffer protocol - Python.NET will correctly resolve .NET methods, that accept `PyList`, `PyInt`, and other `PyObject` derived types when called from Python. +- .NET classes, that have `__call__` method are callable from Python - `PyIterable` type, that wraps any iterable object in Python diff --git a/src/embed_tests/CallableObject.cs b/src/embed_tests/CallableObject.cs new file mode 100644 index 000000000..ab732be15 --- /dev/null +++ b/src/embed_tests/CallableObject.cs @@ -0,0 +1,87 @@ +using System; +using System.Collections.Generic; + +using NUnit.Framework; + +using Python.Runtime; + +namespace Python.EmbeddingTest +{ + public class CallableObject + { + [OneTimeSetUp] + public void SetUp() + { + PythonEngine.Initialize(); + using var locals = new PyDict(); + PythonEngine.Exec(CallViaInheritance.BaseClassSource, locals: locals.Handle); + CustomBaseTypeProvider.BaseClass = new PyType(locals[CallViaInheritance.BaseClassName]); + PythonEngine.InteropConfiguration.PythonBaseTypeProviders.Add(new CustomBaseTypeProvider()); + } + + [OneTimeTearDown] + public void Dispose() + { + PythonEngine.Shutdown(); + } + [Test] + public void CallMethodMakesObjectCallable() + { + var doubler = new DerivedDoubler(); + dynamic applyObjectTo21 = PythonEngine.Eval("lambda o: o(21)"); + Assert.AreEqual(doubler.__call__(21), (int)applyObjectTo21(doubler.ToPython())); + } + [Test] + public void CallMethodCanBeInheritedFromPython() + { + var callViaInheritance = new CallViaInheritance(); + dynamic applyObjectTo14 = PythonEngine.Eval("lambda o: o(14)"); + Assert.AreEqual(callViaInheritance.Call(14), (int)applyObjectTo14(callViaInheritance.ToPython())); + } + + [Test] + public void CanOverwriteCall() + { + var callViaInheritance = new CallViaInheritance(); + using var scope = Py.CreateScope(); + scope.Set("o", callViaInheritance); + scope.Exec("orig_call = o.Call"); + scope.Exec("o.Call = lambda a: orig_call(a*7)"); + int result = scope.Eval("o.Call(5)"); + Assert.AreEqual(105, result); + } + + class Doubler + { + public int __call__(int arg) => 2 * arg; + } + + class DerivedDoubler : Doubler { } + + class CallViaInheritance + { + public const string BaseClassName = "Forwarder"; + public static readonly string BaseClassSource = $@" +class MyCallableBase: + def __call__(self, val): + return self.Call(val) + +class {BaseClassName}(MyCallableBase): pass +"; + public int Call(int arg) => 3 * arg; + } + + class CustomBaseTypeProvider : IPythonBaseTypeProvider + { + internal static PyType BaseClass; + + public IEnumerable GetBaseTypes(Type type, IList existingBases) + { + Assert.Greater(BaseClass.Refcount, 0); + return type != typeof(CallViaInheritance) + ? existingBases + : new[] { BaseClass }; + } + } + } +} diff --git a/src/runtime/classbase.cs b/src/runtime/classbase.cs index 570ce3062..311b5b5f3 100644 --- a/src/runtime/classbase.cs +++ b/src/runtime/classbase.cs @@ -1,6 +1,9 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection; using System.Runtime.InteropServices; namespace Python.Runtime @@ -557,5 +560,44 @@ public static int mp_ass_subscript(IntPtr ob, IntPtr idx, IntPtr v) return 0; } + + static IntPtr tp_call_impl(IntPtr ob, IntPtr args, IntPtr kw) + { + IntPtr tp = Runtime.PyObject_TYPE(ob); + var self = (ClassBase)GetManagedObject(tp); + + if (!self.type.Valid) + { + return Exceptions.RaiseTypeError(self.type.DeletedMessage); + } + + Type type = self.type.Value; + + var calls = GetCallImplementations(type).ToList(); + Debug.Assert(calls.Count > 0); + var callBinder = new MethodBinder(); + foreach (MethodInfo call in calls) + { + callBinder.AddMethod(call); + } + return callBinder.Invoke(ob, args, kw); + } + + static IEnumerable GetCallImplementations(Type type) + => type.GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where(m => m.Name == "__call__"); + + static readonly Interop.TernaryFunc tp_call_delegate = tp_call_impl; + + public virtual void InitializeSlots(SlotsHolder slotsHolder) + { + if (!this.type.Valid) return; + + if (GetCallImplementations(this.type.Value).Any() + && !slotsHolder.IsHolding(TypeOffset.tp_call)) + { + TypeManager.InitializeSlot(ObjectReference, TypeOffset.tp_call, tp_call_delegate, slotsHolder); + } + } } } diff --git a/src/runtime/classmanager.cs b/src/runtime/classmanager.cs index 589ac0ad1..06d82c7b8 100644 --- a/src/runtime/classmanager.cs +++ b/src/runtime/classmanager.cs @@ -162,6 +162,9 @@ internal static Dictionary RestoreRuntimeData(R Runtime.PyType_Modified(pair.Value.TypeReference); var context = contexts[pair.Value.pyHandle]; pair.Value.Load(context); + var slotsHolder = TypeManager.GetSlotsHolder(pyType); + pair.Value.InitializeSlots(slotsHolder); + Runtime.PyType_Modified(pair.Value.TypeReference); loadedObjs.Add(pair.Value, context); } diff --git a/src/runtime/interop.cs b/src/runtime/interop.cs index 188db3a58..e10348e39 100644 --- a/src/runtime/interop.cs +++ b/src/runtime/interop.cs @@ -242,8 +242,13 @@ internal static ThunkInfo GetThunk(MethodInfo method, string funcType = null) return ThunkInfo.Empty; } Delegate d = Delegate.CreateDelegate(dt, method); - var info = new ThunkInfo(d); - allocatedThunks[info.Address] = d; + return GetThunk(d); + } + + internal static ThunkInfo GetThunk(Delegate @delegate) + { + var info = new ThunkInfo(@delegate); + allocatedThunks[info.Address] = @delegate; return info; } diff --git a/src/runtime/pytype.cs b/src/runtime/pytype.cs index 52ef60d04..546a3ed05 100644 --- a/src/runtime/pytype.cs +++ b/src/runtime/pytype.cs @@ -121,6 +121,20 @@ internal static BorrowedReference GetBase(BorrowedReference type) return new BorrowedReference(basePtr); } + internal static BorrowedReference GetBases(BorrowedReference type) + { + Debug.Assert(IsType(type)); + IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_bases); + return new BorrowedReference(basesPtr); + } + + internal static BorrowedReference GetMRO(BorrowedReference type) + { + Debug.Assert(IsType(type)); + IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_mro); + return new BorrowedReference(basesPtr); + } + private static IntPtr EnsureIsType(in StolenReference reference) { IntPtr address = reference.DangerousGetAddressOrNull(); diff --git a/src/runtime/typemanager.cs b/src/runtime/typemanager.cs index 1d6321791..7a836bf05 100644 --- a/src/runtime/typemanager.cs +++ b/src/runtime/typemanager.cs @@ -404,6 +404,10 @@ static void InitializeClass(PyType pyType, ClassBase impl, Type clrType) impl.tpHandle = type; impl.pyHandle = type; + impl.InitializeSlots(slotsHolder); + + Runtime.PyType_Modified(pyType.Reference); + //DebugUtil.DumpType(type); } @@ -787,6 +791,12 @@ static void InitializeSlot(IntPtr type, int slotOffset, MethodInfo method, Slots InitializeSlot(type, slotOffset, thunk, slotsHolder); } + internal static void InitializeSlot(BorrowedReference type, int slotOffset, Delegate impl, SlotsHolder slotsHolder) + { + var thunk = Interop.GetThunk(impl); + InitializeSlot(type.DangerousGetAddress(), slotOffset, thunk, slotsHolder); + } + static void InitializeSlot(IntPtr type, int slotOffset, ThunkInfo thunk, SlotsHolder slotsHolder) { Marshal.WriteIntPtr(type, slotOffset, thunk.Address); @@ -848,6 +858,9 @@ private static SlotsHolder CreateSolotsHolder(IntPtr type) _slotsHolders.Add(type, holder); return holder; } + + internal static SlotsHolder GetSlotsHolder(PyType type) + => _slotsHolders[type.Handle]; } @@ -873,6 +886,8 @@ public SlotsHolder(IntPtr type) _type = type; } + public bool IsHolding(int offset) => _slots.ContainsKey(offset); + public void Set(int offset, ThunkInfo thunk) { _slots[offset] = thunk;