diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationManager.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationManager.cs
index d012d786b10..cc3082a097c 100644
--- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationManager.cs
+++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationManager.cs
@@ -25,7 +25,7 @@ public interface IOperationManager
/// Returns true if the
/// was accepted and registered for execution.
///
- bool Enqueue(string sessionId, GraphQLRequest request);
+ bool Start(string sessionId, GraphQLRequest request);
///
/// Completes a request that was previously enqueued with the operation manager.
diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationSession.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationSession.cs
index 24de59f3e6c..a308760be2c 100644
--- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationSession.cs
+++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/IOperationSession.cs
@@ -9,11 +9,6 @@ namespace HotChocolate.AspNetCore.Subscriptions;
///
public interface IOperationSession : IDisposable
{
- ///
- /// An event that indicates that the underlying subscription has completed.
- ///
- event EventHandler? Completed;
-
///
/// Gets the subscription id that the client has provided.
///
@@ -27,7 +22,22 @@ public interface IOperationSession : IDisposable
///
/// Starts executing the operation.
///
- /// The graphql request.
- /// The cancellation token.
- void BeginExecute(GraphQLRequest request, CancellationToken cancellationToken);
+ ///
+ /// The graphql request.
+ ///
+ ///
+ /// The completion handler that will be called when the operation is completed.
+ ///
+ ///
+ /// The cancellation token.
+ ///
+ void BeginExecute(
+ GraphQLRequest request,
+ IOperationSessionCompletionHandler completion,
+ CancellationToken cancellationToken);
+}
+
+public interface IOperationSessionCompletionHandler
+{
+ void Complete(IOperationSession session);
}
diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationManager.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationManager.cs
index 9196b135c58..31025b26c68 100644
--- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationManager.cs
+++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationManager.cs
@@ -1,4 +1,5 @@
using System.Collections;
+using System.Collections.Concurrent;
using HotChocolate.Language;
using Microsoft.Extensions.DependencyInjection;
using static HotChocolate.AspNetCore.Properties.AspNetCoreResources;
@@ -12,8 +13,8 @@ namespace HotChocolate.AspNetCore.Subscriptions;
///
public sealed class OperationManager : IOperationManager
{
- private readonly ReaderWriterLockSlim _lock = new();
- private readonly Dictionary _subs = new();
+ private readonly ConcurrentDictionary _subs = new();
+ private readonly OperationSessionCompletionHandler _completion;
private readonly CancellationTokenSource _cts;
private readonly CancellationToken _cancellationToken;
private readonly ISocketSession _socketSession;
@@ -35,6 +36,7 @@ public OperationManager(
_errorHandler = executor.Services.GetRequiredService();
_cts = new CancellationTokenSource();
_cancellationToken = _cts.Token;
+ _completion = new OperationSessionCompletionHandler(_subs);
}
internal OperationManager(
@@ -50,10 +52,11 @@ internal OperationManager(
_errorHandler = executor.Services.GetRequiredService();
_cts = new CancellationTokenSource();
_cancellationToken = _cts.Token;
+ _completion = new OperationSessionCompletionHandler(_subs);
}
///
- public bool Enqueue(string sessionId, GraphQLRequest request)
+ public bool Start(string sessionId, GraphQLRequest request)
{
if (string.IsNullOrEmpty(sessionId))
{
@@ -72,30 +75,18 @@ public bool Enqueue(string sessionId, GraphQLRequest request)
throw new ObjectDisposedException(nameof(OperationManager));
}
- IOperationSession? session = null;
- _lock.EnterWriteLock();
+ var context = new StartSessionContext(
+ _createSession,
+ _completion,
+ request,
+ _cancellationToken);
- try
- {
- if(!_subs.ContainsKey(sessionId))
- {
- session = _createSession(sessionId);
- _subs.Add(sessionId, session);
- }
- }
- finally
- {
- _lock.ExitWriteLock();
- }
+ _subs.GetOrAdd(
+ sessionId,
+ static (key, ctx) => ctx.CreateSession(key),
+ context);
- if (session is not null)
- {
- session.Completed += (_, _) => Complete(sessionId);
- session.BeginExecute(request, _cancellationToken);
- return true;
- }
-
- return false;
+ return context.IsNewSession;
}
///
@@ -113,20 +104,10 @@ public bool Complete(string sessionId)
throw new ObjectDisposedException(nameof(OperationManager));
}
- _lock.EnterWriteLock();
-
- try
- {
- if (_subs.TryGetValue(sessionId, out var session))
- {
- _subs.Remove(sessionId);
- session.Dispose();
- return true;
- }
- }
- finally
+ if(_subs.TryRemove(sessionId, out var session))
{
- _lock.ExitWriteLock();
+ session.Dispose();
+ return true;
}
return false;
@@ -149,24 +130,37 @@ public void Dispose()
///
public IEnumerator GetEnumerator()
+ => _subs.Values.GetEnumerator();
+
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+
+ private sealed class StartSessionContext(
+ Func createSession,
+ IOperationSessionCompletionHandler completion,
+ GraphQLRequest request,
+ CancellationToken cancellationToken)
{
- _lock.EnterReadLock();
- IOperationSession[] items;
+ public bool IsNewSession { get; private set; }
- try
- {
- items = _subs.Values.ToArray();
- }
- finally
+ public IOperationSession CreateSession(string sessionId)
{
- _lock.ExitReadLock();
+ IsNewSession = true;
+ var session = createSession(sessionId);
+ session.BeginExecute(request, completion, cancellationToken);
+ return session;
}
+ }
- foreach (var session in items)
+ private sealed class OperationSessionCompletionHandler(
+ ConcurrentDictionary subs)
+ : IOperationSessionCompletionHandler
+ {
+ public void Complete(IOperationSession session)
{
- yield return session;
+ if (subs.TryRemove(session.Id, out _))
+ {
+ session.Dispose();
+ }
}
}
-
- IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationSession.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationSession.cs
index 2ff5521dce7..495db0e8534 100644
--- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationSession.cs
+++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/OperationSession.cs
@@ -13,8 +13,6 @@ internal sealed class OperationSession : IOperationSession
private readonly IRequestExecutor _executor;
private bool _disposed;
- public event EventHandler? Completed;
-
public OperationSession(
ISocketSession session,
ISocketSessionInterceptor interceptor,
@@ -34,10 +32,16 @@ public OperationSession(
public bool IsCompleted { get; private set; }
- public void BeginExecute(GraphQLRequest request, CancellationToken cancellationToken)
- => SendResultsAsync(request, cancellationToken).FireAndForget();
+ public void BeginExecute(
+ GraphQLRequest request,
+ IOperationSessionCompletionHandler completion,
+ CancellationToken cancellationToken)
+ => SendResultsAsync(request, completion, cancellationToken).FireAndForget();
- private async Task SendResultsAsync(GraphQLRequest request, CancellationToken cancellationToken)
+ private async Task SendResultsAsync(
+ GraphQLRequest request,
+ IOperationSessionCompletionHandler completion,
+ CancellationToken cancellationToken)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _ct);
var ct = cts.Token;
@@ -51,32 +55,28 @@ private async Task SendResultsAsync(GraphQLRequest request, CancellationToken ca
switch (result)
{
- case IOperationResult queryResult:
- if (queryResult.Data is null && queryResult.Errors is { Count: > 0, })
+ case IOperationResult singleResult:
+ if (singleResult.Data is null && singleResult.Errors is { Count: > 0, })
{
- await _session.Protocol.SendErrorMessageAsync(
- _session,
- Id,
- queryResult.Errors,
- ct);
+ await _session.Protocol.SendErrorMessageAsync(_session, Id, singleResult.Errors, ct).ConfigureAwait(false);
}
else
{
- await SendResultMessageAsync(queryResult, ct);
+ await SendResultMessageAsync(singleResult, ct).ConfigureAwait(false);
}
break;
case IResponseStream responseStream:
- await foreach (var item in responseStream.ReadResultsAsync().WithCancellation(ct))
+ await foreach (var item in responseStream.ReadResultsAsync().WithCancellation(ct).ConfigureAwait(false))
{
try
{
// use original cancellation token here to keep the websocket open for other streams.
- await SendResultMessageAsync(item, cancellationToken);
+ await SendResultMessageAsync(item, cancellationToken).ConfigureAwait(false);
}
finally
{
- await item.DisposeAsync();
+ await item.DisposeAsync().ConfigureAwait(false);
}
}
break;
@@ -94,7 +94,7 @@ await _session.Protocol.SendErrorMessageAsync(
}
catch (OperationCanceledException) when (ct.IsCancellationRequested)
{
- // the operation was cancelled so we do nothings
+ // the operation was cancelled so we do nothing
}
catch (Exception ex)
{
@@ -121,7 +121,8 @@ await _session.Protocol.SendErrorMessageAsync(
}
// signal that the subscription is completed.
- Complete();
+ completion.Complete(this);
+ IsCompleted = true;
}
}
@@ -180,7 +181,7 @@ private async Task TrySendErrorMessageAsync(Exception exception, CancellationTok
var errors =
error is AggregateError aggregateError
? aggregateError.Errors
- : new[] { error, };
+ : [error];
await _session.Protocol.SendErrorMessageAsync(_session, Id, errors, ct);
}
@@ -192,19 +193,6 @@ error is AggregateError aggregateError
}
}
- private void Complete()
- {
- try
- {
- IsCompleted = true;
- Completed?.Invoke(this, EventArgs.Empty);
- }
- catch
- {
- // we ignore any error that might happen on invoking complete.
- }
- }
-
public void Dispose()
{
if (!_disposed)
diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/Apollo/ApolloSubscriptionProtocolHandler.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/Apollo/ApolloSubscriptionProtocolHandler.cs
index 07f68a04f2a..6385761963f 100644
--- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/Apollo/ApolloSubscriptionProtocolHandler.cs
+++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/Apollo/ApolloSubscriptionProtocolHandler.cs
@@ -141,7 +141,7 @@ await connection.CloseAsync(
return;
}
- if (!session.Operations.Enqueue(dataStartMessage.Id, dataStartMessage.Payload))
+ if (!session.Operations.Start(dataStartMessage.Id, dataStartMessage.Payload))
{
await connection.CloseAsync(
Apollo_OnReceive_SubscriptionIdNotUnique,
diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs
index 57e1754a76f..70bc9ffae40 100644
--- a/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs
+++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs
@@ -142,7 +142,7 @@ await SendConnectionAcceptMessage(
return;
}
- if (!session.Operations.Enqueue(subscribeMessage.Id, subscribeMessage.Payload))
+ if (!session.Operations.Start(subscribeMessage.Id, subscribeMessage.Payload))
{
await connection.CloseSubscriptionIdNotUniqueAsync(cancellationToken);
return;
diff --git a/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/OperationManagerTests.cs b/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/OperationManagerTests.cs
index ca999527f88..12111bc6903 100644
--- a/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/OperationManagerTests.cs
+++ b/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/OperationManagerTests.cs
@@ -25,7 +25,7 @@ public async Task Enqueue_SessionId_Is_Null()
var subscriptions = new OperationManager(session.Object, interceptor.Object, executor);
// act
- void Action() => subscriptions.Enqueue(null!, new GraphQLRequest(null, queryId: "123"));
+ void Action() => subscriptions.Start(null!, new GraphQLRequest(null, queryId: "123"));
// assert
Assert.Equal(
@@ -48,7 +48,7 @@ public async Task Enqueue_SessionId_Is_Empty()
var subscriptions = new OperationManager(session.Object, interceptor.Object, executor);
// act
- void Action() => subscriptions.Enqueue("", new GraphQLRequest(null, queryId: "123"));
+ void Action() => subscriptions.Start("", new GraphQLRequest(null, queryId: "123"));
// assert
Assert.Equal(
@@ -71,7 +71,7 @@ public async Task Enqueue_Request_Is_Null()
var subscriptions = new OperationManager(session.Object, interceptor.Object, executor);
// act
- void Action() => subscriptions.Enqueue("abc", null!);
+ void Action() => subscriptions.Start("abc", null!);
// assert
Assert.Equal(
@@ -106,7 +106,7 @@ public async Task Enqueue_On_Disposed_Manager()
subscriptions.Dispose();
// act
- void Fail() => subscriptions.Enqueue("abc", request);
+ void Fail() => subscriptions.Start("abc", request);
// assert
Assert.Throws(Fail);
@@ -138,7 +138,7 @@ public async Task Enqueue_Request()
var request = new GraphQLRequest(query);
// act
- var success = subscriptions.Enqueue("abc", request);
+ var success = subscriptions.Start("abc", request);
var registered = subscriptions.ToArray();
// assert
@@ -170,11 +170,11 @@ public async Task Enqueue_Request_With_Non_Unique_Id()
var query = Utf8GraphQLParser.Parse(
"subscription { onReview(episode: NEW_HOPE) { stars } }");
var request = new GraphQLRequest(query);
- var success1 = subscriptions.Enqueue("abc", request);
+ var success1 = subscriptions.Start("abc", request);
var registered1 = subscriptions.ToArray();
// act
- var success2 = subscriptions.Enqueue("abc", request);
+ var success2 = subscriptions.Start("abc", request);
var registered2 = subscriptions.ToArray();
// assert
@@ -208,7 +208,7 @@ public async Task Complete_Request()
var query = Utf8GraphQLParser.Parse(
"subscription { onReview(episode: NEW_HOPE) { stars } }");
var request = new GraphQLRequest(query);
- var success1 = subscriptions.Enqueue("abc", request);
+ var success1 = subscriptions.Start("abc", request);
var registered1 = subscriptions.ToArray();
// act
@@ -322,7 +322,7 @@ public async Task Dispose_OperationManager()
var query = Utf8GraphQLParser.Parse(
"subscription { onReview(episode: NEW_HOPE) { stars } }");
var request = new GraphQLRequest(query);
- var success = subscriptions.Enqueue("abc", request);
+ var success = subscriptions.Start("abc", request);
Assert.True(success);
// act
diff --git a/src/HotChocolate/Core/src/Execution/Processing/ISubscription.cs b/src/HotChocolate/Core/src/Execution/Processing/ISubscription.cs
index f4a4c168d29..da033aa6cce 100644
--- a/src/HotChocolate/Core/src/Execution/Processing/ISubscription.cs
+++ b/src/HotChocolate/Core/src/Execution/Processing/ISubscription.cs
@@ -11,7 +11,12 @@ public interface ISubscription
ulong Id { get; }
///
- /// The compiled subscription operation.
+ /// Gets the compiled subscription operation.
///
IOperation Operation { get; }
+
+ ///
+ /// Gets the global request state.
+ ///
+ IDictionary ContextData { get; }
}
diff --git a/src/HotChocolate/Core/src/Execution/Processing/SubscriptionExecutor.Subscription.cs b/src/HotChocolate/Core/src/Execution/Processing/SubscriptionExecutor.Subscription.cs
index 96dd8adead0..6bf0b0e5e8e 100644
--- a/src/HotChocolate/Core/src/Execution/Processing/SubscriptionExecutor.Subscription.cs
+++ b/src/HotChocolate/Core/src/Execution/Processing/SubscriptionExecutor.Subscription.cs
@@ -118,6 +118,9 @@ public IAsyncEnumerable ExecuteAsync()
///
public IOperation Operation => _requestContext.Operation!;
+ ///
+ public IDictionary ContextData => _requestContext.ContextData;
+
public async ValueTask DisposeAsync()
{
if (!_disposed)