diff --git a/CHANGELOG.md b/CHANGELOG.md index fdab9bf64..fd78f138f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][]. ### Added ### Changed +- Added a `FormatterFactory` member in RuntimeData to create formatters with parameters. For compatibility, the `FormatterType` member is still present and has precedence when defining both `FormatterFactory` and `FormatterType` +- Added a post-serialization and a pre-deserialization step callbacks to extend (de)serialization process +- Added an API to stash serialized data on Python capsules ### Fixed diff --git a/src/runtime/StateSerialization/NoopFormatter.cs b/src/runtime/StateSerialization/NoopFormatter.cs new file mode 100644 index 000000000..f05b7ebb2 --- /dev/null +++ b/src/runtime/StateSerialization/NoopFormatter.cs @@ -0,0 +1,14 @@ +using System; +using System.IO; +using System.Runtime.Serialization; + +namespace Python.Runtime; + +public class NoopFormatter : IFormatter { + public object Deserialize(Stream s) => throw new NotImplementedException(); + public void Serialize(Stream s, object o) {} + + public SerializationBinder? Binder { get; set; } + public StreamingContext Context { get; set; } + public ISurrogateSelector? SurrogateSelector { get; set; } +} diff --git a/src/runtime/StateSerialization/RuntimeData.cs b/src/runtime/StateSerialization/RuntimeData.cs index 204e15b5b..8eda9ce0b 100644 --- a/src/runtime/StateSerialization/RuntimeData.cs +++ b/src/runtime/StateSerialization/RuntimeData.cs @@ -1,7 +1,5 @@ using System; -using System.Collections; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.Diagnostics; using System.IO; using System.Linq; @@ -17,7 +15,34 @@ namespace Python.Runtime { public static class RuntimeData { - private static Type? _formatterType; + + public readonly static Func DefaultFormatterFactory = () => + { + try + { + return new BinaryFormatter(); + } + catch + { + return new NoopFormatter(); + } + }; + + private static Func _formatterFactory { get; set; } = DefaultFormatterFactory; + + public static Func FormatterFactory + { + get => _formatterFactory; + set + { + if (value == null) + throw new ArgumentNullException(nameof(value)); + + _formatterFactory = value; + } + } + + private static Type? _formatterType = null; public static Type? FormatterType { get => _formatterType; @@ -31,6 +56,14 @@ public static Type? FormatterType } } + /// + /// Callback called as a last step in the serialization process + /// + public static Action? PostStashHook { get; set; } = null; + /// + /// Callback called as the first step in the deserialization process + /// + public static Action? PreRestoreHook { get; set; } = null; public static ICLRObjectStorer? WrappersStorer { get; set; } /// @@ -74,6 +107,7 @@ internal static void Stash() using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); int res = PySys_SetObject("clr_data", capsule.BorrowOrThrow()); PythonException.ThrowIfIsNotZero(res); + PostStashHook?.Invoke(); } internal static void RestoreRuntimeData() @@ -90,6 +124,7 @@ internal static void RestoreRuntimeData() private static void RestoreRuntimeDataImpl() { + PreRestoreHook?.Invoke(); BorrowedReference capsule = PySys_GetObject("clr_data"); if (capsule.IsNull) { @@ -250,11 +285,102 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage) } } + static readonly string serialization_key_namepsace = "pythonnet_serialization_"; + /// + /// Removes the serialization capsule from the `sys` module object. + /// + /// + /// The serialization data must have been set with StashSerializationData + /// + /// The name given to the capsule on the `sys` module object + public static void FreeSerializationData(string key) + { + key = serialization_key_namepsace + key; + BorrowedReference oldCapsule = PySys_GetObject(key); + if (!oldCapsule.IsNull) + { + IntPtr oldData = PyCapsule_GetPointer(oldCapsule, IntPtr.Zero); + Marshal.FreeHGlobal(oldData); + PyCapsule_SetPointer(oldCapsule, IntPtr.Zero); + PySys_SetObject(key, null); + } + } + + /// + /// Stores the data in the argument in a Python capsule and stores + /// the capsule on the `sys` module object with the name . + /// + /// + /// No checks on pre-existing names on the `sys` module object are made. + /// + /// The name given to the capsule on the `sys` module object + /// A MemoryStream that contains the data to be placed in the capsule + public static void StashSerializationData(string key, MemoryStream stream) + { + if (stream.TryGetBuffer(out var data)) + { + IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Count); + + // store the length of the buffer first + Marshal.WriteIntPtr(mem, (IntPtr)data.Count); + Marshal.Copy(data.Array, data.Offset, mem + IntPtr.Size, data.Count); + + try + { + using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); + int res = PySys_SetObject(key, capsule.BorrowOrThrow()); + PythonException.ThrowIfIsNotZero(res); + } + catch + { + Marshal.FreeHGlobal(mem); + } + } + else + { + throw new NotImplementedException($"{nameof(stream)} must be exposable"); + } + + } + + static byte[] emptyBuffer = new byte[0]; + /// + /// Retreives the previously stored data on a Python capsule. + /// Throws if the object corresponding to the parameter + /// on the `sys` module object is not a capsule. + /// + /// The name given to the capsule on the `sys` module object + /// A MemoryStream containing the previously saved serialization data. + /// The stream is empty if no name matches the key. + public static MemoryStream GetSerializationData(string key) + { + BorrowedReference capsule = PySys_GetObject(key); + if (capsule.IsNull) + { + // nothing to do. + return new MemoryStream(emptyBuffer, writable:false); + } + var ptr = PyCapsule_GetPointer(capsule, IntPtr.Zero); + if (ptr == IntPtr.Zero) + { + // The PyCapsule API returns NULL on error; NULL cannot be stored + // as a capsule's value + PythonException.ThrowIfIsNull(null); + } + var len = (int)Marshal.ReadIntPtr(ptr); + byte[] buffer = new byte[len]; + Marshal.Copy(ptr+IntPtr.Size, buffer, 0, len); + return new MemoryStream(buffer, writable:false); + } + internal static IFormatter CreateFormatter() { - return FormatterType != null ? - (IFormatter)Activator.CreateInstance(FormatterType) - : new BinaryFormatter(); + + if (FormatterType != null) + { + return (IFormatter)Activator.CreateInstance(FormatterType); + } + return FormatterFactory(); } } } diff --git a/tests/domain_tests/TestRunner.cs b/tests/domain_tests/TestRunner.cs index 4f6a3ea28..bbee81b3d 100644 --- a/tests/domain_tests/TestRunner.cs +++ b/tests/domain_tests/TestRunner.cs @@ -1132,6 +1132,66 @@ import System ", }, + new TestCase + { + Name = "test_serialize_unserializable_object", + DotNetBefore = @" + namespace TestNamespace + { + public class NotSerializableTextWriter : System.IO.TextWriter + { + override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} } + } + [System.Serializable] + public static class SerializableWriter + { + private static System.IO.TextWriter _writer = null; + public static System.IO.TextWriter Writer {get { return _writer; }} + public static void CreateInternalWriter() + { + _writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter()); + } + } + } +", + DotNetAfter = @" + namespace TestNamespace + { + public class NotSerializableTextWriter : System.IO.TextWriter + { + override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} } + } + [System.Serializable] + public static class SerializableWriter + { + private static System.IO.TextWriter _writer = null; + public static System.IO.TextWriter Writer {get { return _writer; }} + public static void CreateInternalWriter() + { + _writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter()); + } + } + } + ", + PythonCode = @" +import sys + +def before_reload(): + import clr + import System + clr.AddReference('DomainTests') + import TestNamespace + TestNamespace.SerializableWriter.CreateInternalWriter(); + sys.__obj = TestNamespace.SerializableWriter.Writer + sys.__obj.WriteLine('test') + +def after_reload(): + import clr + import System + sys.__obj.WriteLine('test') + + ", + } }; /// @@ -1142,7 +1202,59 @@ import System const string CaseRunnerTemplate = @" using System; using System.IO; +using System.Runtime.Serialization; +using System.Runtime.Serialization.Formatters.Binary; using Python.Runtime; + +namespace Serialization +{{ + // Classes in this namespace is mostly useful for test_serialize_unserializable_object + class NotSerializableSerializer : ISerializationSurrogate + {{ + public NotSerializableSerializer() + {{ + }} + public void GetObjectData(object obj, SerializationInfo info, StreamingContext context) + {{ + info.AddValue(""notSerialized_tp"", obj.GetType()); + }} + public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector) + {{ + if (info == null) + {{ + return null; + }} + Type typeObj = info.GetValue(""notSerialized_tp"", typeof(Type)) as Type; + if (typeObj == null) + {{ + return null; + }} + + obj = Activator.CreateInstance(typeObj); + return obj; + }} + }} + class NonSerializableSelector : SurrogateSelector + {{ + public override ISerializationSurrogate GetSurrogate(Type type, StreamingContext context, out ISurrogateSelector selector) + {{ + if (type == null) + {{ + throw new ArgumentNullException(); + }} + selector = (ISurrogateSelector)this; + if (type.IsSerializable) + {{ + return null; // use whichever default + }} + else + {{ + return (ISerializationSurrogate)(new NotSerializableSerializer()); + }} + }} + }} +}} + namespace CaseRunner {{ class CaseRunner @@ -1151,6 +1263,11 @@ public static int Main() {{ try {{ + RuntimeData.FormatterFactory = () => + {{ + return new BinaryFormatter(){{SurrogateSelector = new Serialization.NonSerializableSelector()}}; + }}; + PythonEngine.Initialize(); using (Py.GIL()) {{ diff --git a/tests/domain_tests/test_domain_reload.py b/tests/domain_tests/test_domain_reload.py index 8999e481b..1e5e8e81b 100644 --- a/tests/domain_tests/test_domain_reload.py +++ b/tests/domain_tests/test_domain_reload.py @@ -88,3 +88,6 @@ def test_nested_type(): def test_import_after_reload(): _run_test("import_after_reload") + +def test_import_after_reload(): + _run_test("test_serialize_unserializable_object") \ No newline at end of file