10000 allow substituting base types for CLR types (as seen from Python) · pythonnet/pythonnet@d3c6942 · GitHub
[go: up one dir, main page]

Skip to content

Commit d3c6942

Browse files
committed
allow substituting base types for CLR types (as seen from Python)
When embedding Python, host can now provide custom implementations of IPythonBaseTypeProvider via PythonEngine.InteropConfiguration. When .NET type is reflected to Python, this type provider will be able to specify which bases the resulting Python class will have. This implements #862
1 parent 4b7a23c commit d3c6942

11 files changed

+468
-49
lines changed

src/embed_tests/Inheritance.cs

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
using System;
2 8000 +
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Runtime.InteropServices;
5+
6+
using NUnit.Framework;
7+
8+
using Python.Runtime;
9+
10+
namespace Python.EmbeddingTest
11+
{
12+
public class Inheritance
13+
{
14+
[OneTimeSetUp]
15+
public void SetUp()
16+
{
17+
PythonEngine.Initialize();
18+
var locals = new PyDict();
19+
PythonEngine.Exec(InheritanceTestBaseClassWrapper.ClassSourceCode, locals: locals.Handle);
20+
ExtraBaseTypeProvider.ExtraBase = new PyType(locals[InheritanceTestBaseClassWrapper.ClassName]);
21+
var baseTypeProviders = PythonEngine.InteropConfiguration.PythonBaseTypeProviders;
22+
baseTypeProviders.Add(new ExtraBaseTypeProvider());
23+
77F4 baseTypeProviders.Add(new NoEffectBaseTypeProvider());
24+
}
25+
26+
[OneTimeTearDown]
27+
public void Dispose()
28+
{
29+
PythonEngine.Shutdown();
30+
}
31+
32+
[Test]
33+
public void ExtraBase_PassesInstanceCheck()
34+
{
35+
var inherited = new Inherited();
36+
bool properlyInherited = PyIsInstance(inherited, ExtraBaseTypeProvider.ExtraBase);
37+
Assert.IsTrue(properlyInherited);
38+
}
39+
40+
static dynamic PyIsInstance => PythonEngine.Eval("isinstance");
41+
42+
[Test]
43+
public void InheritingWithExtraBase_CreatesNewClass()
44+
{
45+
PyObject a = ExtraBaseTypeProvider.ExtraBase;
46+
var inherited = new Inherited();
47+
PyObject inheritedClass = inherited.ToPython().GetAttr("__class__");
48+
Assert.IsFalse(PythonReferenceComparer.Instance.Equals(a, inheritedClass));
49+
}
50+
51+
[Test]
52+
public void InheritedFromInheritedClassIsSelf()
53+
{
54+
using var scope = Py.CreateScope();
55+
scope.Exec($"from {typeof(Inherited).Namespace} import {nameof(Inherited)}");
56+
scope.Exec($"class B({nameof(Inherited)}): pass");
57+
PyObject b = scope.Eval("B");
58+
PyObject bInstance = b.Invoke();
59+
PyObject bInstanceClass = bInstance.GetAttr("__class__");
60+
Assert.IsTrue(PythonReferenceComparer.Instance.Equals(b, bInstanceClass));
61+
}
62+
63+
[Test]
64+
public void Grandchild_PassesExtraBaseInstanceCheck()
65+
{
66+
using var scope = Py.CreateScope();
67+
scope.Exec($"from {typeof(Inherited).Namespace} import {nameof(Inherited)}");
68+
scope.Exec($"class B({nameof(Inherited)}): pass");
69+
PyObject b = scope.Eval("B");
70+
PyObject bInst = b.Invoke();
71+
bool properlyInherited = PyIsInstance(bInst, ExtraBaseTypeProvider.ExtraBase);
72+
Assert.IsTrue(properlyInherited);
73+
}
74+
75+
[Test]
76+
public void CallInheritedClrMethod_WithExtraPythonBase()
77+
{
78+
var instance = new Inherited().ToPython();
79+
string result = instance.InvokeMethod(nameof(PythonWrapperBase.WrapperBaseMethod)).As<string>();
80+
Assert.AreEqual(result, nameof(PythonWrapperBase.WrapperBaseMethod));
81+
}
82+
83+
[Test]
84+
public void CallExtraBaseMethod()
85+
{
86+
var instance = new Inherited();
87+
using var scope = Py.CreateScope();
88+
scope.Set(nameof(instance), instance);
89+
int actual = instance.ToPython().InvokeMethod("callVirt").As<int>();
90+
Assert.AreEqual(expected: Inherited.OverridenVirtValue, actual);
91+
}
92+
93+
[Test]
94+
public void SetAdHocAttributes_WhenExtraBasePresent()
95+
{
96+
var instance = new Inherited();
97+
using var scope = Py.CreateScope();
98+
scope.Set(nameof(instance), instance);
99+
scope.Exec($"super({nameof(instance)}.__class__, {nameof(instance)}).set_x_to_42()");
100+
int actual = scope.Eval<int>($"{nameof(instance)}.{nameof(Inherited.XProp)}");
101+
Assert.AreEqual(expected: Inherited.X, actual);
102+
}
103+
}
104+
105+
class ExtraBaseTypeProvider : IPythonBaseTypeProvider
106+
{
107+
internal static PyType ExtraBase;
108+
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
109+
{
110+
if (type == typeof(InheritanceTestBaseClassWrapper))
111+
{
112+
return new[] { PyType.Get(type.BaseType), ExtraBase };
113+
}
114+
return existingBases;
115+
}
116+
}
117+
118+
class NoEffectBaseTypeProvider : IPythonBaseTypeProvider
119+
{
120+
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
121+
=> existingBases;
122+
}
123+
124+
public class PythonWrapperBase
125+
{
126+
public string WrapperBaseMethod() => nameof(WrapperBaseMethod);
127+
}
128+
129+
public class InheritanceTestBaseClassWrapper : PythonWrapperBase
130+
{
131+
public const string ClassName = "InheritanceTestBaseClass";
132+
public const string ClassSourceCode = "class " + ClassName +
133+
@":
134+
def virt(self):
135+
return 42
136+
def set_x_to_42(self):
137+
self.XProp = 42
138+
def callVirt(self):
139+
return self.virt()
140+
def __getattr__(self, name):
141+
return '__getattr__:' + name
142+
def __setattr__(self, name, value):
143+
value[name] = name
144+
" + ClassName + " = " + ClassName + "\n";
145+
}
146+
147+
public class Inherited : InheritanceTestBaseClassWrapper
148+
{
149+
public const int OverridenVirtValue = -42;
150+
public const int X = 42;
151+
readonly Dictionary<string, object> extras = new Dictionary<string, object>();
152+
public int virt() => OverridenVirtValue;
153+
public int XProp
154+
{
155+
get
156+
{
157+
using (var scope = Py.CreateScope())
158+
{
159+
scope.Set("this", this);
160+
try
161+
{
162+
return scope.Eval<int>($"super(this.__class__, this).{nameof(XProp)}");
163+
}
164+
catch (PythonException ex) when (ex.Type.Handle == Exceptions.AttributeError)
165+
{
166+
if (this.extras.TryGetValue(nameof(this.XProp), out object value))
167+
return (int)value;
168+
throw;
169+
}
170+
}
171+
}
172+
set => this.extras[nameof(this.XProp)] = value;
173+
}
174+
}
175+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace Python.Runtime
5+
{
6+
/// <summary>Minimal Python base type provider</summary>
7+
public sealed class DefaultBaseTypeProvider : IPythonBaseTypeProvider
8+
{
9+
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
10+
{
11+
if (type is null)
12+
throw new ArgumentNullException(nameof(type));
13+
if (existingBases is null)
14+
throw new ArgumentNullException(nameof(existingBases));
15+
if (existingBases.Count > 0)
16+
throw new ArgumentException("To avoid confusion, this type provider requires the initial set of base types to be empty");
17+
18+
return new[] { new PyType(GetBaseType(type)) };
19+
}
20+
21+
static BorrowedReference GetBaseType(Type type)
22+
{
23+
if (type == typeof(Exception))
24+
return new BorrowedReference(Exceptions.Exception);
25+
26+
return type.BaseType is not null
27+
? ClassManager.GetClass(type.BaseType).ObjectReference
28+
: new BorrowedReference(Runtime.PyBaseObjectType);
29+
}
30+
31+
DefaultBaseTypeProvider(){}
32+
public static DefaultBaseTypeProvider Instance { get; } = new DefaultBaseTypeProvider();
33+
}
34+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace Python.Runtime
5+
{
6+
public interface IPythonBaseTypeProvider
7+
{
8+
/// <summary>
9+
/// Get Python types, that should be presented to Python as the base types
10+
/// for the specified .NET type.
11+
/// </summary>
12+
IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases);
13+
}
14+
}

src/runtime/InteropConfiguration.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
namespace Python.Runtime
2+
{
3+
using System;
4+
using System.Collections.Generic;
5+
6+
public sealed class InteropConfiguration
7+
{
8+
internal readonly PythonBaseTypeProviderGroup pythonBaseTypeProviders
9+
= new PythonBaseTypeProviderGroup();
10+
11+
/// <summary>Enables replacing base types of CLR types as seen from Python</summary>
12+
public IList<IPythonBaseTypeProvider> PythonBaseTypeProviders => this.pythonBaseTypeProviders;
13+
14+
public static InteropConfiguration MakeDefault()
15+
{
16+
return new InteropConfiguration
17+
{
18+
PythonBaseTypeProviders =
19+
{
20+
DefaultBaseTypeProvider.Instance,
21+
},
22+
};
23+
}
24+
}
25+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
5+
namespace Python.Runtime
6+
{
7+
class PythonBaseTypeProviderGroup : List<IPythonBaseTypeProvider>, IPythonBaseTypeProvider
8+
{
9+
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
10+
{
11+
if (type is null)
12+
throw new ArgumentNullException(nameof(type));
13+
if (existingBases is null)
14+
throw new ArgumentNullException(nameof(existingBases));
15+
16+
foreach (var provider in this)
17+
{
18+
existingBases = provider.GetBaseTypes(type, existingBases).ToList();
19+
}
20+
21+
return existingBases;
22+
}
23+
}
24+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#nullable enable
2+
using System.Collections.Generic;
3+
4+
namespace Python.Runtime
5+
{
6+
/// <summary>
7+
/// Compares Python object wrappers by Python object references.
8+
/// <para>Similar to <see cref="object.ReferenceEquals"/> but for Python objects</para>
9+
/// </summary>
10+
public sealed class PythonReferenceComparer : IEqualityComparer<PyObject>
11+
{
12+
public static PythonReferenceComparer Instance { get; } = new PythonReferenceComparer();
13+
public bool Equals(PyObject? x, PyObject? y)
14+
{
15+
return x?.Handle == y?.Handle;
16+
}
17+
18+
public int GetHashCode(PyObject obj) => obj.Handle.GetHashCode();
19+
20+
private PythonReferenceComparer() { }
21+
}
22+
}

src/runtime/StolenReference.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ public override bool Equals(object obj)
3535

3636
[Pure]
3737
public override int GetHashCode() => Pointer.GetHashCode();
38+
39+
[Pure]
40+
public static StolenReference DangerousFromPointer(IntPtr ptr)
41+
{
42+
if (ptr == IntPtr.Zero) throw new ArgumentNullException(nameof(ptr));
43+
return new StolenReference(ptr);
44+
}
3845
}
3946

4047
static class StolenReferenceExtensions

src/runtime/classmanager.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,14 @@ private static ClassBase CreateClass(Type type)
252252

253253
private static void InitClassBase(Type type, ClassBase impl)
254254
{
255+
// Ensure, that matching Python type exists first.
256+
// It is required for self-referential classes
257+
// (e.g. with members, that refer to the same class)
258+
var pyType = TypeManager.GetOrCreateClass(type);
259+
260+
// Set the handle attributes on the implementing instance.
261+
impl.tpHandle = impl.pyHandle = pyType.Handle;
262+
255263
// First, we introspect the managed type and build some class
256264
// information, including generating the member descriptors
257265
// that we'll be putting in the Python class __dict__.
@@ -261,12 +269,12 @@ private static void InitClassBase(Type type, ClassBase impl)
261269
impl.indexer = info.indexer;
262270
impl.richcompare = new Dictionary<int, MethodObject>();
263271

264-
// Now we allocate the Python type object to reflect the given
272+
// Now we force initialize the Python type object to reflect the given
265273
// managed type, filling the Python type slots with thunks that
266274
// point to the managed methods providing the implementation.
267275

268276

269-
var pyType = TypeManager.GetType(impl, type);
277+
TypeManager.GetOrInitializeClass(impl, type);
270278

271279
// Finally, initialize the class __dict__ and return the object.
272280
using var dict = Runtime.PyObject_GenericGetDict(pyType.Reference);

src/runtime/pythonengine.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public static ShutdownMode ShutdownMode
2626
private static IntPtr _pythonHome = IntPtr.Zero;
2727
private static IntPtr _programName = IntPtr.Zero;
2828
private static IntPtr _pythonPath = IntPtr.Zero;
29+
private static InteropConfiguration interopConfiguration = InteropConfiguration.MakeDefault();
2930

3031
public PythonEngine()
3132
{
@@ -68,6 +69,18 @@ internal static DelegateManager DelegateManager
6869
}
6970
}
7071

72+
public static InteropConfiguration InteropConfiguration
73+
{
74+
get => interopConfiguration;
75+
set
76+
{
77+
if (IsInitialized)
78+
throw new NotSupportedException("Changing interop configuration when engine is running is not supported");
79+
80+
interopConfiguration = value ?? throw new ArgumentNullException(nameof(InteropConfiguration));
81+
}
82+
}
83+
7184
public static string ProgramName
7285
{
7386
get
@@ -334,6 +347,8 @@ public static void Shutdown(ShutdownMode mode)
334347
PyObjectConversions.Reset();
335348

336349
initialized = false;
350+
351+
InteropConfiguration = InteropConfiguration.MakeDefault();
337352
}
338353

339354
/// <summary>
@@ -563,7 +578,7 @@ public static ulong GetPythonThreadID()
563578
/// Interrupts the execution of a thread.
564579
/// </summary>
565580
/// <param name="pythonThreadID">The Python thread ID.</param>
566-
/// <returns>The number of thread states modified; this is normally one, but will be zero if the thread id isn’t found.</returns>
581+
/// <returns>The number of thread states modified; this is normally one, but will be zero if the thread id is not found.</returns>
567582
public static int Interrupt(ulong pythonThreadID)
568583
{
569584
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))

0 commit comments

Comments
 (0)
0