8000 Expose serialization api by BadSingleton · Pull Request #2336 · pythonnet/pythonnet · GitHub
[go: up one dir, main page]

Skip to content

Expose serialization api #2336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].
### Added

### Changed
- Added a `FormatterFactory` member in RuntimeData to create formatters with parameters. For compatibility, the `FormatterType` member is still present and has precedence when defining both `FormatterFactory` and `FormatterType`
- Added a post-serialization and a pre-deserialization step callbacks to extend (de)serialization process
- Added an API to stash serialized data on Python capsules

### Fixed

Expand Down
14 changes: 14 additions & 0 deletions src/runtime/StateSerialization/NoopFormatter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System;
using System.IO;
using System.Runtime.Serialization;

namespace Python.Runtime;

public class NoopFormatter : IFormatter {
public object Deserialize(Stream s) => throw new NotImplementedException();
public void Serialize(Stream s, object o) {}

public SerializationBinder? Binder { get; set; }
public StreamingContext Context { get; set; }
public ISurrogateSelector? SurrogateSelector { get; set; }
}
138 changes: 132 additions & 6 deletions src/runtime/StateSerialization/RuntimeData.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
Expand All @@ -17,7 +15,34 @@ namespace Python.Runtime
{
public static class RuntimeData
{
private static Type? _formatterType;

public readonly static Func<IFormatter> DefaultFormatterFactory = () =>
{
try
{
return new BinaryFormatter();
}
catch
{
return new NoopFormatter();
}
};

private static Func<IFormatter> _formatterFactory { get; set; } = DefaultFormatterFactory;

public static Func<IFormatter> FormatterFactory
{
get => _formatterFactory;
set
{
if (value == null)
throw new ArgumentNullException(nameof(value));

_formatterFactory = value;
}
}

private static Type? _formatterType = null;
public static Type? FormatterType
{
get => _formatterType;
Expand All @@ -31,6 +56,14 @@ public static Type? FormatterType
}
}

/// <summary>
/// Callback called as a last step in the serialization process
/// </summary>
public static Action? PostStashHook { get; set; } = null;
/// <summary>
/// Callback called as the first step in the deserialization process
/// </summary>
public static Action? PreRestoreHook { get; set; } = null;
public static ICLRObjectStorer? WrappersStorer { get; set; }

/// <summary>
Expand Down Expand Up @@ -74,6 +107,7 @@ internal static void Stash()
using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero);
int res = PySys_SetObject("clr_data", capsule.BorrowOrThrow());
PythonException.ThrowIfIsNotZero(res);
PostStashHook?.Invoke();
}

internal static void RestoreRuntimeData()
Expand All @@ -90,6 +124,7 @@ internal static void RestoreRuntimeData()

private static void RestoreRuntimeDataImpl()
{
PreRestoreHook?.Invoke();
BorrowedReference capsule = PySys_GetObject("clr_data");
if (capsule.IsNull)
{
Expand Down Expand Up @@ -250,11 +285,102 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage)
}
}

static readonly string serialization_key_namepsace = "pythonnet_serialization_";
/// <summary>
/// Removes the serialization capsule from the `sys` module object.
/// </summary>
/// <remarks>
/// The serialization data must have been set with <code>StashSerializationData</code>
/// </remarks>
/// <param name="key">The name given to the capsule on the `sys` module object</param>
public static void FreeSerializationData(string key)
{
key = serialization_key_namepsace + key;
BorrowedReference oldCapsule = PySys_GetObject(key);
if (!oldCapsule.IsNull)
{
IntPtr oldData = PyCapsule_GetPointer(oldCapsule, IntPtr.Zero);
Marshal.FreeHGlobal(oldData);
PyCapsule_SetPointer(oldCapsule, IntPtr.Zero);
PySys_SetObject(key, null);
}
}

/// <summary>
/// Stores the data in the <paramref name="stream"/> argument in a Python capsule and stores
/// the capsule on the `sys` module object with the name <paramref name="key"/>.
/// </summary>
/// <remarks>
/// No checks on pre-existing names on the `sys` module object are made.
/// </remarks>
/// <param name="key">The name given to the capsule on the `sys` module object</param>
/// <param name="stream">A MemoryStream that contains the data to be placed in the capsule</param>
public static void StashSerializationData(string key, MemoryStream stream)
{
if (stream.TryGetBuffer(out var data))
{
IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Count);

// store the length of the buffer first
Marshal.WriteIntPtr(mem, (IntPtr)data.Count);
Marshal.Copy(data.Array, data.Offset, mem + IntPtr.Size, data.Count);

try
{
using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero);
int res = PySys_SetObject(key, capsule.BorrowOrThrow());
PythonException.ThrowIfIsNotZero(res);
}
catch
{
Marshal.FreeHGlobal(mem);
}
}
else
{
throw new NotImplementedException($"{nameof(stream)} must be exposable");
}

}

static byte[] emptyBuffer = new byte[0];
/// <summary>
/// Retreives the previously stored data on a Python capsule.
/// Throws if the object corresponding to the <paramref name="key"/> parameter
/// on the `sys` module object is not a capsule.
/// </summary>
/// <param name="key">The name given to the capsule on the `sys` module object</param>
/// <returns>A MemoryStream containing the previously saved serialization data.
/// The stream is empty if no name matches the key. </returns>
public static MemoryStream GetSerializationData(string key)
{
BorrowedReference capsule = PySys_GetObject(key);
if (capsule.IsNull)
{
// nothing to do.
return new MemoryStream(emptyBuffer, writable:false);
}
var ptr = PyCapsule_GetPointer(capsule, IntPtr.Zero);
if (ptr == IntPtr.Zero)
{
// The PyCapsule API returns NULL on error; NULL cannot be stored
// as a capsule's value
PythonException.ThrowIfIsNull(null);
}
var len = (int)Marshal.ReadIntPtr(ptr);
byte[] buffer = new byte[len];
Marshal.Copy(ptr+IntPtr.Size, buffer, 0, len);
return new MemoryStream(buffer, writable:false);
}

internal static IFormatter CreateFormatter()
{
return FormatterType != null ?
(IFormatter)Activator.CreateInstance(FormatterType)
: new BinaryFormatter();

if (FormatterType != null)
{
return (IFormatter)Activator.CreateInstance(FormatterType);
}
return FormatterFactory();
}
}
}
117 changes: 117 additions & 0 deletions tests/domain_tests/TestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,66 @@ import System

",
},
new TestCase
{
Name = "test_serialize_unserializable_object",
DotNetBefore = @"
namespace TestNamespace
{
public class NotSerializableTextWriter : System.IO.TextWriter
{
override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} }
}
[System.Serializable]
public static class SerializableWriter
{
private static System.IO.TextWriter _writer = null;
public static System.IO.TextWriter Writer {get { return _writer; }}
public static void CreateInternalWriter()
{
_writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter());
}
}
}
",
DotNetAfter = @"
namespace TestNamespace
{
public class NotSerializableTextWriter : System.IO.TextWriter
{
override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} }
}
[System.Serializable]
public static class SerializableWriter
{
private static System.IO.TextWriter _writer = null;
public static System.IO.TextWriter Writer {get { return _writer; }}
public static void CreateInternalWriter()
{
_writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter());
}
}
}
",
PythonCode = @"
import sys

def before_reload():
import clr
import System
clr.AddReference('DomainTests')
import TestNamespace
TestNamespace.SerializableWriter.CreateInternalWriter();
sys.__obj = TestNamespace.SerializableWriter.Writer
sys.__obj.WriteLine('test')

def after_reload():
import clr
import System
sys.__obj.WriteLine('test')

",
}
};

/// <summary>
Expand All @@ -1142,7 +1202,59 @@ import System
const string CaseRunnerTemplate = @"
using System;
using System.IO;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using Python.Runtime;

namespace Serialization
{{
// Classes in this namespace is mostly useful for test_serialize_unserializable_object
class NotSerializableSerializer : ISerializationSurrogate
{{
public NotSerializableSerializer()
{{
}}
public void GetObjectData(object obj, SerializationInfo info, StreamingContext context)
{{
info.AddValue(""notSerialized_tp"", obj.GetType());
}}
public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector)
{{
if (info == null)
{{
return null;
}}
Type typeObj = info.GetValue(""notSerialized_tp"", typeof(Type)) as Type;
if (typeObj == null)
{{
return null;
}}

obj = Activator.CreateInstance(typeObj);
return obj;
}}
}}
class NonSerializableSelector : SurrogateSelector
{{
public override ISerializationSurrogate GetSurrogate(Type type, StreamingContext context, out ISurrogateSelector selector)
{{
if (type == null)
{{
throw new ArgumentNullException();
}}
selector = (ISurrogateSelector)this;
if (type.IsSerializable)
{{
return null; // use whichever default
}}
else
{{
return (ISerializationSurrogate)(new NotSerializableSerializer());
}}
}}
}}
}}

namespace CaseRunner
{{
class CaseRunner
Expand All @@ -1151,6 +1263,11 @@ public static int Main()
{{
try
{{
RuntimeData.FormatterFactory = () =>
{{
return new BinaryFormatter(){{SurrogateSelector = new Serialization.NonSerializableSelector()}};
}};

PythonEngine.Initialize();
using (Py.GIL())
{{
Expand Down
3 changes: 3 additions & 0 deletions tests/domain_tests/test_domain_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,6 @@ def test_nested_type():

def test_import_after_reload():
_run_test("import_after_reload")

def test_import_after_reload():
_run_test("test_serialize_unserializable_object")
Loading
0