From acc6c5af170aa9f0cbaf89f63d9b4317fc580cdc Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Fri, 7 Feb 2025 00:09:43 +0100 Subject: [PATCH 1/2] Reworked Operation Session Manager --- .../Subscriptions/IOperationManager.cs | 2 +- .../Subscriptions/IOperationSession.cs | 26 +++-- .../Subscriptions/OperationManager.cs | 94 +++++++++---------- .../Subscriptions/OperationSession.cs | 52 ++++------ .../ApolloSubscriptionProtocolHandler.cs | 2 +- .../GraphQLOverWebSocketProtocolHandler.cs | 2 +- .../Subscriptions/OperationManagerTests.cs | 18 ++-- 7 files changed, 94 insertions(+), 102 deletions(-) diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationManager.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationManager.cs index d012d786b10..cc3082a097c 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationManager.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationManager.cs @@ -25,7 +25,7 @@ public interface IOperationManager /// Returns true if the /// was accepted and registered for execution. /// - bool Enqueue(string sessionId, GraphQLRequest request); + bool Start(string sessionId, GraphQLRequest request); /// /// Completes a request that was previously enqueued with the operation manager. diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationSession.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationSession.cs index 24de59f3e6c..a308760be2c 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationSession.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationSession.cs @@ -9,11 +9,6 @@ namespace HotChocolate.AspNetCore.Subscriptions; /// public interface IOperationSession : IDisposable { - /// - /// An event that indicates that the underlying subscription has completed. - /// - event EventHandler? Completed; - /// /// Gets the subscription id that the client has provided. /// @@ -27,7 +22,22 @@ public interface IOperationSession : IDisposable /// /// Starts executing the operation. /// - /// The graphql request. - /// The cancellation token. - void BeginExecute(GraphQLRequest request, CancellationToken cancellationToken); + /// + /// The graphql request. + /// + /// + /// The completion handler that will be called when the operation is completed. + /// + /// + /// The cancellation token. + /// + void BeginExecute( + GraphQLRequest request, + IOperationSessionCompletionHandler completion, + CancellationToken cancellationToken); +} + +public interface IOperationSessionCompletionHandler +{ + void Complete(IOperationSession session); } diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationManager.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationManager.cs index 9196b135c58..31025b26c68 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationManager.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationManager.cs @@ -1,4 +1,5 @@ using System.Collections; +using System.Collections.Concurrent; using HotChocolate.Language; using Microsoft.Extensions.DependencyInjection; using static HotChocolate.AspNetCore.Properties.AspNetCoreResources; @@ -12,8 +13,8 @@ namespace HotChocolate.AspNetCore.Subscriptions; /// public sealed class OperationManager : IOperationManager { - private readonly ReaderWriterLockSlim _lock = new(); - private readonly Dictionary _subs = new(); + private readonly ConcurrentDictionary _subs = new(); + private readonly OperationSessionCompletionHandler _completion; private readonly CancellationTokenSource _cts; private readonly CancellationToken _cancellationToken; private readonly ISocketSession _socketSession; @@ -35,6 +36,7 @@ public OperationManager( _errorHandler = executor.Services.GetRequiredService(); _cts = new CancellationTokenSource(); _cancellationToken = _cts.Token; + _completion = new OperationSessionCompletionHandler(_subs); } internal OperationManager( @@ -50,10 +52,11 @@ internal OperationManager( _errorHandler = executor.Services.GetRequiredService(); _cts = new CancellationTokenSource(); _cancellationToken = _cts.Token; + _completion = new OperationSessionCompletionHandler(_subs); } /// - public bool Enqueue(string sessionId, GraphQLRequest request) + public bool Start(string sessionId, GraphQLRequest request) { if (string.IsNullOrEmpty(sessionId)) { @@ -72,30 +75,18 @@ public bool Enqueue(string sessionId, GraphQLRequest request) throw new ObjectDisposedException(nameof(OperationManager)); } - IOperationSession? session = null; - _lock.EnterWriteLock(); + var context = new StartSessionContext( + _createSession, + _completion, + request, + _cancellationToken); - try - { - if(!_subs.ContainsKey(sessionId)) - { - session = _createSession(sessionId); - _subs.Add(sessionId, session); - } - } - finally - { - _lock.ExitWriteLock(); - } + _subs.GetOrAdd( + sessionId, + static (key, ctx) => ctx.CreateSession(key), + context); - if (session is not null) - { - session.Completed += (_, _) => Complete(sessionId); - session.BeginExecute(request, _cancellationToken); - return true; - } - - return false; + return context.IsNewSession; } /// @@ -113,20 +104,10 @@ public bool Complete(string sessionId) throw new ObjectDisposedException(nameof(OperationManager)); } - _lock.EnterWriteLock(); - - try - { - if (_subs.TryGetValue(sessionId, out var session)) - { - _subs.Remove(sessionId); - session.Dispose(); - return true; - } - } - finally + if(_subs.TryRemove(sessionId, out var session)) { - _lock.ExitWriteLock(); + session.Dispose(); + return true; } return false; @@ -149,24 +130,37 @@ public void Dispose() /// public IEnumerator GetEnumerator() + => _subs.Values.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + private sealed class StartSessionContext( + Func createSession, + IOperationSessionCompletionHandler completion, + GraphQLRequest request, + CancellationToken cancellationToken) { - _lock.EnterReadLock(); - IOperationSession[] items; + public bool IsNewSession { get; private set; } - try - { - items = _subs.Values.ToArray(); - } - finally + public IOperationSession CreateSession(string sessionId) { - _lock.ExitReadLock(); + IsNewSession = true; + var session = createSession(sessionId); + session.BeginExecute(request, completion, cancellationToken); + return session; } + } - foreach (var session in items) + private sealed class OperationSessionCompletionHandler( + ConcurrentDictionary subs) + : IOperationSessionCompletionHandler + { + public void Complete(IOperationSession session) { - yield return session; + if (subs.TryRemove(session.Id, out _)) + { + session.Dispose(); + } } } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); } diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationSession.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationSession.cs index 2ff5521dce7..495db0e8534 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationSession.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationSession.cs @@ -13,8 +13,6 @@ internal sealed class OperationSession : IOperationSession private readonly IRequestExecutor _executor; private bool _disposed; - public event EventHandler? Completed; - public OperationSession( ISocketSession session, ISocketSessionInterceptor interceptor, @@ -34,10 +32,16 @@ public OperationSession( public bool IsCompleted { get; private set; } - public void BeginExecute(GraphQLRequest request, CancellationToken cancellationToken) - => SendResultsAsync(request, cancellationToken).FireAndForget(); + public void BeginExecute( + GraphQLRequest request, + IOperationSessionCompletionHandler completion, + CancellationToken cancellationToken) + => SendResultsAsync(request, completion, cancellationToken).FireAndForget(); - private async Task SendResultsAsync(GraphQLRequest request, CancellationToken cancellationToken) + private async Task SendResultsAsync( + GraphQLRequest request, + IOperationSessionCompletionHandler completion, + CancellationToken cancellationToken) { using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _ct); var ct = cts.Token; @@ -51,32 +55,28 @@ private async Task SendResultsAsync(GraphQLRequest request, CancellationToken ca switch (result) { - case IOperationResult queryResult: - if (queryResult.Data is null && queryResult.Errors is { Count: > 0, }) + case IOperationResult singleResult: + if (singleResult.Data is null && singleResult.Errors is { Count: > 0, }) { - await _session.Protocol.SendErrorMessageAsync( - _session, - Id, - queryResult.Errors, - ct); + await _session.Protocol.SendErrorMessageAsync(_session, Id, singleResult.Errors, ct).ConfigureAwait(false); } else { - await SendResultMessageAsync(queryResult, ct); + await SendResultMessageAsync(singleResult, ct).ConfigureAwait(false); } break; case IResponseStream responseStream: - await foreach (var item in responseStream.ReadResultsAsync().WithCancellation(ct)) + await foreach (var item in responseStream.ReadResultsAsync().WithCancellation(ct).ConfigureAwait(false)) { try { // use original cancellation token here to keep the websocket open for other streams. - await SendResultMessageAsync(item, cancellationToken); + await SendResultMessageAsync(item, cancellationToken).ConfigureAwait(false); } finally { - await item.DisposeAsync(); + await item.DisposeAsync().ConfigureAwait(false); } } break; @@ -94,7 +94,7 @@ await _session.Protocol.SendErrorMessageAsync( } catch (OperationCanceledException) when (ct.IsCancellationRequested) { - // the operation was cancelled so we do nothings + // the operation was cancelled so we do nothing } catch (Exception ex) { @@ -121,7 +121,8 @@ await _session.Protocol.SendErrorMessageAsync( } // signal that the subscription is completed. - Complete(); + completion.Complete(this); + IsCompleted = true; } } @@ -180,7 +181,7 @@ private async Task TrySendErrorMessageAsync(Exception exception, CancellationTok var errors = error is AggregateError aggregateError ? aggregateError.Errors - : new[] { error, }; + : [error]; await _session.Protocol.SendErrorMessageAsync(_session, Id, errors, ct); } @@ -192,19 +193,6 @@ error is AggregateError aggregateError } } - private void Complete() - { - try - { - IsCompleted = true; - Completed?.Invoke(this, EventArgs.Empty); - } - catch - { - // we ignore any error that might happen on invoking complete. - } - } - public void Dispose() { if (!_disposed) diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/Apollo/ApolloSubscriptionProtocolHandler.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/Apollo/ApolloSubscriptionProtocolHandler.cs index 07f68a04f2a..6385761963f 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/Apollo/ApolloSubscriptionProtocolHandler.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/Apollo/ApolloSubscriptionProtocolHandler.cs @@ -141,7 +141,7 @@ await connection.CloseAsync( return; } - if (!session.Operations.Enqueue(dataStartMessage.Id, dataStartMessage.Payload)) + if (!session.Operations.Start(dataStartMessage.Id, dataStartMessage.Payload)) { await connection.CloseAsync( Apollo_OnReceive_SubscriptionIdNotUnique, diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs index 57e1754a76f..70bc9ffae40 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs @@ -142,7 +142,7 @@ await SendConnectionAcceptMessage( return; } - if (!session.Operations.Enqueue(subscribeMessage.Id, subscribeMessage.Payload)) + if (!session.Operations.Start(subscribeMessage.Id, subscribeMessage.Payload)) { await connection.CloseSubscriptionIdNotUniqueAsync(cancellationToken); return; diff --git a/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/OperationManagerTests.cs b/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/OperationManagerTests.cs index ca999527f88..12111bc6903 100644 --- a/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/OperationManagerTests.cs +++ b/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/OperationManagerTests.cs @@ -25,7 +25,7 @@ public async Task Enqueue_SessionId_Is_Null() var subscriptions = new OperationManager(session.Object, interceptor.Object, executor); // act - void Action() => subscriptions.Enqueue(null!, new GraphQLRequest(null, queryId: "123")); + void Action() => subscriptions.Start(null!, new GraphQLRequest(null, queryId: "123")); // assert Assert.Equal( @@ -48,7 +48,7 @@ public async Task Enqueue_SessionId_Is_Empty() var subscriptions = new OperationManager(session.Object, interceptor.Object, executor); // act - void Action() => subscriptions.Enqueue("", new GraphQLRequest(null, queryId: "123")); + void Action() => subscriptions.Start("", new GraphQLRequest(null, queryId: "123")); // assert Assert.Equal( @@ -71,7 +71,7 @@ public async Task Enqueue_Request_Is_Null() var subscriptions = new OperationManager(session.Object, interceptor.Object, executor); // act - void Action() => subscriptions.Enqueue("abc", null!); + void Action() => subscriptions.Start("abc", null!); // assert Assert.Equal( @@ -106,7 +106,7 @@ public async Task Enqueue_On_Disposed_Manager() subscriptions.Dispose(); // act - void Fail() => subscriptions.Enqueue("abc", request); + void Fail() => subscriptions.Start("abc", request); // assert Assert.Throws(Fail); @@ -138,7 +138,7 @@ public async Task Enqueue_Request() var request = new GraphQLRequest(query); // act - var success = subscriptions.Enqueue("abc", request); + var success = subscriptions.Start("abc", request); var registered = subscriptions.ToArray(); // assert @@ -170,11 +170,11 @@ public async Task Enqueue_Request_With_Non_Unique_Id() var query = Utf8GraphQLParser.Parse( "subscription { onReview(episode: NEW_HOPE) { stars } }"); var request = new GraphQLRequest(query); - var success1 = subscriptions.Enqueue("abc", request); + var success1 = subscriptions.Start("abc", request); var registered1 = subscriptions.ToArray(); // act - var success2 = subscriptions.Enqueue("abc", request); + var success2 = subscriptions.Start("abc", request); var registered2 = subscriptions.ToArray(); // assert @@ -208,7 +208,7 @@ public async Task Complete_Request() var query = Utf8GraphQLParser.Parse( "subscription { onReview(episode: NEW_HOPE) { stars } }"); var request = new GraphQLRequest(query); - var success1 = subscriptions.Enqueue("abc", request); + var success1 = subscriptions.Start("abc", request); var registered1 = subscriptions.ToArray(); // act @@ -322,7 +322,7 @@ public async Task Dispose_OperationManager() var query = Utf8GraphQLParser.Parse( "subscription { onReview(episode: NEW_HOPE) { stars } }"); var request = new GraphQLRequest(query); - var success = subscriptions.Enqueue("abc", request); + var success = subscriptions.Start("abc", request); Assert.True(success); // act From cee8bfe040e7ceb7ee697bce69651671199f0268 Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Fri, 7 Feb 2025 00:15:56 +0100 Subject: [PATCH 2/2] Allow access to the subscription request context --- .../Core/src/Execution/Processing/ISubscription.cs | 7 ++++++- .../Processing/SubscriptionExecutor.Subscription.cs | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/HotChocolate/Core/src/Execution/Processing/ISubscription.cs b/src/HotChocolate/Core/src/Execution/Processing/ISubscription.cs index f4a4c168d29..da033aa6cce 100644 --- a/src/HotChocolate/Core/src/Execution/Processing/ISubscription.cs +++ b/src/HotChocolate/Core/src/Execution/Processing/ISubscription.cs @@ -11,7 +11,12 @@ public interface ISubscription ulong Id { get; } /// - /// The compiled subscription operation. + /// Gets the compiled subscription operation. /// IOperation Operation { get; } + + /// + /// Gets the global request state. + /// + IDictionary ContextData { get; } } diff --git a/src/HotChocolate/Core/src/Execution/Processing/SubscriptionExecutor.Subscription.cs b/src/HotChocolate/Core/src/Execution/Processing/SubscriptionExecutor.Subscription.cs index 96dd8adead0..6bf0b0e5e8e 100644 --- a/src/HotChocolate/Core/src/Execution/Processing/SubscriptionExecutor.Subscription.cs +++ b/src/HotChocolate/Core/src/Execution/Processing/SubscriptionExecutor.Subscription.cs @@ -118,6 +118,9 @@ public IAsyncEnumerable ExecuteAsync() /// public IOperation Operation => _requestContext.Operation!; + /// + public IDictionary ContextData => _requestContext.ContextData; + public async ValueTask DisposeAsync() { if (!_disposed)