From d421aab340efb02e03bc1c9a80456b7bf680db69 Mon Sep 17 00:00:00 2001 From: Milos Kotlar Date: Sat, 13 Jun 2026 13:42:48 +0200 Subject: [PATCH 1/2] Fix memory leak and stale-model results when hot-swapping PredictionEnginePool ## Summary Makes zero-downtime model hot-swapping in `Microsoft.Extensions.ML` safe under live traffic: `PredictionEnginePool` can now be reloaded continuously without leaking native memory, serving predictions from a stale model, or disposing a shared transformer out from under sibling engines. ## Problem `PredictionEngine` wraps native resources and is `IDisposable`. Three defects made hot-swap unsafe for a long-lived, high-throughput service: 1. **Native memory leak on reload.** `PoolLoader.LoadPool()` created the pool with `new DefaultObjectPool<>(...)` and, on reload, swapped in a new pool via `Interlocked.Exchange` without disposing the old one. `DefaultObjectPool` never disposes the objects it retains, so every reload abandoned a generation of native engines to the GC finalizer (effectively unbounded growth at a 15-20 minute swap cadence). 2. **Cross-generation contamination (wrong-model results).** `ReturnPredictionEngine` always returned an engine to the *current* pool. An engine rented against the old model and returned after a swap landed in the *new* model's pool and was later handed to a caller, producing stale-model predictions. 3. **Shared-transformer double dispose.** `PredictionEnginePoolPolicy.Create()` built engines with the default `ownsTransformer: true`, so every engine in a generation believed it owned the single shared `ITransformer`. Once engines are actually disposed (e.g. on overflow), the first disposal tears down the shared model for all sibling engines. ## Change - `PoolLoader` now builds each pool generation with `DefaultObjectPoolProvider`, which returns the framework's `DisposableObjectPool` for `IDisposable` element types. That pool already disposes overflow engines, disposes engines returned after it has been disposed, and disposes everything it retains on `Dispose()` - so no custom pool type is needed. - `PoolLoader` is generation-aware: a `ConditionalWeakTable` records the generation each rented engine came from and routes it back to that exact generation on return, so an old-model engine is never mixed into the new pool. On reload the old generation is disposed atomically after the swap, and the generation's `ITransformer` is disposed with it (guarded by reference equality so a model still in use by the surviving generation is never disposed). - `PredictionEnginePoolPolicy.Create()` now passes `OwnsTransformer = false`; the shared model's lifetime is owned by `PoolLoader` (per generation) instead of by each engine. - `PredictionEnginePool` implements `IDisposable`, guards `GetPredictionEngine` against use after disposal, disposes (rather than throws on) engines returned after disposal, and disposes its loaders so the file/uri watchers and change-token registrations are torn down. ## Proof New tests: `pool_serves_predictions_across_a_hot_swap` (serve across an in-flight swap; stale engine not reused), `disposing_pool_releases_loader_resources`, and `pooled_engines_do_not_dispose_the_shared_model` (rents past the retention limit to force overflow disposal and asserts the shared transformer is never disposed; this test fails if `ownsTransformer` is left at its default of true). ``` Passed! - Failed: 0, Passed: 10, Skipped: 0, Total: 10 - Microsoft.Extensions.ML.Tests.dll (net8.0) ``` ## Risk Behavioral change limited to disposal/lifetime semantics. Public API addition only (`PredictionEnginePool` now `IDisposable`). No change to prediction results on the happy path; stale-model engines are simply no longer reused after a swap. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Microsoft.Extensions.ML/PoolLoader.cs | 106 +++++++++++++- .../PredictionEnginePool.cs | 56 +++++++- .../PredictionEnginePoolPolicy.cs | 4 +- .../PredictionEnginePoolTests.cs | 129 ++++++++++++++++++ 4 files changed, 283 insertions(+), 12 deletions(-) diff --git a/src/Microsoft.Extensions.ML/PoolLoader.cs b/src/Microsoft.Extensions.ML/PoolLoader.cs index de394d563a..1263a68f61 100644 --- a/src/Microsoft.Extensions.ML/PoolLoader.cs +++ b/src/Microsoft.Extensions.ML/PoolLoader.cs @@ -1,8 +1,9 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using System.Runtime.CompilerServices; using System.Threading; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.ObjectPool; @@ -19,15 +20,23 @@ internal class PoolLoader : IDisposable where TData : class where TPrediction : class, new() { - private DefaultObjectPool> _pool; + private static readonly ObjectPoolProvider _poolProvider = new DefaultObjectPoolProvider(); + + private ObjectPool> _pool; + private ITransformer _model; private readonly IDisposable _changeTokenRegistration; + private readonly ConditionalWeakTable, ObjectPool>> _rentedEngines; + private bool _disposed; + public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions poolOptions) { var contextOptions = sp.GetRequiredService>(); Context = contextOptions.Value.MLContext ?? throw new ArgumentNullException(nameof(contextOptions)); Loader = poolOptions.ModelLoader ?? throw new ArgumentNullException(nameof(poolOptions)); + _rentedEngines = new ConditionalWeakTable, ObjectPool>>(); + LoadPool(); _changeTokenRegistration = ChangeToken.OnChange( @@ -37,17 +46,104 @@ public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions> PredictionEnginePool { get { return _pool; } } + + /// + /// The active pool generation. Exposed for compatibility; prefer and + /// , which route an engine back to the generation that created it. + /// + public ObjectPool> PredictionEnginePool { get { return Volatile.Read(ref _pool); } } + + /// + /// Rents an engine from the current pool generation, recording its origin so it can be + /// returned to the correct generation later. + /// + public PredictionEngine Get() + { + var pool = Volatile.Read(ref _pool); + if (_disposed || pool == null) + { + throw new ObjectDisposedException(nameof(PoolLoader)); + } + + var engine = pool.Get(); + + _rentedEngines.Remove(engine); + _rentedEngines.Add(engine, pool); + return engine; + } + + /// + /// Returns an engine to the generation it was rented from. If that generation has already + /// been disposed by a hot-swap, the pool disposes the engine instead of retaining it. + /// + public void Return(PredictionEngine engine) + { + if (engine == null) + { + throw new ArgumentNullException(nameof(engine)); + } + + if (_rentedEngines.TryGetValue(engine, out var origin)) + { + _rentedEngines.Remove(engine); + origin.Return(engine); + } + else + { + engine.Dispose(); + } + } private void LoadPool() { - var predictionEnginePolicy = new PredictionEnginePoolPolicy(Context, Loader.GetModel()); - Interlocked.Exchange(ref _pool, new DefaultObjectPool>(predictionEnginePolicy)); + if (_disposed) + { + return; + } + + var model = Loader.GetModel(); + var policy = new PredictionEnginePoolPolicy(Context, model); + var newPool = _poolProvider.Create(policy); + + var oldPool = Interlocked.Exchange(ref _pool, newPool); + var oldModel = Interlocked.Exchange(ref _model, model); + + DisposeGeneration(oldPool, oldModel, model); + + if (_disposed && Interlocked.CompareExchange(ref _pool, null, newPool) == newPool) + { + DisposeGeneration(newPool, model, null); + } + } + + private static void DisposeGeneration( + ObjectPool> pool, + ITransformer model, + ITransformer survivingModel) + { + (pool as IDisposable)?.Dispose(); + + if (model != null && !ReferenceEquals(model, survivingModel)) + { + (model as IDisposable)?.Dispose(); + } } public void Dispose() { + if (_disposed) + { + return; + } + + _disposed = true; _changeTokenRegistration?.Dispose(); + + var pool = Interlocked.Exchange(ref _pool, null); + var model = Interlocked.Exchange(ref _model, null); + DisposeGeneration(pool, model, null); + + (Loader as IDisposable)?.Dispose(); } } } diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs index db97d97fde..9fb2c18c9e 100644 --- a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs +++ b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs @@ -14,7 +14,7 @@ namespace Microsoft.Extensions.ML /// Provides a pool of objects /// that can be used to make predictions. /// - public class PredictionEnginePool + public class PredictionEnginePool : IDisposable where TData : class where TPrediction : class, new() { @@ -23,6 +23,7 @@ public class PredictionEnginePool private readonly IServiceProvider _serviceProvider; private readonly PoolLoader _defaultEnginePool; private readonly ConcurrentDictionary> _namedPools; + private bool _disposed; public PredictionEnginePool(IServiceProvider serviceProvider, IOptions mlContextOptions, @@ -85,9 +86,14 @@ public PredictionEngine GetPredictionEngine() /// public PredictionEngine GetPredictionEngine(string modelName) { + if (_disposed) + { + throw new ObjectDisposedException(nameof(PredictionEnginePool)); + } + if (_namedPools.TryGetValue(modelName, out var existingPool)) { - return existingPool.PredictionEnginePool.Get(); + return existingPool.Get(); } //This is the case where someone has used string.Empty to get the default model. @@ -100,11 +106,11 @@ public PredictionEngine GetPredictionEngine(string modelName throw new ArgumentException("You need to configure a default, not named, model before you use this method."); } - return _defaultEnginePool.PredictionEnginePool.Get(); + return _defaultEnginePool.Get(); } var pool = AddPool(modelName); - return pool.PredictionEnginePool.Get(); + return pool.Get(); } private PoolLoader AddPool(string modelName) @@ -141,14 +147,52 @@ public void ReturnPredictionEngine(string modelName, PredictionEngine + /// Disposes the pooled prediction engines and releases the file/uri watchers backing each + /// model loader. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + + _defaultEnginePool?.Dispose(); + foreach (var pool in _namedPools.Values) + { + pool.Dispose(); } + _namedPools.Clear(); } } } diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs b/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs index e0557ada71..cf16f170a0 100644 --- a/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs +++ b/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs @@ -34,7 +34,9 @@ public PredictionEnginePoolPolicy(MLContext mlContext, ITransformer model) /// public override PredictionEngine Create() => - _mlContext.Model.CreatePredictionEngine(_model); + _mlContext.Model.CreatePredictionEngine( + _model, + new PredictionEngineOptions { OwnsTransformer = false }); /// public override bool Return(PredictionEngine obj) => true; diff --git a/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs b/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs index ac77ee9e4f..3676e98e99 100644 --- a/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs +++ b/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs @@ -9,6 +9,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; +using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.TestFramework; using Microsoft.ML.TestFrameworkCommon; @@ -41,6 +42,134 @@ public void can_load_namedmodel() Assert.NotNull(model); } + [Fact] + public void pool_serves_predictions_across_a_hot_swap() + { + var loader = CreateReloadableLoader(); + using var pool = CreatePool(loader); + + var inFlight = pool.GetPredictionEngine(); + Assert.NotNull(inFlight.Predict(new SentimentData { SentimentText = "great" })); + + loader.Reload(); + + var afterSwap = pool.GetPredictionEngine(); + Assert.NotNull(afterSwap.Predict(new SentimentData { SentimentText = "terrible" })); + pool.ReturnPredictionEngine(afterSwap); + + pool.ReturnPredictionEngine(inFlight); + + var reused = pool.GetPredictionEngine(); + Assert.NotSame(inFlight, reused); + Assert.NotNull(reused.Predict(new SentimentData { SentimentText = "fine" })); + pool.ReturnPredictionEngine(reused); + } + + [Fact] + public void disposing_pool_releases_loader_resources() + { + var loader = CreateReloadableLoader(); + var pool = CreatePool(loader); + + _ = pool.GetPredictionEngine(); + pool.Dispose(); + + loader.Reload(); + pool.Dispose(); + } + + [Fact] + public void pooled_engines_do_not_dispose_the_shared_model() + { + var context = new MLContext(seed: 1); + using var stream = File.OpenRead(Path.Combine("TestModels", "SentimentModel.zip")); + var innerModel = context.Model.Load(stream, out _); + var model = new DisposeCountingTransformer(innerModel); + var loader = new ReloadableModelLoader(model); + using var pool = CreatePool(loader); + + var maximumRetained = Environment.ProcessorCount * 2; + var rented = new PredictionEngine[maximumRetained + 2]; + for (var i = 0; i < rented.Length; i++) + { + rented[i] = pool.GetPredictionEngine(); + } + + foreach (var engine in rented) + { + pool.ReturnPredictionEngine(engine); + } + + Assert.Equal(0, model.DisposeCount); + + var afterOverflow = pool.GetPredictionEngine(); + Assert.NotNull(afterOverflow.Predict(new SentimentData { SentimentText = "still works" })); + pool.ReturnPredictionEngine(afterOverflow); + } + + private static ReloadableModelLoader CreateReloadableLoader() + { + var context = new MLContext(seed: 1); + using var stream = File.OpenRead(Path.Combine("TestModels", "SentimentModel.zip")); + var model = context.Model.Load(stream, out _); + return new ReloadableModelLoader(model); + } + + private static PredictionEnginePool CreatePool(ModelLoader loader) + { + var services = new ServiceCollection().AddOptions().AddLogging(); + services.AddPredictionEnginePool(); + services.Configure>( + string.Empty, o => o.ModelLoader = loader); + + var sp = services.BuildServiceProvider(); + return sp.GetRequiredService>(); + } + + private sealed class ReloadableModelLoader : ModelLoader + { + private readonly ITransformer _model; + private ModelReloadToken _token = new ModelReloadToken(); + + public ReloadableModelLoader(ITransformer model) => _model = model; + + public override IChangeToken GetReloadToken() => _token; + + public override ITransformer GetModel() => _model; + + public void Reload() + { + var previous = Interlocked.Exchange(ref _token, new ModelReloadToken()); + previous.OnReload(); + } + } + + private sealed class DisposeCountingTransformer : ITransformer, IDisposable + { + private readonly ITransformer _inner; + private int _disposeCount; + + public DisposeCountingTransformer(ITransformer inner) => _inner = inner; + + public int DisposeCount => _disposeCount; + + public bool IsRowToRowMapper => _inner.IsRowToRowMapper; + + public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => _inner.GetOutputSchema(inputSchema); + + public IDataView Transform(IDataView input) => _inner.Transform(input); + + public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => _inner.GetRowToRowMapper(inputSchema); + + public void Save(ModelSaveContext ctx) => _inner.Save(ctx); + + public void Dispose() + { + Interlocked.Increment(ref _disposeCount); + (_inner as IDisposable)?.Dispose(); + } + } + public class SentimentData { [ColumnName("Label"), LoadColumn(0)] From 57cf192605be7b7e2be47cba1a767b0a4e165ec4 Mon Sep 17 00:00:00 2001 From: Milos Kotlar Date: Wed, 17 Jun 2026 15:28:05 +0200 Subject: [PATCH 2/2] Use CAS/volatile for disposal flag in PoolLoader and PredictionEnginePool Make the _disposed flag an int updated via Interlocked.Exchange so Dispose runs exactly once, and read it with Volatile.Read on all hot paths to avoid races between rent/return, reload callbacks, and disposal. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Microsoft.Extensions.ML/PoolLoader.cs | 11 +++++------ src/Microsoft.Extensions.ML/PredictionEnginePool.cs | 11 +++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/Microsoft.Extensions.ML/PoolLoader.cs b/src/Microsoft.Extensions.ML/PoolLoader.cs index 1263a68f61..f678b943f5 100644 --- a/src/Microsoft.Extensions.ML/PoolLoader.cs +++ b/src/Microsoft.Extensions.ML/PoolLoader.cs @@ -27,7 +27,7 @@ internal class PoolLoader : IDisposable private readonly IDisposable _changeTokenRegistration; private readonly ConditionalWeakTable, ObjectPool>> _rentedEngines; - private bool _disposed; + private int _disposed; public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions poolOptions) { @@ -60,7 +60,7 @@ public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions Get() { var pool = Volatile.Read(ref _pool); - if (_disposed || pool == null) + if (Volatile.Read(ref _disposed) != 0 || pool == null) { throw new ObjectDisposedException(nameof(PoolLoader)); } @@ -96,7 +96,7 @@ public void Return(PredictionEngine engine) private void LoadPool() { - if (_disposed) + if (Volatile.Read(ref _disposed) != 0) { return; } @@ -110,7 +110,7 @@ private void LoadPool() DisposeGeneration(oldPool, oldModel, model); - if (_disposed && Interlocked.CompareExchange(ref _pool, null, newPool) == newPool) + if (Volatile.Read(ref _disposed) != 0 && Interlocked.CompareExchange(ref _pool, null, newPool) == newPool) { DisposeGeneration(newPool, model, null); } @@ -131,12 +131,11 @@ private static void DisposeGeneration( public void Dispose() { - if (_disposed) + if (Interlocked.Exchange(ref _disposed, 1) != 0) { return; } - _disposed = true; _changeTokenRegistration?.Dispose(); var pool = Interlocked.Exchange(ref _pool, null); diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs index 9fb2c18c9e..e1c8fa0c07 100644 --- a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs +++ b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Threading; using Microsoft.Extensions.Options; using Microsoft.ML; @@ -23,7 +24,7 @@ public class PredictionEnginePool : IDisposable private readonly IServiceProvider _serviceProvider; private readonly PoolLoader _defaultEnginePool; private readonly ConcurrentDictionary> _namedPools; - private bool _disposed; + private int _disposed; public PredictionEnginePool(IServiceProvider serviceProvider, IOptions mlContextOptions, @@ -86,7 +87,7 @@ public PredictionEngine GetPredictionEngine() /// public PredictionEngine GetPredictionEngine(string modelName) { - if (_disposed) + if (Volatile.Read(ref _disposed) != 0) { throw new ObjectDisposedException(nameof(PredictionEnginePool)); } @@ -147,7 +148,7 @@ public void ReturnPredictionEngine(string modelName, PredictionEngine public void Dispose() { - if (_disposed) + if (Interlocked.Exchange(ref _disposed, 1) != 0) { return; } - _disposed = true; - _defaultEnginePool?.Dispose(); foreach (var pool in _namedPools.Values) {