8000 Add a basic test for the new 'LoadAssemblyFromNativeMemory' API · awakecoding/PowerShell@2a2efcf · GitHub
[go: up one dir, main page]

Skip to content

Commit 2a2efcf

Browse files
daxian-dbwSteveL-MSFT
authored andcommitted
Add a basic test for the new 'LoadAssemblyFromNativeMemory' API
1 parent 39b1c78 commit 2a2efcf

File tree

2 files changed

+105
-2
lines changed

2 files changed

+105
-2
lines changed

src/System.Management.Automation/CoreCLR/CorePsAssemblyLoadContext.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,9 @@ public static unsafe class PowerShellUnsafeAssemblyLoad
603603
/// Size in bytes of the assembly data buffer.
604604
/// </param>
605605
[UnmanagedCallersOnly]
606-
public static void LoadAssemblyFromMemory(IntPtr data, int size)
606+
public static void LoadAssemblyFromNativeMemory(IntPtr data, int size)
607607
{
608-
using var stream = new UnmanagedMemoryStream((Byte*)data, size);
608+
using var stream = new UnmanagedMemoryStream((byte*)data, size);
609609
AssemblyLoadContext.Default.LoadFromStream(stream);
610610
}
611611
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
5+
using System.Reflection;
6+
using System.Runtime.InteropServices;
7+
using System.Runtime.Loader;
8+
using System.Management.Automation;
9+
using Xunit;
10+
11+
using Microsoft.CodeAnalysis;
12+
using Microsoft.CodeAnalysis.CSharp;
13+
using Microsoft.CodeAnalysis.Emit;
14+
using Microsoft.CodeAnalysis.Text;
15+
16+
namespace PSTests.Sequential
17+
{
18+
public static class NativeInterop
19+
{
20+
[Fact]
21+
public static void TestLoadNativeInMemoryAssembly()
22+
{
23+
string tempDir = Path.Combine(Path.GetTempPath(), "TestLoadNativeInMemoryAssembly");
24+
string testDll = Path.Combine(tempDir, "test.dll");
25+
26+
if (!File.Exists(testDll))
27+
{
28+
Directory.CreateDirectory(tempDir);
29+
bool result = CreateTestDll(testDll);
30+
Assert.True(result, "The call to 'CreateTestDll' should be successful and return true.");
31+
Assert.True(File.Exists(testDll), "The test assembly should be created.");
32+
}
33+
34+
var asmName = AssemblyName.GetAssemblyName(testDll);
35+
string asmFullName = SearchAssembly(asmName.Name);
36+
Assert.Null(asmFullName);
37+
38+
unsafe { LoadAssemblyTest(testDll); }
39+
40+
asmFullName = SearchAssembly(asmName.Name);
41+
Assert.Equal(asmName.FullName, asmFullName);
42+
}
43+
44+
private static unsafe void LoadAssemblyTest(string assemblyPath)
45+
{
46+
// The 'LoadAssemblyFromNativeMemory' method is annotated with 'UnmanagedCallersOnly' attribute,
47+
// so we have to use the 'unmanaged' function pointer to invoke it.
48+
delegate* unmanaged<IntPtr, int, void> funcPtr = &PowerShellUnsafeAssemblyLoad.LoadAssemblyFromNativeMemory;
49+
50+
int length = 0;
51+
IntPtr nativeMem = IntPtr.Zero;
52+
53+
try
54+
{
55+
using (var fileStream = new FileStream(assemblyPath, FileMode.Open, FileAccess.Read))
56+
{
57+
length = (int)fileStream.Length;
58+
nativeMem = Marshal.AllocHGlobal(length);
59+
60+
using var unmanagedStream = new UnmanagedMemoryStream((byte*)nativeMem, length, length, FileAccess.Write);
61+
fileStream.CopyTo(unmanagedStream);
62+
}
63+
64+
// Call the function pointer.
65+
funcPtr(nativeMem, length);
66+
}
67+
finally
68+
{
69+
// Free the native memory
70+
Marshal.FreeHGlobal(nativeMem);
71+
}
72+
}
73+
74+
private static string SearchAssembly(string assemblyName)
75+
{
76+
Assembly asm = AssemblyLoadContext.Default.Assemblies.FirstOrDefault(
77+
assembly => assembly.FullName.StartsWith(assemblyName, StringComparison.OrdinalIgnoreCase));
78+
79+
return asm?.FullName;
80+
}
81+
82+
private static bool CreateTestDll(string dllPath)
83+
{
84+
var parseOptions = CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Latest);
85+
var compilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary);
86+
87+
List<SyntaxTree> syntaxTrees = new();
88+
SourceText sourceText = SourceText.From("public class Utt { }");
89+
syntaxTrees.Add(CSharpSyntaxTree.ParseText(sourceText, parseOptions));
90+
91+
var refs = new List<PortableExecutableReference> { MetadataReference.CreateFromFile(typeof(object).Assembly.Location) };
92+
Compilation compilation = CSharpCompilation.Create(
93+
Path.GetRandomFileName(),
94+
syntaxTrees: syntaxTrees,
95+
references: refs,
96+
options: compilationOptions);
97+
98+
using var fs = new FileStream(dllPath, FileMode.CreateNew, FileAccess.ReadWrite, FileShare.None);
99+
EmitResult emitResult = compilation.Emit(peStream: fs, options: null);
100+
return emitResult.Success;
101+
}
102+
}
103+
}

0 commit comments

Comments
 (0)
0