diff --git a/src/runtime/assemblymanager.cs b/src/runtime/assemblymanager.cs index a791b8195..5d9759375 100644 --- a/src/runtime/assemblymanager.cs +++ b/src/runtime/assemblymanager.cs @@ -1,11 +1,11 @@ using System; -using System.IO; using System.Collections; -using System.Collections.Specialized; +using System.IO; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Reflection; -using System.Reflection.Emit; +using System.Threading; namespace Python.Runtime { @@ -15,12 +15,16 @@ namespace Python.Runtime /// internal class AssemblyManager { - static Dictionary> namespaces; + // modified from event handlers below, potentially triggered from different .NET threads + // therefore this should be a ConcurrentDictionary + static ConcurrentDictionary> namespaces; //static Dictionary> generics; static AssemblyLoadEventHandler lhandler; static ResolveEventHandler rhandler; + // updated only under GIL? static Dictionary probed; - static List assemblies; + // modified from event handlers below, potentially triggered from different .NET threads + static AssemblyList assemblies; internal static List pypath; private AssemblyManager() @@ -36,10 +40,10 @@ private AssemblyManager() internal static void Initialize() { namespaces = new - Dictionary>(32); + ConcurrentDictionary>(); probed = new Dictionary(32); //generics = new Dictionary>(); - assemblies = new List(16); + assemblies = new AssemblyList(16); pypath = new List(16); AppDomain domain = AppDomain.CurrentDomain; @@ -105,9 +109,8 @@ static void AssemblyLoadHandler(Object ob, AssemblyLoadEventArgs args) static Assembly ResolveHandler(Object ob, ResolveEventArgs args) { string name = args.Name.ToLower(); - for (int i = 0; i < assemblies.Count; i++) + foreach (Assembly a in assemblies) { - Assembly a = (Assembly)assemblies[i]; string full = a.FullName.ToLower(); if (full.StartsWith(name)) { @@ -266,9 +269,8 @@ public static Assembly LoadAssemblyFullPath(string name) public static Assembly FindLoadedAssembly(string name) { - for (int i = 0; i < assemblies.Count; i++) + foreach (Assembly a in assemblies) { - Assembly a = (Assembly)assemblies[i]; if (a.GetName().Name == name) { return a; @@ -295,15 +297,15 @@ public static bool LoadImplicit(string name, bool warn = true) bool loaded = false; string s = ""; Assembly lastAssembly = null; - HashSet assemblies = null; + HashSet assembliesSet = null; for (int i = 0; i < names.Length; i++) { s = (i == 0) ? names[0] : s + "." + names[i]; if (!probed.ContainsKey(s)) { - if (assemblies == null) + if (assembliesSet == null) { - assemblies = new HashSet(AppDomain.CurrentDomain.GetAssemblies()); + assembliesSet = new HashSet(AppDomain.CurrentDomain.GetAssemblies()); } Assembly a = FindLoadedAssembly(s); if (a == null) @@ -314,7 +316,7 @@ public static bool LoadImplicit(string name, bool warn = true) { a = LoadAssembly(s); } - if (a != null && !assemblies.Contains(a)) + if (a != null && !assembliesSet.Contains(a)) { loaded = true; lastAssembly = a; @@ -362,16 +364,13 @@ internal static void ScanAssembly(Assembly assembly) for (int n = 0; n < names.Length; n++) { s = (n == 0) ? names[0] : s + "." + names[n]; - if (!namespaces.ContainsKey(s)) - { - namespaces.Add(s, new Dictionary()); - } + namespaces.TryAdd(s, new ConcurrentDictionary()); } } - if (ns != null && !namespaces[ns].ContainsKey(assembly)) + if (ns != null) { - namespaces[ns].Add(assembly, String.Empty); + namespaces[ns].TryAdd(assembly, String.Empty); } if (ns != null && t.IsGenericTypeDefinition) @@ -383,14 +382,12 @@ internal static void ScanAssembly(Assembly assembly) public static AssemblyName[] ListAssemblies() { - AssemblyName[] names = new AssemblyName[assemblies.Count]; - Assembly assembly; - for (int i = 0; i < assemblies.Count; i++) + List names = new List(assemblies.Count); + foreach (Assembly assembly in assemblies) { - assembly = assemblies[i]; - names.SetValue(assembly.GetName(), i); + names.Add(assembly.GetName()); } - return names; + return names.ToArray(); } //=================================================================== @@ -471,9 +468,8 @@ public static List GetNames(string nsname) public static Type LookupType(string qname) { - for (int i = 0; i < assemblies.Count; i++) + foreach (Assembly assembly in assemblies) { - Assembly assembly = (Assembly)assemblies[i]; Type type = assembly.GetType(qname); if (type != null) { @@ -482,5 +478,92 @@ public static Type LookupType(string qname) } return null; } + + /// + /// Wrapper around List for thread safe access + /// + private class AssemblyList : IEnumerable{ + private readonly List _list; + private readonly ReaderWriterLockSlim _lock; + + public AssemblyList(int capacity) { + _list = new List(capacity); + _lock = new ReaderWriterLockSlim(); + } + + public int Count + { + get + { + _lock.EnterReadLock(); + try { + return _list.Count; + } + finally { + _lock.ExitReadLock(); + } + } + } + + public void Add(Assembly assembly) { + _lock.EnterWriteLock(); + try + { + _list.Add(assembly); + } + finally + { + _lock.ExitWriteLock(); + } + } + + public IEnumerator GetEnumerator() + { + return ((IEnumerable) this).GetEnumerator(); + } + + /// + /// Enumerator wrapping around 's enumerator. + /// Acquires and releases a read lock on during enumeration + /// + private class Enumerator : IEnumerator + { + private readonly AssemblyList _assemblyList; + + private readonly IEnumerator _listEnumerator; + + public Enumerator(AssemblyList assemblyList) + { + _assemblyList = assemblyList; + _assemblyList._lock.EnterReadLock(); + _listEnumerator = _assemblyList._list.GetEnumerator(); + } + + public void Dispose() + { + _listEnumerator.Dispose(); + _assemblyList._lock.ExitReadLock(); + } + + public bool MoveNext() + { + return _listEnumerator.MoveNext(); + } + + public void Reset() + { + _listEnumerator.Reset(); + } + + public Assembly Current { get { return _listEnumerator.Current; } } + + object IEnumerator.Current { get { return Current; } } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return new Enumerator(this); + } + } } } diff --git a/src/testing/Python.Test.csproj b/src/testing/Python.Test.csproj index 56efda8e3..a0d62d4f9 100644 --- a/src/testing/Python.Test.csproj +++ b/src/testing/Python.Test.csproj @@ -122,6 +122,7 @@ + diff --git a/src/testing/moduletest.cs b/src/testing/moduletest.cs new file mode 100644 index 000000000..8734f2569 --- /dev/null +++ b/src/testing/moduletest.cs @@ -0,0 +1,25 @@ +using System; +using System.Threading; + +namespace Python.Test { + public class ModuleTest { + private static Thread _thread; + + public static void RunThreads() + { + _thread = new Thread(() => { + var appdomain = AppDomain.CurrentDomain; + var assemblies = appdomain.GetAssemblies(); + foreach (var assembly in assemblies) { + assembly.GetTypes(); + } + }); + _thread.Start(); + } + + public static void JoinThreads() + { + _thread.Join(); + } + } +} \ No newline at end of file diff --git a/src/tests/test_module.py b/src/tests/test_module.py index f03954d28..a23f37d90 100644 --- a/src/tests/test_module.py +++ b/src/tests/test_module.py @@ -65,8 +65,11 @@ def testModuleInterface(self): import System self.assertEquals(type(System.__dict__), type({})) self.assertEquals(System.__name__, 'System') - # the filename can be any module from the System namespace (eg System.Data.dll or System.dll) - self.assertTrue(fnmatch(System.__file__, "*System*.dll")) + # the filename can be any module from the System namespace + # (eg System.Data.dll or System.dll, but also mscorlib.dll) + system_file = System.__file__ + self.assertTrue(fnmatch(system_file, "*System*.dll") or fnmatch(system_file, "*mscorlib.dll"), + "unexpected System.__file__: " + system_file) self.assertTrue(System.__doc__.startswith("Namespace containing types from the following assemblies:")) self.assertTrue(self.isCLRClass(System.String)) self.assertTrue(self.isCLRClass(System.Int32)) @@ -353,6 +356,22 @@ def test_ClrAddReference(self): self.assertRaises(FileNotFoundException, AddReference, "somethingtotallysilly") + def test_AssemblyLoadThreadSafety(self): + import time + from Python.Test import ModuleTest + # spin up .NET thread which loads assemblies and triggers AppDomain.AssemblyLoad event + ModuleTest.RunThreads() + time.sleep(1e-5) + for i in range(1, 100): + # call import clr, which in AssemblyManager.GetNames iterates through the loaded types + import clr + # import some .NET types + from System import DateTime + from System import Guid + from System.Collections.Generic import Dictionary + dict = Dictionary[Guid,DateTime]() + ModuleTest.JoinThreads() + def test_suite(): return unittest.makeSuite(ModuleTests)