8000 IComparable and IEquatable implementations for PyInt, PyFloat, and Py… · pythonnet/pythonnet@563e369 · GitHub
[go: up one dir, main page]

Skip to content

Commit 563e369

Browse files
committed
IComparable and IEquatable implementations for PyInt, PyFloat, and PyString for primitive .NET types
1 parent 9d18a24 commit 563e369

File tree

10 files changed

+331
-4
lines changed

10 files changed

+331
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].
1111

1212
- Added `ToPythonAs<T>()` extension method to allow for explicit conversion using a specific type. ([#2311][i2311])
1313

14+
- Added `IComparable` and `IEquatable` implementations to `PyInt`, `PyFloat`, and `PyString`
15+
to compare with primitive .NET types like `long`.
16+
1417
### Changed
1518

1619
### Fixed

src/embed_tests/TestPyFloat.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,32 @@ public void AsFloatBad()
126126
StringAssert.StartsWith("could not convert string to float", ex.Message);
127127
Assert.IsNull(a);
128128
}
129+
130+
[Test]
131+
public void CompareTo()
132+
{
133+
var v = new PyFloat(42);
134+
135+
Assert.AreEqual(0, v.CompareTo(42f));
136+
Assert.AreEqual(0, v.CompareTo(42d));
137+
138+
Assert.AreEqual(1, v.CompareTo(41f));
139+
Assert.AreEqual(1, v.CompareTo(41d));
140+
141+
Assert.AreEqual(-1, v.CompareTo(43f));
142+
Assert.AreEqual(-1, v.CompareTo(43d));
143+
}
144+
145+
[Test]
146+
public void Equals()
147+
{
148+
var v = new PyFloat(42);
149+
150+
Assert.IsTrue(v.Equals(42f));
151+
Assert.IsTrue(v.Equals(42d));
152+
153+
Assert.IsFalse(v.Equals(41f));
154+
Assert.IsFalse(v.Equals(41d));
155+
}
129156
}
130157
}

src/embed_tests/TestPyInt.cs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,76 @@ public void ToBigInteger()
210210
CollectionAssert.AreEqual(expected, actual);
211211
}
212212

213+
[Test]
214+
public void CompareTo()
215+
{
216+
var v = new PyInt(42);
217+
218+
#region Signed
219+
Assert.AreEqual(0, v.CompareTo(42L));
220+
Assert.AreEqual(0, v.CompareTo(42));
221+
Assert.AreEqual(0, v.CompareTo((short)42));
222+
Assert.AreEqual(0, v.CompareTo((sbyte)42));
223+
224+
Assert.AreEqual(1, v.CompareTo(41L));
225+
Assert.AreEqual(1, v.CompareTo(41));
226+
Assert.AreEqual(1, v.CompareTo((short)41));
227+
Assert.AreEqual(1, v.CompareTo((sbyte)41));
228+
229+
Assert.AreEqual(-1, v.CompareTo(43L));
230+
Assert.AreEqual(-1, v.CompareTo(43));
231+
Assert.AreEqual(-1, v.CompareTo((short)43));
232+
Assert.AreEqual(-1, v.CompareTo((sbyte)43));
233+
#endregion Signed
234+
235+
#region Unsigned
236+
Assert.AreEqual(0, v.CompareTo(42UL));
237+
Assert.AreEqual(0, v.CompareTo(42U));
238+
Assert.AreEqual(0, v.CompareTo((ushort)42));
239+
Assert.AreEqual(0, v.CompareTo((byte)42));
240+
241+
Assert.AreEqual(1, v.CompareTo(41UL));
242+
Assert.AreEqual(1, v.CompareTo(41U));
243+
Assert.AreEqual(1, v.CompareTo((ushort)41));
244+
Assert.AreEqual(1, v.CompareTo((byte)41));
245+
246+
Assert.AreEqual(-1, v.CompareTo(43UL));
247+
Assert.AreEqual(-1, v.CompareTo(43U));
248+
Assert.AreEqual(-1, v.CompareTo((ushort)43));
249+
Assert.AreEqual(-1, v.CompareTo((byte)43));
250+
#endregion Unsigned
251+
}
252+
253+
[Test]
254+
public void Equals()
255+
{
256+
var v = new PyInt(42);
257+
258+
#region Signed
259+
Assert.True(v.Equals(42L));
260+
Assert.True(v.Equals(42));
261+
Assert.True(v.Equals((short)42));
262+
Assert.True(v.Equals((sbyte)42));
263+
264+
Assert.False(v.Equals(41L));
265+
Assert.False(v.Equals(41));
266+
Assert.False(v.Equals((short)41));
267+
Assert.False(v.Equals((sbyte)41));
268+
#endregion Signed
269+
270+
#region Unsigned
271+
Assert.True(v.Equals(42UL));
272+
Assert.True(v.Equals(42U)) E377 ;
273+
Assert.True(v.Equals((ushort)42));
274+
Assert.True(v.Equals((byte)42));
275+
276+
Assert.False(v.Equals(41UL));
277+
Assert.False(v.Equals(41U));
278+
Assert.False(v.Equals((ushort)41));
279+
Assert.False(v.Equals((byte)41));
280+
#endregion Unsigned
281+
}
282+
213283
[Test]
214284
public void ToBigIntegerLarge()
215285
{

src/embed_tests/TestPyString.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,24 @@ public void TestUnicodeSurrogate()
112112
Assert.AreEqual(4, actual.Length());
113113
Assert.AreEqual(expected, actual.ToString());
114114
}
115+
116+
[Test]
117+
public void CompareTo()
118+
{
119+
var a = new PyString("foo");
120+
121+
Assert.AreEqual(0, a.CompareTo("foo"));
122+
Assert.AreEqual("foo".CompareTo("bar"), a.CompareTo("bar"));
123+
Assert.AreEqual("foo".CompareTo("foz"), a.CompareTo("foz"));
124+
}
125+
126+
[Test]
127+
public void Equals()
128+
{
129+
var a = new PyString("foo");
130+
131+
Assert.True(a.Equals("foo"));
132+
Assert.False(a.Equals("bar"));
133+
}
115134
}
116135
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
3+
namespace Python.Runtime;
4+
5+
partial class PyFloat : IComparable<double>, IComparable<float>
6+
, IEquatable<double>, IEquatable<float>
7+
, IComparable<PyFloat?>, IEquatable<PyFloat?>
8+
{
9+
public override bool Equals(object o)
10+
{
11+
using var _ = Py.GIL();
12+
return o switch
13+
{
14+
double f64 => this.Equals(f64),
15+
float f32 => this.Equals(f32),
16+
_ => base.Equals(o),
17+
};
18+
}
19+
20+
public int CompareTo(double other) => this.ToDouble().CompareTo(other);
21+
22+
public int CompareTo(float other) => this.ToDouble().CompareTo(other);
23+
24+
public bool Equals(double other) => this.ToDouble().Equals(other);
25+
26+
public bool Equals(float other) => this.ToDouble().Equals(other);
27+
28+
public int CompareTo(PyFloat? other)
29+
{
30+
return other is null ? 1 : this.CompareTo(other.BorrowNullable());
31+
}
32+
33+
public bool Equals(PyFloat? other) => base.Equals(other);
34+
}

src/runtime/PythonTypes/PyFloat.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Python.Runtime
88
/// PY3: https://docs.python.org/3/c-api/float.html
99
/// for details.
1010
/// </summary>
11-
public class PyFloat : PyNumber
11+
public partial class PyFloat : PyNumber
1212
{
1313
internal PyFloat(in StolenReference ptr) : base(ptr)
1414
{
@@ -100,6 +100,8 @@ public static PyFloat AsFloat(PyObject value)
100100
return new PyFloat(op.Steal());
101101
}
102102

103+
public double ToDouble() => Runtime.PyFloat_AsDouble(obj);
104+
103105
public override TypeCode GetTypeCode() => TypeCode.Double;
104106
}
105107
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
using System;
2+
3+
namespace Python.Runtime;
4+
5+
partial class PyInt : IComparable<long>, IComparable<int>, IComparable<sbyte>, IComparable<short>
6+
, IComparable<ulong>, IComparable<uint>, IComparable<ushort>, IComparable<byte>
7+
, IEquatable<long>, IEquatable<int>, IEquatable<short>, IEquatable<sbyte>
8+
, IEquatable<ulong>, IEquatable<uint>, IEquatable<ushort>, IEquatable<byte>
9+
, IComparable<PyInt?>, IEquatable<PyInt?>
10+
{
11+
public override bool Equals(object o)
12+
{
13+
using var _ = Py.GIL();
14+
return o switch
15+
{
16+
long i64 => this.Equals(i64),
17+
int i32 => this.Equals(i32),
18+
short i16 => this.Equals(i16),
19+
sbyte i8 => this.Equals(i8),
20+
21+
ulong u64 => this.Equals(u64),
22+
uint u32 => this.Equals(u32),
23+
ushort u16 => this.Equals(u16),
24+
byte u8 => this.Equals(u8),
25+
26+
_ => base.Equals(o),
27+
};
28+
}
29+
30+
#region Signed
31+
public int CompareTo(long other)
32+
{
33+
using var pyOther = Runtime.PyInt_FromInt64(other);
34+
return this.CompareTo(pyOther.BorrowOrThrow());
35+
}
36+
37+
public int CompareTo(int other)
38+
{
39+
using var pyOther = Runtime.PyInt_FromInt32(other);
40+
return this.CompareTo(pyOther.BorrowOrThrow());
41+
}
42+
43+
public int CompareTo(short other)
44+
{
45+
using var pyOther = Runtime.PyInt_FromInt32(other);
46+
return this.CompareTo(pyOther.BorrowOrThrow());
47+
}
48+
49+
public int CompareTo(sbyte other)
50+
{
51+
using var pyOther = Runtime.PyInt_FromInt32(other);
52+
return this.CompareTo(pyOther.BorrowOrThrow());
53+
}
54+
55+
public bool Equals(long other)
56+
{
57+
using var pyOther = Runtime.PyInt_FromInt64(other);
58+
return this.Equals(pyOther.BorrowOrThrow());
59+
}
60+
61+
public bool Equals(int other)
62+
{
63+
using var pyOther = Runtime.PyInt_FromInt32(other);
64+
return this.Equals(pyOther.BorrowOrThrow());
65+
}
66+
67+
public bool Equals(short other)
68+
{
69+
using var pyOther = Runtime.PyInt_FromInt32(other);
70+
return this.Equals(pyOther.BorrowOrThrow());
71+
}
72+
73+
public bool Equals(sbyte other)
74+
{
75+
using var pyOther = Runtime.PyInt_FromInt32(other);
76+
return this.Equals(pyOther.BorrowOrThrow());
77+
}
78+
#endregion Signed
79+
80+
#region Unsigned
81+
public int CompareTo(ulong other)
82+
{
83+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
84+
return this.CompareTo(pyOther.BorrowOrThrow());
85+
}
86+
87+
public int CompareTo(uint other)
88+
{
89+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
90+
return this.CompareTo(pyOther.BorrowOrThrow());
91+
}
92+
93+
public int CompareTo(ushort other)
94+
{
95+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
96+
return this.CompareTo(pyOther.BorrowOrThrow());
97+
}
98+
99+
public int CompareTo(byte other)
100+
{
101+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
102+
return this.CompareTo(pyOther.BorrowOrThrow());
103+
}
104+
105< D7AE /td>+
public bool Equals(ulong other)
106+
{
107+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
108+
return this.Equals(pyOther.BorrowOrThrow());
109+
}
110+
111+
public bool Equals(uint other)
112+
{
113+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
114+
return this.Equals(pyOther.BorrowOrThrow());
115+
}
116+
117+
public bool Equals(ushort other)
118+
{
119+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
120+
return this.Equals(pyOther.BorrowOrThrow());
121+
}
122+
123+
public bool Equals(byte other)
124+
{
125+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
126+
return this.Equals(pyOther.BorrowOrThrow());
127+
}
128+
#endregion Unsigned
129+
130+
public int CompareTo(PyInt? other)
131+
{
132+
return other is null ? 1 : this.CompareTo(other.BorrowNullable());
133+
}
134+
135+
public bool Equals(PyInt? other) => base.Equals(other);
136+
}

src/runtime/PythonTypes/PyInt.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Python.Runtime
99
/// Represents a Python integer object.
1010
/// See the documentation at https://docs.python.org/3/c-api/long.html
1111
/// </summary>
12-
public class PyInt : PyNumber, IFormattable
12+
public partial class PyInt : PyNumber, IFormattable
1313
{
1414
internal PyInt(in StolenReference ptr) : base(ptr)
1515
{

src/runtime/PythonTypes/PyObject.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,23 @@ public long Refcount
11361136
}
11371137
}
11381138

1139+
internal int CompareTo(BorrowedReference other)
1140+
{
1141+
int greater = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_GT);
1142+
Debug.Assert(greater != -1);
1143+
if (greater > 0)
1144+
return 1;
1145+
int less = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_LT);
1146+
Debug.Assert(less != -1);
1147+
return less > 0 ? -1 : 0;
1148+
}
1149+
1150+
internal bool Equals(BorrowedReference other)
1151+
{
1152+
int equal = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_EQ);
1153+
Debug.Assert(equal != -1);
1154+
return equal > 0;
1155+
}
11391156

11401157
public override bool TryGetMember(GetMemberBinder binder, out object? result)
11411158
{
@@ -1325,7 +1342,7 @@ private bool TryCompare(PyObject arg, int op, out object @out)
13251342
}
13261343
return true;
13271344
}
1328-
1345+
13291346
public override bool TryBinaryOperation(BinaryOperationBinder binder, object arg, out object? result)
13301347
{
13311348
using var _ = Py.GIL();

0 commit comments

Comments
 (0)
0