8000 improve list codec · pythonnet/pythonnet@d9e1e2c · GitHub
[go: up one dir, main page]

Skip to content

Commit d9e1e2c

Browse files
committed
improve list codec
1 parent 9d3d4cb commit d9e1e2c

File tree

2 files changed

+61
-32
lines changed

2 files changed

+61
-32
lines changed

src/embed_tests/Codecs.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,16 @@ public void ListCodecTest()
9090
var items = new List<PyObject>() { new PyInt(1), new PyInt(2), new PyInt(3) };
9191

9292
var x = new PyList(items.ToArray());
93-
Assert.IsTrue(codec.CanDecode(x, typeof(List<int>)));
9493
Assert.IsTrue(codec.CanDecode(x, typeof(IList<bool>)));
9594
Assert.IsTrue(codec.CanDecode(x, typeof(System.Collections.IEnumerable)));
9695
Assert.IsTrue(codec.CanDecode(x, typeof(IEnumerable<int>)));
9796
Assert.IsTrue(codec.CanDecode(x, typeof(ICollection<float>)));
9897
Assert.IsFalse(codec.CanDecode(x, typeof(bool)));
9998

99+
//we'd have to copy into a list to do this. not the best idea to support it.
100+
//maybe there can be a flag on listcodec to allow it.
101+
Assert.IsFalse(codec.CanDecode(x, typeof(List<int>)));
102+
100103
Action<System.Collections.IEnumerable> checkPlainEnumerable = (System.Collections.IEnumerable enumerable) =>
101104
{
102105
Assert.IsNotNull(enumerable);
@@ -145,7 +148,9 @@ raise StopIteration
145148
Assert.IsFalse(codec.CanDecode(foo, typeof(ICollection<int>)));
146149
Assert.IsFalse(codec.CanDecode(foo, typeof(IList<int>)));
147150

148-
151+
IEnumerable<int> intEnumerable = null;
152+
Assert.DoesNotThrow(() => { codec.TryDecode<IEnumerable<int>>(x, out intEnumerable); });
153+
checkPlainEnumerable(intEnumerable);
149154
}
150155
}
151156
}

src/runtime/Codecs/ListCodec.cs

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ private CollectionRank GetRank(PyObject objectType)
3737
return CollectionRank.Iterable;
3838
}
3939

40-
private CollectionRank GetRank(Type targetType)
40+
private Tuple<CollectionRank, Type> GetRankAndType(Type collectionType)
4141
{
4242
//if it is a plain IEnumerable, we can decode it using sequence protocol.
43-
if (targetType == typeof(System.Collections.IEnumerable))
44-
return CollectionRank.Iterable;
43+
if (collectionType == typeof(System.Collections.IEnumerable))
44+
return new Tuple<CollectionRank, Type>(CollectionRank.Iterable, typeof(object));
4545

4646
Func<Type, CollectionRank> getRankOfType = (Type type) => {
4747
if (type.GetGenericTypeDefinition() == typeof(IList<>))
@@ -53,32 +53,25 @@ private CollectionRank GetRank(Type targetType)
5353
return CollectionRank.None;
5454
};
5555

56-
if (targetType.IsGenericType)
57-
{
58-
var thisRank = getRankOfType(targetType);
59-
if (thisRank != CollectionRank.None)
60-
return thisRank;
61-
}
62-
63-
var maxRank = CollectionRank.None;
64-
//if it implements any of the standard C# collection interfaces, we can decode it.
65-
foreach (Type itf in targetType.GetInterfaces())
56+
if (collectionType.IsGenericType)
6657
{
67-
if (!itf.IsGenericType) continue;
68-
69-
var thisRank = getRankOfType(itf);
58+
//for compatibility we *could* do this and copy the value but probably not the best option.
59+
/*if (collectionType.GetGenericTypeDefinition() == typeof(List<>))
60+
return new Tuple<CollectionRank, Type>(CollectionRank.List, elementType);*/
7061

71-
//this is the most specialized type. return early
72-
if (thisRank == CollectionRank.List) return thisRank;
73-
74-
//if it is more specialized, assign to max rank
75-
if ((int)thisRank > (int)maxRank)
76-
maxRank = thisRank;
62+
var elementType = collectionType.GetGenericArguments()[0];
63+
var thisRank = getRankOfType(collectionType);
64+
if (thisRank != CollectionRank.None)
65+
return new Tuple<CollectionRank, Type>(thisRank, elementType);
7766
}
7867

79-
return maxRank;
68+
return null;
8069
}
8170

71+
private CollectionRank? GetRank(Type targetType)
72+
{
73+
return GetRankAndType(targetType)?.Item1;
74+
}
8275

8376
public bool CanDecode(PyObject objectType, Type targetType)
8477
{
@@ -89,7 +82,7 @@ public bool CanDecode(PyObject objectType, Type targetType)
8982

9083
//get the clr object rank
9184
var clrRank = GetRank(targetType);
92-
if (clrRank == CollectionRank.None)
85+
if (clrRank == null || clrRank == CollectionRank.None)
9386
return false;
9487

9588
//if it is a plain IEnumerable, we can decode it using sequence protocol.
@@ -99,15 +92,16 @@ public bool CanDecode(PyObject objectType, Type targetType)
9992
return (int)pyRank >= (int)clrRank;
10093
}
10194

102-
private class PyEnumerable : System.Collections.IEnumerable
95+
private class GenericPyEnumerable<T> : IEnumerable<T>
10396
{
104-
PyObject iterObject;
105-
internal PyEnumerable(PyObject pyObj)
97+
protected PyObject iterObject;
98+
99+
internal GenericPyEnumerable(PyObject pyObj)
106100
{
107101
iterObject = new PyObject(Runtime.PyObject_GetIter(pyObj.Handle));
108102
}
109103

110-
public IEnumerator GetEnumerator()
104+
IEnumerator IEnumerable.GetEnumerator()
111105
{
112106
IntPtr item;
113107
while ((item = Runtime.PyIter_Next(iterObject.Handle)) != IntPtr.Zero)
@@ -123,11 +117,32 @@ public IEnumerator GetEnumerator()
123117
yield return obj;
124118
}
125119
}
120+
121+
public IEnumerator<T> GetEnumerator()
122+
{
123+
IntPtr item;
124+
while ((item = Runtime.PyIter_Next(iterObject.Handle)) != IntPtr.Zero)
125+
{
126+
object obj = null;
127+
if (!Converter.ToManaged(item, typeof(T), out obj, true))
128+
{
129+
Runtime.XDecref(item);
130+
break;
131+
}
132+
133+
Runtime.XDecref(item);
134+
yield return (T)obj;
135+
}
136+
}
126137
}
127138

128139
private object ToPlainEnumerable(PyObject pyObj)
129140
{
130-
return new PyEnumerable(pyObj);
141+
return new GenericPyEnumerable<object>(pyObj);
142+
}
143+
private object ToEnumerable<T>(PyObject pyObj)
144+
{
145+
return new GenericPyEnumerable<T>(pyObj);
131146
}
132147

133148
public bool TryDecode<T>(PyObject pyObj, out T value)
@@ -136,7 +151,16 @@ public bool TryDecode<T>(PyObject pyObj, out T value)
136151
//first see if T is a plan IEnumerable
137152
if (typeof(T) == typeof(System.Collections.IEnumerable))
138153
{
139-
var = ToPlainEnumerable(pyObj);
154+
var = new GenericPyEnumerable<object>(pyObj);
155+
}
156+
157+
//next use the rank to return the appropriate type
158+
var clrRank = GetRank(typeof(T));
159+
if (clrRank == CollectionRank.Iterable)
160+
var = new GenericPyEnumerable<int>(pyObj);
161+
else
162+
{
163+
//var = null;
140164
}
141165

142166
value = (T)var;

0 commit comments

Comments
 (0)
0