Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 100 additions & 5 deletions src/Microsoft.Extensions.ML/PoolLoader.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.ObjectPool;
Expand All @@ -19,15 +20,23 @@ internal class PoolLoader<TData, TPrediction> : IDisposable
where TData : class
where TPrediction : class, new()
{
private DefaultObjectPool<PredictionEngine<TData, TPrediction>> _pool;
private static readonly ObjectPoolProvider _poolProvider = new DefaultObjectPoolProvider();

private ObjectPool<PredictionEngine<TData, TPrediction>> _pool;
private ITransformer _model;
private readonly IDisposable _changeTokenRegistration;

private readonly ConditionalWeakTable<PredictionEngine<TData, TPrediction>, ObjectPool<PredictionEngine<TData, TPrediction>>> _rentedEngines;
private int _disposed;

public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions<TData, TPrediction> poolOptions)
{
var contextOptions = sp.GetRequiredService<IOptions<MLOptions>>();
Context = contextOptions.Value.MLContext ?? throw new ArgumentNullException(nameof(contextOptions));
Loader = poolOptions.ModelLoader ?? throw new ArgumentNullException(nameof(poolOptions));

_rentedEngines = new ConditionalWeakTable<PredictionEngine<TData, TPrediction>, ObjectPool<PredictionEngine<TData, TPrediction>>>();

LoadPool();

_changeTokenRegistration = ChangeToken.OnChange(
Expand All @@ -37,17 +46,103 @@ public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions<TData, TPredi

public ModelLoader Loader { get; }
private MLContext Context { get; }
public ObjectPool<PredictionEngine<TData, TPrediction>> PredictionEnginePool { get { return _pool; } }

/// <summary>
/// The active pool generation. Exposed for compatibility; prefer <see cref="Get"/> and
/// <see cref="Return"/>, which route an engine back to the generation that created it.
/// </summary>
public ObjectPool<PredictionEngine<TData, TPrediction>> PredictionEnginePool { get { return Volatile.Read(ref _pool); } }

/// <summary>
/// Rents an engine from the current pool generation, recording its origin so it can be
/// returned to the correct generation later.
/// </summary>
public PredictionEngine<TData, TPrediction> Get()
{
var pool = Volatile.Read(ref _pool);
if (Volatile.Read(ref _disposed) != 0 || pool == null)
{
throw new ObjectDisposedException(nameof(PoolLoader<TData, TPrediction>));
}

var engine = pool.Get();

_rentedEngines.Remove(engine);
_rentedEngines.Add(engine, pool);
Comment thread
rosebyte marked this conversation as resolved.
return engine;
}

/// <summary>
/// Returns an engine to the generation it was rented from. If that generation has already
/// been disposed by a hot-swap, the pool disposes the engine instead of retaining it.
/// </summary>
public void Return(PredictionEngine<TData, TPrediction> engine)
{
if (engine == null)
Comment thread
rosebyte marked this conversation as resolved.
{
throw new ArgumentNullException(nameof(engine));
}

if (_rentedEngines.TryGetValue(engine, out var origin))
{
_rentedEngines.Remove(engine);
origin.Return(engine);
}
else
{
engine.Dispose();
}
}

private void LoadPool()
{
var predictionEnginePolicy = new PredictionEnginePoolPolicy<TData, TPrediction>(Context, Loader.GetModel());
Interlocked.Exchange(ref _pool, new DefaultObjectPool<PredictionEngine<TData, TPrediction>>(predictionEnginePolicy));
Comment thread
rosebyte marked this conversation as resolved.
if (Volatile.Read(ref _disposed) != 0)
{
return;
}

var model = Loader.GetModel();
var policy = new PredictionEnginePoolPolicy<TData, TPrediction>(Context, model);
var newPool = _poolProvider.Create(policy);

var oldPool = Interlocked.Exchange(ref _pool, newPool);
var oldModel = Interlocked.Exchange(ref _model, model);

DisposeGeneration(oldPool, oldModel, model);

if (Volatile.Read(ref _disposed) != 0 && Interlocked.CompareExchange(ref _pool, null, newPool) == newPool)
{
DisposeGeneration(newPool, model, null);
}
}

private static void DisposeGeneration(
ObjectPool<PredictionEngine<TData, TPrediction>> pool,
ITransformer model,
ITransformer survivingModel)
{
(pool as IDisposable)?.Dispose();

if (model != null && !ReferenceEquals(model, survivingModel))
{
(model as IDisposable)?.Dispose();
}
}

public void Dispose()
{
if (Interlocked.Exchange(ref _disposed, 1) != 0)
{
return;
}

_changeTokenRegistration?.Dispose();

var pool = Interlocked.Exchange(ref _pool, null);
var model = Interlocked.Exchange(ref _model, null);
DisposeGeneration(pool, model, null);

(Loader as IDisposable)?.Dispose();
}
Comment thread
kotlarmilos marked this conversation as resolved.
}
}
55 changes: 49 additions & 6 deletions src/Microsoft.Extensions.ML/PredictionEnginePool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using Microsoft.Extensions.Options;
using Microsoft.ML;

Expand All @@ -14,7 +15,7 @@ namespace Microsoft.Extensions.ML
/// Provides a pool of <see cref="PredictionEngine{TSrc, TDst}"/> objects
/// that can be used to make predictions.
/// </summary>
public class PredictionEnginePool<TData, TPrediction>
public class PredictionEnginePool<TData, TPrediction> : IDisposable
where TData : class
where TPrediction : class, new()
{
Expand All @@ -23,6 +24,7 @@ public class PredictionEnginePool<TData, TPrediction>
private readonly IServiceProvider _serviceProvider;
private readonly PoolLoader<TData, TPrediction> _defaultEnginePool;
private readonly ConcurrentDictionary<string, PoolLoader<TData, TPrediction>> _namedPools;
private int _disposed;

public PredictionEnginePool(IServiceProvider serviceProvider,
IOptions<MLOptions> mlContextOptions,
Expand Down Expand Up @@ -85,9 +87,14 @@ public PredictionEngine<TData, TPrediction> GetPredictionEngine()
/// </param>
public PredictionEngine<TData, TPrediction> GetPredictionEngine(string modelName)
{
if (Volatile.Read(ref _disposed) != 0)
{
throw new ObjectDisposedException(nameof(PredictionEnginePool<TData, TPrediction>));
}

if (_namedPools.TryGetValue(modelName, out var existingPool))
{
return existingPool.PredictionEnginePool.Get();
return existingPool.Get();
}

//This is the case where someone has used string.Empty to get the default model.
Expand All @@ -100,11 +107,11 @@ public PredictionEngine<TData, TPrediction> GetPredictionEngine(string modelName
throw new ArgumentException("You need to configure a default, not named, model before you use this method.");
}

return _defaultEnginePool.PredictionEnginePool.Get();
return _defaultEnginePool.Get();
}

var pool = AddPool(modelName);
return pool.PredictionEnginePool.Get();
return pool.Get();
}

private PoolLoader<TData, TPrediction> AddPool(string modelName)
Expand Down Expand Up @@ -141,14 +148,50 @@ public void ReturnPredictionEngine(string modelName, PredictionEngine<TData, TPr
throw new ArgumentNullException(nameof(engine));
}

if (Volatile.Read(ref _disposed) != 0)
{
engine.Dispose();
return;
}

if (string.IsNullOrEmpty(modelName))
{
_defaultEnginePool.PredictionEnginePool.Return(engine);
if (_defaultEnginePool != null)
{
_defaultEnginePool.Return(engine);
}
else
{
engine.Dispose();
}
}
else if (_namedPools.TryGetValue(modelName, out var pool))
{
pool.Return(engine);
}
else
{
_namedPools[modelName].PredictionEnginePool.Return(engine);
engine.Dispose();
}
}

/// <summary>
/// Disposes the pooled prediction engines and releases the file/uri watchers backing each
/// model loader.
/// </summary>
public void Dispose()
{
if (Interlocked.Exchange(ref _disposed, 1) != 0)
{
return;
}

_defaultEnginePool?.Dispose();
foreach (var pool in _namedPools.Values)
{
pool.Dispose();
}
_namedPools.Clear();
}
}
}
4 changes: 3 additions & 1 deletion src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ public PredictionEnginePoolPolicy(MLContext mlContext, ITransformer model)

/// <inheritdoc />
public override PredictionEngine<TData, TPrediction> Create() =>
_mlContext.Model.CreatePredictionEngine<TData, TPrediction>(_model);
_mlContext.Model.CreatePredictionEngine<TData, TPrediction>(
_model,
new PredictionEngineOptions { OwnsTransformer = false });

/// <inheritdoc />
public override bool Return(PredictionEngine<TData, TPrediction> obj) => true;
Expand Down
Loading
Loading