diff --git a/src/Microsoft.Extensions.ML/PoolLoader.cs b/src/Microsoft.Extensions.ML/PoolLoader.cs index de394d563a..f678b943f5 100644 --- a/src/Microsoft.Extensions.ML/PoolLoader.cs +++ b/src/Microsoft.Extensions.ML/PoolLoader.cs @@ -1,8 +1,9 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using System.Runtime.CompilerServices; using System.Threading; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.ObjectPool; @@ -19,15 +20,23 @@ internal class PoolLoader : IDisposable where TData : class where TPrediction : class, new() { - private DefaultObjectPool> _pool; + private static readonly ObjectPoolProvider _poolProvider = new DefaultObjectPoolProvider(); + + private ObjectPool> _pool; + private ITransformer _model; private readonly IDisposable _changeTokenRegistration; + private readonly ConditionalWeakTable, ObjectPool>> _rentedEngines; + private int _disposed; + public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions poolOptions) { var contextOptions = sp.GetRequiredService>(); Context = contextOptions.Value.MLContext ?? throw new ArgumentNullException(nameof(contextOptions)); Loader = poolOptions.ModelLoader ?? throw new ArgumentNullException(nameof(poolOptions)); + _rentedEngines = new ConditionalWeakTable, ObjectPool>>(); + LoadPool(); _changeTokenRegistration = ChangeToken.OnChange( @@ -37,17 +46,103 @@ public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions> PredictionEnginePool { get { return _pool; } } + + /// + /// The active pool generation. Exposed for compatibility; prefer and + /// , which route an engine back to the generation that created it. + /// + public ObjectPool> PredictionEnginePool { get { return Volatile.Read(ref _pool); } } + + /// + /// Rents an engine from the current pool generation, recording its origin so it can be + /// returned to the correct generation later. + /// + public PredictionEngine Get() + { + var pool = Volatile.Read(ref _pool); + if (Volatile.Read(ref _disposed) != 0 || pool == null) + { + throw new ObjectDisposedException(nameof(PoolLoader)); + } + + var engine = pool.Get(); + + _rentedEngines.Remove(engine); + _rentedEngines.Add(engine, pool); + return engine; + } + + /// + /// Returns an engine to the generation it was rented from. If that generation has already + /// been disposed by a hot-swap, the pool disposes the engine instead of retaining it. + /// + public void Return(PredictionEngine engine) + { + if (engine == null) + { + throw new ArgumentNullException(nameof(engine)); + } + + if (_rentedEngines.TryGetValue(engine, out var origin)) + { + _rentedEngines.Remove(engine); + origin.Return(engine); + } + else + { + engine.Dispose(); + } + } private void LoadPool() { - var predictionEnginePolicy = new PredictionEnginePoolPolicy(Context, Loader.GetModel()); - Interlocked.Exchange(ref _pool, new DefaultObjectPool>(predictionEnginePolicy)); + if (Volatile.Read(ref _disposed) != 0) + { + return; + } + + var model = Loader.GetModel(); + var policy = new PredictionEnginePoolPolicy(Context, model); + var newPool = _poolProvider.Create(policy); + + var oldPool = Interlocked.Exchange(ref _pool, newPool); + var oldModel = Interlocked.Exchange(ref _model, model); + + DisposeGeneration(oldPool, oldModel, model); + + if (Volatile.Read(ref _disposed) != 0 && Interlocked.CompareExchange(ref _pool, null, newPool) == newPool) + { + DisposeGeneration(newPool, model, null); + } + } + + private static void DisposeGeneration( + ObjectPool> pool, + ITransformer model, + ITransformer survivingModel) + { + (pool as IDisposable)?.Dispose(); + + if (model != null && !ReferenceEquals(model, survivingModel)) + { + (model as IDisposable)?.Dispose(); + } } public void Dispose() { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + return; + } + _changeTokenRegistration?.Dispose(); + + var pool = Interlocked.Exchange(ref _pool, null); + var model = Interlocked.Exchange(ref _model, null); + DisposeGeneration(pool, model, null); + + (Loader as IDisposable)?.Dispose(); } } } diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs index db97d97fde..e1c8fa0c07 100644 --- a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs +++ b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Threading; using Microsoft.Extensions.Options; using Microsoft.ML; @@ -14,7 +15,7 @@ namespace Microsoft.Extensions.ML /// Provides a pool of objects /// that can be used to make predictions. /// - public class PredictionEnginePool + public class PredictionEnginePool : IDisposable where TData : class where TPrediction : class, new() { @@ -23,6 +24,7 @@ public class PredictionEnginePool private readonly IServiceProvider _serviceProvider; private readonly PoolLoader _defaultEnginePool; private readonly ConcurrentDictionary> _namedPools; + private int _disposed; public PredictionEnginePool(IServiceProvider serviceProvider, IOptions mlContextOptions, @@ -85,9 +87,14 @@ public PredictionEngine GetPredictionEngine() /// public PredictionEngine GetPredictionEngine(string modelName) { + if (Volatile.Read(ref _disposed) != 0) + { + throw new ObjectDisposedException(nameof(PredictionEnginePool)); + } + if (_namedPools.TryGetValue(modelName, out var existingPool)) { - return existingPool.PredictionEnginePool.Get(); + return existingPool.Get(); } //This is the case where someone has used string.Empty to get the default model. @@ -100,11 +107,11 @@ public PredictionEngine GetPredictionEngine(string modelName throw new ArgumentException("You need to configure a default, not named, model before you use this method."); } - return _defaultEnginePool.PredictionEnginePool.Get(); + return _defaultEnginePool.Get(); } var pool = AddPool(modelName); - return pool.PredictionEnginePool.Get(); + return pool.Get(); } private PoolLoader AddPool(string modelName) @@ -141,14 +148,50 @@ public void ReturnPredictionEngine(string modelName, PredictionEngine + /// Disposes the pooled prediction engines and releases the file/uri watchers backing each + /// model loader. + /// + public void Dispose() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + return; + } + + _defaultEnginePool?.Dispose(); + foreach (var pool in _namedPools.Values) + { + pool.Dispose(); } + _namedPools.Clear(); } } } diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs b/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs index e0557ada71..cf16f170a0 100644 --- a/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs +++ b/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs @@ -34,7 +34,9 @@ public PredictionEnginePoolPolicy(MLContext mlContext, ITransformer model) /// public override PredictionEngine Create() => - _mlContext.Model.CreatePredictionEngine(_model); + _mlContext.Model.CreatePredictionEngine( + _model, + new PredictionEngineOptions { OwnsTransformer = false }); /// public override bool Return(PredictionEngine obj) => true; diff --git a/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs b/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs index ac77ee9e4f..3676e98e99 100644 --- a/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs +++ b/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs @@ -9,6 +9,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; +using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.TestFramework; using Microsoft.ML.TestFrameworkCommon; @@ -41,6 +42,134 @@ public void can_load_namedmodel() Assert.NotNull(model); } + [Fact] + public void pool_serves_predictions_across_a_hot_swap() + { + var loader = CreateReloadableLoader(); + using var pool = CreatePool(loader); + + var inFlight = pool.GetPredictionEngine(); + Assert.NotNull(inFlight.Predict(new SentimentData { SentimentText = "great" })); + + loader.Reload(); + + var afterSwap = pool.GetPredictionEngine(); + Assert.NotNull(afterSwap.Predict(new SentimentData { SentimentText = "terrible" })); + pool.ReturnPredictionEngine(afterSwap); + + pool.ReturnPredictionEngine(inFlight); + + var reused = pool.GetPredictionEngine(); + Assert.NotSame(inFlight, reused); + Assert.NotNull(reused.Predict(new SentimentData { SentimentText = "fine" })); + pool.ReturnPredictionEngine(reused); + } + + [Fact] + public void disposing_pool_releases_loader_resources() + { + var loader = CreateReloadableLoader(); + var pool = CreatePool(loader); + + _ = pool.GetPredictionEngine(); + pool.Dispose(); + + loader.Reload(); + pool.Dispose(); + } + + [Fact] + public void pooled_engines_do_not_dispose_the_shared_model() + { + var context = new MLContext(seed: 1); + using var stream = File.OpenRead(Path.Combine("TestModels", "SentimentModel.zip")); + var innerModel = context.Model.Load(stream, out _); + var model = new DisposeCountingTransformer(innerModel); + var loader = new ReloadableModelLoader(model); + using var pool = CreatePool(loader); + + var maximumRetained = Environment.ProcessorCount * 2; + var rented = new PredictionEngine[maximumRetained + 2]; + for (var i = 0; i < rented.Length; i++) + { + rented[i] = pool.GetPredictionEngine(); + } + + foreach (var engine in rented) + { + pool.ReturnPredictionEngine(engine); + } + + Assert.Equal(0, model.DisposeCount); + + var afterOverflow = pool.GetPredictionEngine(); + Assert.NotNull(afterOverflow.Predict(new SentimentData { SentimentText = "still works" })); + pool.ReturnPredictionEngine(afterOverflow); + } + + private static ReloadableModelLoader CreateReloadableLoader() + { + var context = new MLContext(seed: 1); + using var stream = File.OpenRead(Path.Combine("TestModels", "SentimentModel.zip")); + var model = context.Model.Load(stream, out _); + return new ReloadableModelLoader(model); + } + + private static PredictionEnginePool CreatePool(ModelLoader loader) + { + var services = new ServiceCollection().AddOptions().AddLogging(); + services.AddPredictionEnginePool(); + services.Configure>( + string.Empty, o => o.ModelLoader = loader); + + var sp = services.BuildServiceProvider(); + return sp.GetRequiredService>(); + } + + private sealed class ReloadableModelLoader : ModelLoader + { + private readonly ITransformer _model; + private ModelReloadToken _token = new ModelReloadToken(); + + public ReloadableModelLoader(ITransformer model) => _model = model; + + public override IChangeToken GetReloadToken() => _token; + + public override ITransformer GetModel() => _model; + + public void Reload() + { + var previous = Interlocked.Exchange(ref _token, new ModelReloadToken()); + previous.OnReload(); + } + } + + private sealed class DisposeCountingTransformer : ITransformer, IDisposable + { + private readonly ITransformer _inner; + private int _disposeCount; + + public DisposeCountingTransformer(ITransformer inner) => _inner = inner; + + public int DisposeCount => _disposeCount; + + public bool IsRowToRowMapper => _inner.IsRowToRowMapper; + + public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => _inner.GetOutputSchema(inputSchema); + + public IDataView Transform(IDataView input) => _inner.Transform(input); + + public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => _inner.GetRowToRowMapper(inputSchema); + + public void Save(ModelSaveContext ctx) => _inner.Save(ctx); + + public void Dispose() + { + Interlocked.Increment(ref _disposeCount); + (_inner as IDisposable)?.Dispose(); + } + } + public class SentimentData { [ColumnName("Label"), LoadColumn(0)]