Skip to content

Commit

Permalink
ComponentCatalog refactor (dotnet#970)
Browse files Browse the repository at this point in the history
* Stop loading assemblies in ComponentCatalog.

Write the AssemblyName into the model, and use it to register the assembly during model load.

* Move ComponentCatalog from a static class to a member of IHostEnvironment.

* Update tests for ComponentCatalog refactoring.

* minor cleanup

* Add AssemblyName to all model VersionInfo instances.

Also fix a couple more tests.

* Load and register all assemblies in the Maml directory.
Ensure all loaded assemblies are registered in Experiment to maintain compability.
Fix tests to not use ComponentCatalog but direct instantiation instead.

* Sync up with latest code.

* Fix newly added test

* Clean up some test changes.

* Fix up for latest code

* Add path filtering back to LoadAssembliesInDir

* Update TestAutoInference to use the correct Environment.

* Respond to PR feedback.

* Make all AutoInference tests use LocalEnvironment.
  • Loading branch information
eerhardt authored Sep 25, 2018
1 parent 655c2e2 commit a02807c
Show file tree
Hide file tree
Showing 177 changed files with 1,261 additions and 770 deletions.
290 changes: 290 additions & 0 deletions src/Common/AssemblyLoadingUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.IO;
using System.IO.Compression;
using System.Reflection;

namespace Microsoft.ML.Runtime
{
internal static class AssemblyLoadingUtils
{
/// <summary>
/// Make sure the given assemblies are loaded and that their loadable classes have been catalogued.
/// </summary>
public static void LoadAndRegister(IHostEnvironment env, string[] assemblies)
{
Contracts.AssertValue(env);

if (Utils.Size(assemblies) > 0)
{
foreach (string path in assemblies)
{
Exception ex = null;
try
{
// REVIEW: Will LoadFrom ever return null?
Contracts.CheckNonEmpty(path, nameof(path));
var assem = LoadAssembly(env, path);
if (assem != null)
continue;
}
catch (Exception e)
{
ex = e;
}

// If it is a zip file, load it that way.
ZipArchive zip;
try
{
zip = ZipFile.OpenRead(path);
}
catch (Exception e)
{
// Couldn't load as an assembly and not a zip, so warn the user.
ex = ex ?? e;
Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message);
continue;
}

string dir;
try
{
dir = CreateTempDirectory();
}
catch (Exception e)
{
throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path);
}

try
{
zip.ExtractToDirectory(dir);
}
catch (Exception e)
{
throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path);
}

LoadAssembliesInDir(env, dir, false);
}
}
}

public static IDisposable CreateAssemblyRegistrar(IHostEnvironment env, string loadAssembliesPath = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValueOrNull(loadAssembliesPath);

return new AssemblyRegistrar(env, loadAssembliesPath);
}

public static void RegisterCurrentLoadedAssemblies(IHostEnvironment env)
{
Contracts.CheckValue(env, nameof(env));

foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies())
{
TryRegisterAssembly(env.ComponentCatalog, a);
}
}

private static string CreateTempDirectory()
{
string dir = GetTempPath();
Directory.CreateDirectory(dir);
return dir;
}

private static string GetTempPath()
{
Guid guid = Guid.NewGuid();
return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "MLNET_" + guid.ToString()));
}

private static readonly string[] _filePrefixesToAvoid = new string[] {
"api-ms-win",
"clr",
"coreclr",
"dbgshim",
"ext-ms-win",
"microsoft.bond.",
"microsoft.cosmos.",
"microsoft.csharp",
"microsoft.data.",
"microsoft.hpc.",
"microsoft.live.",
"microsoft.platformbuilder.",
"microsoft.visualbasic",
"microsoft.visualstudio.",
"microsoft.win32",
"microsoft.windowsapicodepack.",
"microsoft.windowsazure.",
"mscor",
"msvc",
"petzold.",
"roslyn.",
"sho",
"sni",
"sqm",
"system.",
"zlib",
};

private static bool ShouldSkipPath(string path)
{
string name = Path.GetFileName(path).ToLowerInvariant();
switch (name)
{
case "cqo.dll":
case "fasttreenative.dll":
case "libiomp5md.dll":
case "libvw.dll":
case "matrixinterf.dll":
case "microsoft.ml.neuralnetworks.gpucuda.dll":
case "mklimports.dll":
case "microsoft.research.controls.decisiontrees.dll":
case "microsoft.ml.neuralnetworks.sse.dll":
case "neuraltreeevaluator.dll":
case "optimizationbuilderdotnet.dll":
case "parallelcommunicator.dll":
case "microsoft.ml.runtime.runtests.dll":
case "scopecompiler.dll":
case "tbb.dll":
case "internallearnscope.dll":
case "unmanagedlib.dll":
case "vcclient.dll":
case "libxgboost.dll":
case "zedgraph.dll":
case "__scopecodegen__.dll":
case "cosmosClientApi.dll":
return true;
}

foreach (var s in _filePrefixesToAvoid)
{
if (name.StartsWith(s, StringComparison.OrdinalIgnoreCase))
return true;
}

return false;
}

private static void LoadAssembliesInDir(IHostEnvironment env, string dir, bool filter)
{
if (!Directory.Exists(dir))
return;

using (var ch = env.Start("LoadAssembliesInDir"))
{
// Load all dlls in the given directory.
var paths = Directory.EnumerateFiles(dir, "*.dll");
foreach (string path in paths)
{
if (filter && ShouldSkipPath(path))
{
ch.Info($"Skipping assembly '{path}' because its name was filtered.");
continue;
}

LoadAssembly(env, path);
}
}
}

/// <summary>
/// Given an assembly path, load the assembly and register it with the ComponentCatalog.
/// </summary>
private static Assembly LoadAssembly(IHostEnvironment env, string path)
{
Assembly assembly = null;
try
{
assembly = Assembly.LoadFrom(path);
}
catch (Exception e)
{
using (var ch = env.Start("LoadAssembly"))
{
ch.Error("Could not load assembly {0}:\n{1}", path, e.ToString());
}
return null;
}

if (assembly != null)
{
TryRegisterAssembly(env.ComponentCatalog, assembly);
}

return assembly;
}

/// <summary>
/// Checks whether <paramref name="assembly"/> references the assembly containing LoadableClassAttributeBase,
/// and therefore can contain components.
/// </summary>
private static bool CanContainComponents(Assembly assembly)
{
var targetFullName = typeof(LoadableClassAttributeBase).Assembly.GetName().FullName;

bool found = false;
foreach (var name in assembly.GetReferencedAssemblies())
{
if (name.FullName == targetFullName)
{
found = true;
break;
}
}

return found;
}

private static void TryRegisterAssembly(ComponentCatalog catalog, Assembly assembly)
{
// Don't try to index dynamic generated assembly
if (assembly.IsDynamic)
return;

if (!CanContainComponents(assembly))
return;

catalog.RegisterAssembly(assembly);
}

private sealed class AssemblyRegistrar : IDisposable
{
private readonly IHostEnvironment _env;

public AssemblyRegistrar(IHostEnvironment env, string path)
{
_env = env;

RegisterCurrentLoadedAssemblies(_env);

if (!string.IsNullOrEmpty(path))
{
LoadAssembliesInDir(_env, path, true);
path = Path.Combine(path, "AutoLoad");
LoadAssembliesInDir(_env, path, true);
}

AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad;
}

public void Dispose()
{
AppDomain.CurrentDomain.AssemblyLoad -= CurrentDomainAssemblyLoad;
}

private void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args)
{
TryRegisterAssembly(_env.ComponentCatalog, args.LoadedAssembly);
}
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Api/ComponentCreation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs ar
{
env.CheckValue(args, nameof(args));

var classes = ComponentCatalog.FindLoadableClasses<TArgs, TSig>();
var classes = env.ComponentCatalog.FindLoadableClasses<TArgs, TSig>();
if (classes.Length == 0)
throw env.Except("Couldn't find a {0} class that accepts {1} as arguments.", typeof(TRes).Name, typeof(TArgs).FullName);
if (classes.Length > 1)
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Api/SerializableLambdaTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(SerializableLambdaTransform).Assembly.FullName);
}

public const string LoaderSignature = "UserLambdaMapTransform";
Expand Down
Loading

0 comments on commit a02807c

Please sign in to comment.