Skip to content

Commit

Permalink
Improve package mirroring (loic-sharma#679)
Browse files Browse the repository at this point in the history
Improves the mirroring implementation:

1. Push all mirroring logic into the `MirrorService`. This will make it easier to reuse mirroring logic in the upcoming Razor Pages rewrite: loic-sharma#678
2. Add unit and integration tests on the mirroring functionality
3. Fixed bugs caught by tests
  • Loading branch information
loic-sharma authored Sep 6, 2021
1 parent a742a9a commit 7d5e2ad
Show file tree
Hide file tree
Showing 21 changed files with 827 additions and 354 deletions.
16 changes: 3 additions & 13 deletions src/BaGet.Core/Content/DefaultPackageContentService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,10 @@ public async Task<PackageVersionsResponse> GetPackageVersionsOrNullAsync(
string id,
CancellationToken cancellationToken = default)
{
// First, attempt to find all package versions using the upstream source.
var versions = await _mirror.FindPackageVersionsOrNullAsync(id, cancellationToken);

if (versions == null)
var versions = await _mirror.FindPackageVersionsAsync(id, cancellationToken);
if (!versions.Any())
{
// Fallback to the local packages if mirroring is disabled.
var packages = await _packages.FindAsync(id, includeUnlisted: true, cancellationToken);

if (!packages.Any())
{
return null;
}

versions = packages.Select(p => p.Version).ToList();
return null;
}

return new PackageVersionsResponse
Expand Down
12 changes: 6 additions & 6 deletions src/BaGet.Core/Extensions/DependencyInjectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ private static void AddBaGetServices(this IServiceCollection services)
services.TryAddTransient<MirrorService>();
services.TryAddTransient<MirrorV2Client>();
services.TryAddTransient<MirrorV3Client>();
services.TryAddTransient<NullMirrorService>();
services.TryAddTransient<DisabledMirrorService>();
services.TryAddSingleton<NullStorageService>();
services.TryAddTransient<PackageService>();

services.TryAddTransient(IMirrorServiceFactory);
services.TryAddTransient(IMirrorNuGetClientFactory);
services.TryAddTransient(IMirrorClientFactory);
}

private static void AddDefaultProviders(this IServiceCollection services)
Expand Down Expand Up @@ -197,17 +197,17 @@ private static NuGetClientFactory NuGetClientFactoryFactory(IServiceProvider pro
private static IMirrorService IMirrorServiceFactory(IServiceProvider provider)
{
var options = provider.GetRequiredService<IOptionsSnapshot<MirrorOptions>>();
var service = options.Value.Enabled ? typeof(MirrorService) : typeof(NullMirrorService);
var service = options.Value.Enabled ? typeof(MirrorService) : typeof(DisabledMirrorService);

return (IMirrorService)provider.GetRequiredService(service);
}

private static IMirrorNuGetClient IMirrorNuGetClientFactory(IServiceProvider provider)
private static IMirrorClient IMirrorClientFactory(IServiceProvider provider)
{
var options = provider.GetRequiredService<IOptionsSnapshot<MirrorOptions>>();
var service = options.Value.Legacy ? typeof(MirrorV2Client) : typeof(MirrorV3Client);

return (IMirrorNuGetClient)provider.GetRequiredService(service);
return (IMirrorClient)provider.GetRequiredService(service);
}
}
}
6 changes: 3 additions & 3 deletions src/BaGet.Core/Indexing/PackageIndexingService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ public async Task<PackageIndexingResult> IndexAsync(Stream packageStream, Cancel
package.Published = _time.UtcNow;

nuspecStream = await packageReader.GetNuspecAsync(cancellationToken);
nuspecStream = await nuspecStream.AsTemporaryFileStreamAsync();
nuspecStream = await nuspecStream.AsTemporaryFileStreamAsync(cancellationToken);

if (package.HasReadme)
{
readmeStream = await packageReader.GetReadmeAsync(cancellationToken);
readmeStream = await readmeStream.AsTemporaryFileStreamAsync();
readmeStream = await readmeStream.AsTemporaryFileStreamAsync(cancellationToken);
}
else
{
Expand All @@ -64,7 +64,7 @@ public async Task<PackageIndexingResult> IndexAsync(Stream packageStream, Cancel
if (package.HasEmbeddedIcon)
{
iconStream = await packageReader.GetIconAsync(cancellationToken);
iconStream = await iconStream.AsTemporaryFileStreamAsync();
iconStream = await iconStream.AsTemporaryFileStreamAsync(cancellationToken);
}
else
{
Expand Down
30 changes: 2 additions & 28 deletions src/BaGet.Core/Metadata/DefaultPackageMetadataService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public async Task<BaGetRegistrationIndexResponse> GetRegistrationIndexOrNullAsyn
string packageId,
CancellationToken cancellationToken = default)
{
var packages = await FindPackagesOrNullAsync(packageId, cancellationToken);
if (packages == null)
var packages = await _mirror.FindPackagesAsync(packageId, cancellationToken);
if (!packages.Any())
{
return null;
}
Expand All @@ -58,31 +58,5 @@ public async Task<RegistrationLeafResponse> GetRegistrationLeafOrNullAsync(

return _builder.BuildLeaf(package);
}

private async Task<IReadOnlyList<Package>> FindPackagesOrNullAsync(
string packageId,
CancellationToken cancellationToken)
{
var upstreamPackages = await _mirror.FindPackagesOrNullAsync(packageId, cancellationToken);
var localPackages = await _packages.FindAsync(packageId, includeUnlisted: true, cancellationToken);

if (upstreamPackages == null)
{
return localPackages.Any()
? localPackages
: null;
}

// Mrge the local packages into the upstream packages.
var result = upstreamPackages.ToDictionary(p => new PackageIdentity(p.Id, p.Version));
var local = localPackages.ToDictionary(p => new PackageIdentity(p.Id, p.Version));

foreach (var localPackage in local)
{
result[localPackage.Key] = localPackage.Value;
}

return result.Values.ToList();
}
}
}
121 changes: 81 additions & 40 deletions src/BaGet.Core/Mirror/Clients/MirrorV2Client.cs
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using BaGet.Protocol.Models;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using NuGet.Common;
using NuGet.Configuration;
using NuGet.Protocol;
using NuGet.Protocol.Core.Types;
using NuGet.Versioning;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace BaGet.Core
{
internal sealed class MirrorV2Client : IMirrorNuGetClient
using ILogger = Microsoft.Extensions.Logging.ILogger<MirrorV2Client>;
using INuGetLogger = NuGet.Common.ILogger;

/// <summary>
/// The mirroring client for a NuGet server that uses the V2 protocol.
/// </summary>
internal sealed class MirrorV2Client : IMirrorClient, IDisposable
{
private readonly ILogger _logger;
private readonly SourceCacheContext _cache;
private readonly SourceRepository _repository;
private readonly INuGetLogger _ngLogger;
private readonly ILogger _logger;

public MirrorV2Client(IOptionsSnapshot<MirrorOptions> options)
public MirrorV2Client(
IOptionsSnapshot<MirrorOptions> options,
ILogger logger)
{
if (options is null)
{
Expand All @@ -32,60 +42,91 @@ public MirrorV2Client(IOptionsSnapshot<MirrorOptions> options)
throw new ArgumentException("No mirror package source has been set.");
}

_logger = NullLogger.Instance;
_logger = logger ?? throw new ArgumentNullException(nameof(logger));

_ngLogger = NullLogger.Instance;
_cache = new SourceCacheContext();
_repository = Repository.Factory.GetCoreV2(new PackageSource(options.Value.PackageSource.AbsoluteUri));
}

public async Task<IReadOnlyList<NuGetVersion>> ListPackageVersionsAsync(string id, bool includeUnlisted, CancellationToken cancellationToken)
public async Task<IReadOnlyList<NuGetVersion>> ListPackageVersionsAsync(string id, CancellationToken cancellationToken)
{
var resource = await _repository.GetResourceAsync<FindPackageByIdResource>();
var versions = await resource.GetAllVersionsAsync(id, _cache, _logger, cancellationToken);
try
{
var resource = await _repository.GetResourceAsync<FindPackageByIdResource>(cancellationToken);
var versions = await resource.GetAllVersionsAsync(id, _cache, _ngLogger, cancellationToken);

return versions.ToList();
return versions.ToList();
}
catch (Exception e)
{
_logger.LogError(e, "Failed to mirror {PackageId}'s upstream versions", id);
return new List<NuGetVersion>();
}
}

public async Task<IReadOnlyList<PackageMetadata>> GetPackageMetadataAsync(string id, CancellationToken cancellationToken)
{
var resource = await _repository.GetResourceAsync<PackageMetadataResource>();
var packages = await resource.GetMetadataAsync(id, includePrerelease: true, includeUnlisted: false, _cache, _logger, cancellationToken);
try
{
var resource = await _repository.GetResourceAsync<PackageMetadataResource>(cancellationToken);
var packages = await resource.GetMetadataAsync(

id,
includePrerelease: true,
includeUnlisted: true,
_cache,
_ngLogger,
cancellationToken);

var result = new List<PackageMetadata>();
foreach (var package in packages)
return packages
.Select(package => new PackageMetadata
{
Authors = package.Authors,
Description = package.Description,
IconUrl = package.IconUrl?.AbsoluteUri,
LicenseUrl = package.LicenseUrl?.AbsoluteUri,
Listed = package.IsListed,
PackageId = id,
Summary = package.Summary,
Version = package.Identity.Version.ToString(),
Tags = package.Tags?.Split(new[] {';'}, StringSplitOptions.RemoveEmptyEntries),
Title = package.Title,
RequireLicenseAcceptance = package.RequireLicenseAcceptance,
Published = package.Published?.UtcDateTime ?? DateTimeOffset.MinValue,
ProjectUrl = package.ProjectUrl?.AbsoluteUri,
DependencyGroups = ToDependencyGroups(package),
})
.ToList();
}
catch (Exception e)
{
result.Add(new PackageMetadata
{
Authors = package.Authors,
Description = package.Description,
IconUrl = package.IconUrl?.AbsoluteUri,
LicenseUrl = package.LicenseUrl?.AbsoluteUri,
Listed = package.IsListed,
PackageId = id,
Summary = package.Summary,
Version = package.Identity.Version.ToString(),
Tags = package.Tags?.Split(new[] { ';' }, StringSplitOptions.RemoveEmptyEntries),
Title = package.Title,
RequireLicenseAcceptance = package.RequireLicenseAcceptance,
Published = package.Published?.UtcDateTime ?? DateTimeOffset.MinValue,
ProjectUrl = package.ProjectUrl?.AbsoluteUri,
DependencyGroups = GetDependencies(package),
});
_logger.LogError(e, "Failed to mirror {PackageId}'s upstream versions", id);
return new List<PackageMetadata>();
}

return result;
}

public async Task<Stream> DownloadPackageAsync(string id, NuGetVersion version, CancellationToken cancellationToken)
{
var packageStream = new MemoryStream();
var resource = await _repository.GetResourceAsync<FindPackageByIdResource>();
await resource.CopyNupkgToStreamAsync(id, version, packageStream, _cache, _logger, cancellationToken);
var resource = await _repository.GetResourceAsync<FindPackageByIdResource>(cancellationToken);
var success = await resource.CopyNupkgToStreamAsync(
id, version, packageStream, _cache, _ngLogger,
cancellationToken);

if (!success)
{
throw new PackageNotFoundException(id, version);
}

packageStream.Seek(0, SeekOrigin.Begin);

return packageStream;
}

private IReadOnlyList<DependencyGroupItem> GetDependencies(IPackageSearchMetadata package)
public void Dispose() => _cache.Dispose();

private IReadOnlyList<DependencyGroupItem> ToDependencyGroups(IPackageSearchMetadata package)
{
var groupItems = new List<DependencyGroupItem>();
foreach (var set in package.DependencySets)
Expand Down
47 changes: 37 additions & 10 deletions src/BaGet.Core/Mirror/Clients/MirrorV3Client.cs
Original file line number Diff line number Diff line change
@@ -1,36 +1,63 @@
using BaGet.Protocol;
using BaGet.Protocol.Models;
using NuGet.Versioning;
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using BaGet.Protocol;
using BaGet.Protocol.Models;
using NuGet.Versioning;
using Microsoft.Extensions.Logging;

namespace BaGet.Core
{
internal sealed class MirrorV3Client : IMirrorNuGetClient
/// <summary>
/// The mirroring client for a NuGet server that uses the V3 protocol.
/// </summary>
internal sealed class MirrorV3Client : IMirrorClient
{
private readonly NuGetClient _client;
private readonly ILogger<MirrorV3Client> _logger;

public MirrorV3Client(NuGetClient client)
public MirrorV3Client(NuGetClient client, ILogger<MirrorV3Client> logger)
{
_client = client ?? throw new ArgumentNullException(nameof(client));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}

public async Task<Stream> DownloadPackageAsync(string id, NuGetVersion version, CancellationToken cancellationToken)
public async Task<Stream> DownloadPackageAsync(string id, NuGetVersion version,
CancellationToken cancellationToken)
{
return await _client.DownloadPackageAsync(id, version, cancellationToken);
}

public async Task<IReadOnlyList<PackageMetadata>> GetPackageMetadataAsync(string id, CancellationToken cancellationToken)
public async Task<IReadOnlyList<PackageMetadata>> GetPackageMetadataAsync(
string id,
CancellationToken cancellationToken)
{
return await _client.GetPackageMetadataAsync(id, cancellationToken);
try
{
return await _client.GetPackageMetadataAsync(id, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to mirror {PackageId}'s upstream metadata", id);
return new List<PackageMetadata>();
}
}

public async Task<IReadOnlyList<NuGetVersion>> ListPackageVersionsAsync(string id, bool includeUnlisted, CancellationToken cancellationToken)
public async Task<IReadOnlyList<NuGetVersion>> ListPackageVersionsAsync(
string id,
CancellationToken cancellationToken)
{
return await _client.ListPackageVersionsAsync(id, includeUnlisted, cancellationToken);
try
{
return await _client.ListPackageVersionsAsync(id, includeUnlisted: true, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to mirror {PackageId}'s upstream versions", id);
return new List<NuGetVersion>();
}
}
}
}
Loading

0 comments on commit 7d5e2ad

Please sign in to comment.