Skip to content

Commit a6d42d2

Browse files
authored
CSHARP-5777: Avoid ThreadPool-dependent IO methods in sync API (#1805) (#1811)
1 parent 2c6ad30 commit a6d42d2

File tree

4 files changed

+172
-106
lines changed

4 files changed

+172
-106
lines changed

src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs

Lines changed: 102 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -36,44 +36,15 @@ public static void EfficientCopyTo(this Stream input, Stream output)
3636
}
3737
}
3838

39-
public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
40-
{
41-
try
42-
{
43-
using var manualResetEvent = new ManualResetEventSlim();
44-
var readOperation = stream.BeginRead(
45-
buffer,
46-
offset,
47-
count,
48-
state => ((ManualResetEventSlim)state.AsyncState).Set(),
49-
manualResetEvent);
50-
51-
if (readOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken))
52-
{
53-
return stream.EndRead(readOperation);
54-
}
55-
}
56-
catch (OperationCanceledException)
57-
{
58-
// Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed.
59-
}
60-
catch (ObjectDisposedException)
61-
{
62-
throw new IOException();
63-
}
64-
65-
try
66-
{
67-
stream.Dispose();
68-
}
69-
catch
70-
{
71-
// Ignore any exceptions
72-
}
73-
74-
cancellationToken.ThrowIfCancellationRequested();
75-
throw new TimeoutException();
76-
}
39+
public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) =>
40+
ExecuteOperationWithTimeout(
41+
stream,
42+
(str, state) => str.Read(state.Buffer, state.Offset, state.Count),
43+
buffer,
44+
offset,
45+
count,
46+
timeout,
47+
cancellationToken);
7748

7849
public static async Task<int> ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
7950
{
@@ -217,46 +188,19 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] destination,
217188
}
218189
}
219190

220-
public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
221-
{
222-
try
223-
{
224-
using var manualResetEvent = new ManualResetEventSlim();
225-
var writeOperation = stream.BeginWrite(
226-
buffer,
227-
offset,
228-
count,
229-
state => ((ManualResetEventSlim)state.AsyncState).Set(),
230-
manualResetEvent);
231-
232-
if (writeOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken))
191+
public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) =>
192+
ExecuteOperationWithTimeout(
193+
stream,
194+
(str, state) =>
233195
{
234-
stream.EndWrite(writeOperation);
235-
return;
236-
}
237-
}
238-
catch (OperationCanceledException)
239-
{
240-
// Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed.
241-
}
242-
catch (ObjectDisposedException)
243-
{
244-
// It's possible to get ObjectDisposedException when the connection pool was closed with interruptInUseConnections set to true.
245-
throw new IOException();
246-
}
247-
248-
try
249-
{
250-
stream.Dispose();
251-
}
252-
catch
253-
{
254-
// Ignore any exceptions
255-
}
256-
257-
cancellationToken.ThrowIfCancellationRequested();
258-
throw new TimeoutException();
259-
}
196+
str.Write(state.Buffer, state.Offset, state.Count);
197+
return true;
198+
},
199+
buffer,
200+
offset,
201+
count,
202+
timeout,
203+
cancellationToken);
260204

261205
public static async Task WriteAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
262206
{
@@ -325,5 +269,86 @@ public static async Task WriteBytesAsync(this Stream stream, OperationContext op
325269
count -= bytesToWrite;
326270
}
327271
}
272+
273+
private static TResult ExecuteOperationWithTimeout<TResult>(Stream stream, Func<Stream, (byte[] Buffer, int Offset, int Count), TResult> operation, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
274+
{
275+
StreamDisposeCallbackState callbackState = null;
276+
Timer timer = null;
277+
CancellationTokenRegistration cancellationSubscription = default;
278+
if (timeout != Timeout.InfiniteTimeSpan)
279+
{
280+
callbackState = new StreamDisposeCallbackState(stream);
281+
timer = new Timer(DisposeStreamCallback, callbackState, timeout, Timeout.InfiniteTimeSpan);
282+
}
283+
284+
if (cancellationToken.CanBeCanceled)
285+
{
286+
callbackState ??= new StreamDisposeCallbackState(stream);
287+
cancellationSubscription = cancellationToken.Register(DisposeStreamCallback, callbackState);
288+
}
289+
290+
try
291+
{
292+
var result = operation(stream, (buffer, offset, count));
293+
if (callbackState?.TryChangeStateFromInProgress(OperationState.Done) == false)
294+
{
295+
// if cannot change the state - then the stream was/will be disposed, throw here
296+
throw new IOException();
297+
}
298+
299+
return result;
300+
}
301+
catch (IOException)
302+
{
303+
if (callbackState?.OperationState == OperationState.Interrupted)
304+
{
305+
cancellationToken.ThrowIfCancellationRequested();
306+
throw new TimeoutException();
307+
}
308+
309+
throw;
310+
}
311+
finally
312+
{
313+
timer?.Dispose();
314+
cancellationSubscription.Dispose();
315+
}
316+
317+
static void DisposeStreamCallback(object state)
318+
{
319+
var disposeCallbackState = (StreamDisposeCallbackState)state;
320+
if (!disposeCallbackState.TryChangeStateFromInProgress(OperationState.Interrupted))
321+
{
322+
// If the state can't be changed - then I/O had already succeeded
323+
return;
324+
}
325+
326+
try
327+
{
328+
disposeCallbackState.Stream.Dispose();
329+
}
330+
catch (Exception)
331+
{
332+
// callbacks should not fail, suppress any exceptions here
333+
}
334+
}
335+
}
336+
337+
private record StreamDisposeCallbackState(Stream Stream)
338+
{
339+
private int _operationState = 0;
340+
341+
public OperationState OperationState => (OperationState)_operationState;
342+
343+
public bool TryChangeStateFromInProgress(OperationState newState) =>
344+
Interlocked.CompareExchange(ref _operationState, (int)newState, (int)OperationState.InProgress) == (int)OperationState.InProgress;
345+
}
346+
347+
private enum OperationState
348+
{
349+
InProgress = 0,
350+
Done,
351+
Interrupted,
352+
}
328353
}
329354
}

tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -713,19 +713,8 @@ public async Task SendMessage_should_put_the_message_on_the_stream_and_raise_the
713713

714714
private void SetupStreamRead(Mock<Stream> streamMock, TaskCompletionSource<int> tcs)
715715
{
716-
streamMock.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
717-
.Returns((byte[] _, int __, int ___, AsyncCallback callback, object state) =>
718-
{
719-
var innerTcs = new TaskCompletionSource<int>(state);
720-
tcs.Task.ContinueWith(t =>
721-
{
722-
innerTcs.TrySetException(t.Exception.InnerException);
723-
callback(innerTcs.Task);
724-
});
725-
return innerTcs.Task;
726-
});
727-
streamMock.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
728-
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
716+
streamMock.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
717+
.Returns((byte[] _, int __, int ___) => tcs.Task.GetAwaiter().GetResult());
729718
streamMock.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
730719
.Returns(tcs.Task);
731720
streamMock.Setup(s => s.Close()).Callback(() => tcs.TrySetException(new ObjectDisposedException("stream")));

tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2013-present MongoDB Inc.
1+
/* Copyright 2010-present MongoDB Inc.
22
*
33
* Licensed under the Apache License, Version 2.0 (the "License");
44
* you may not use this file except in compliance with the License.
@@ -90,20 +90,18 @@ public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_part
9090
var bytes = new byte[] { 1, 2, 3 };
9191
var n = 0;
9292
var position = 0;
93-
Task<int> ReadPartial (byte[] buffer, int offset, int count)
93+
int ReadPartial (byte[] buffer, int offset, int count)
9494
{
9595
var length = partition[n++];
9696
Buffer.BlockCopy(bytes, position, buffer, offset, length);
9797
position += length;
98-
return Task.FromResult(length);
98+
return length;
9999
}
100100

101101
mockStream.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
102-
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count));
103-
mockStream.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
104-
.Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count));
105-
mockStream.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
106-
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
102+
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count)));
103+
mockStream.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
104+
.Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count));
107105
var destination = new byte[3];
108106

109107
if (async)
@@ -203,6 +201,49 @@ await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeo
203201
.ParamName.Should().Be("stream");
204202
}
205203

204+
[Theory]
205+
[ParameterAttributeData]
206+
public async Task ReadBytes_with_byte_array_throws_on_timeout([Values(true, false)]bool async)
207+
{
208+
var streamMock = new Mock<Stream>();
209+
SetupStreamRead(streamMock);
210+
var stream = streamMock.Object;
211+
212+
var destination = new byte[2];
213+
var timeout = TimeSpan.FromMilliseconds(10);
214+
215+
var exception = async ?
216+
await Record.ExceptionAsync(() => stream.ReadAsync(destination, 0, 2, timeout, CancellationToken.None)) :
217+
Record.Exception(() => stream.Read(destination, 0, 2, timeout, CancellationToken.None));
218+
219+
exception.Should().BeOfType<TimeoutException>();
220+
}
221+
222+
[Theory]
223+
[ParameterAttributeData]
224+
public async Task ReadBytes_with_byte_array_throws_on_cancellation([Values(true, false)]bool async)
225+
{
226+
var streamMock = new Mock<Stream>();
227+
SetupStreamRead(streamMock);
228+
var stream = streamMock.Object;
229+
230+
var destination = new byte[2];
231+
using var cancellationTokenSource = new CancellationTokenSource(10);
232+
233+
var exception = async ?
234+
await Record.ExceptionAsync(() => stream.ReadAsync(destination, 0, 2, Timeout.InfiniteTimeSpan, cancellationTokenSource.Token)) :
235+
Record.Exception(() => stream.Read(destination, 0, 2, Timeout.InfiniteTimeSpan, cancellationTokenSource.Token));
236+
237+
if (async)
238+
{
239+
exception.Should().BeOfType<TaskCanceledException>();
240+
}
241+
else
242+
{
243+
exception.Should().BeOfType<OperationCanceledException>();
244+
}
245+
}
246+
206247
[Theory]
207248
[InlineData(true, 0, new byte[] { 0, 0 })]
208249
[InlineData(true, 1, new byte[] { 1, 0 })]
@@ -267,20 +308,18 @@ public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_par
267308
var destination = new ByteArrayBuffer(new byte[3], 3);
268309
var n = 0;
269310
var position = 0;
270-
Task<int> ReadPartial (byte[] buffer, int offset, int count)
311+
int ReadPartial (byte[] buffer, int offset, int count)
271312
{
272313
var length = partition[n++];
273314
Buffer.BlockCopy(bytes, position, buffer, offset, length);
274315
position += length;
275-
return Task.FromResult(length);
316+
return length;
276317
}
277318

278319
mockStream.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
279-
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count));
280-
mockStream.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
281-
.Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count));
282-
mockStream.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
283-
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
320+
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count)));
321+
mockStream.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
322+
.Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count));
284323

285324
if (async)
286325
{
@@ -533,5 +572,18 @@ private Mock<IByteBuffer> CreateMockByteBuffer(int length)
533572
mockBuffer.SetupGet(b => b.Length).Returns(length);
534573
return mockBuffer;
535574
}
575+
576+
private void SetupStreamRead(Mock<Stream> streamMock, TaskCompletionSource<int> readTaskCompletionSource = null)
577+
{
578+
readTaskCompletionSource ??= new TaskCompletionSource<int>();
579+
streamMock.Setup(s => s.Close()).Callback(() =>
580+
{
581+
readTaskCompletionSource.SetException(new IOException());
582+
});
583+
streamMock.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>())).Returns(() =>
584+
readTaskCompletionSource.Task.GetAwaiter().GetResult());
585+
streamMock.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>())).Returns(() =>
586+
readTaskCompletionSource.Task);
587+
}
536588
}
537589
}

tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public void Heartbeat_should_be_emitted_before_connection_open()
9696

9797
var mockStream = new Mock<Stream>();
9898
mockStream
99-
.Setup(s => s.BeginWrite(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
99+
.Setup(s => s.Write(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
100100
.Callback(() => EnqueueEvent(HelloReceivedEvent))
101101
.Throws(new Exception("Stream is closed."));
102102

0 commit comments

Comments
 (0)