diff --git a/src/Renci.SshNet/.editorconfig b/src/Renci.SshNet/.editorconfig index ee5bc33b9..a596998dc 100644 --- a/src/Renci.SshNet/.editorconfig +++ b/src/Renci.SshNet/.editorconfig @@ -191,3 +191,6 @@ dotnet_diagnostic.MA0042.severity = none # S3236: Caller information arguments should not be provided explicitly dotnet_diagnostic.S3236.severity = none + +# S3358: Ternary operators should not be nested +dotnet_diagnostic.S3358.severity = none diff --git a/src/Renci.SshNet/ISubsystemSession.cs b/src/Renci.SshNet/ISubsystemSession.cs index 4fa6b28b2..c87747bbd 100644 --- a/src/Renci.SshNet/ISubsystemSession.cs +++ b/src/Renci.SshNet/ISubsystemSession.cs @@ -1,5 +1,6 @@ using System; using System.Threading; +using System.Threading.Tasks; using Renci.SshNet.Common; @@ -39,15 +40,40 @@ internal interface ISubsystemSession : IDisposable void Disconnect(); /// - /// Waits a specified time for a given to get signaled. + /// Waits a specified time for a given to be signaled. /// /// The handle to wait for. - /// The number of millieseconds wait for to get signaled, or -1 to wait indefinitely. + /// The number of milliseconds to wait for to be signaled, or -1 to wait indefinitely. /// The connection was closed by the server. /// The channel was closed. /// The handle did not get signaled within the specified timeout. void WaitOnHandle(WaitHandle waitHandle, int millisecondsTimeout); + /// + /// Asynchronously waits for a given to be signaled. + /// + /// The handle to wait for. + /// The number of milliseconds to wait for to be signaled, or -1 to wait indefinitely. + /// The cancellation token to observe. + /// The connection was closed by the server. + /// The channel was closed. + /// The handle did not get signaled within the specified timeout. + /// A representing the wait. + Task WaitOnHandleAsync(WaitHandle waitHandle, int millisecondsTimeout, CancellationToken cancellationToken); + + /// + /// Asynchronously waits for a given to complete. + /// + /// The type of the result which is being awaited. + /// The handle to wait for. + /// The number of milliseconds to wait for to complete, or -1 to wait indefinitely. + /// The cancellation token to observe. + /// The connection was closed by the server. + /// The channel was closed. + /// The handle did not get signaled within the specified timeout. + /// A representing the wait. + Task WaitOnHandleAsync(TaskCompletionSource tcs, int millisecondsTimeout, CancellationToken cancellationToken); + /// /// Blocks the current thread until the specified gets signaled, using a /// 32-bit signed integer to specify the time interval in milliseconds. diff --git a/src/Renci.SshNet/SftpClient.cs b/src/Renci.SshNet/SftpClient.cs index e6d4efc16..7cd304dc6 100644 --- a/src/Renci.SshNet/SftpClient.cs +++ b/src/Renci.SshNet/SftpClient.cs @@ -1027,6 +1027,8 @@ public void UploadFile(Stream input, string path, Action? uploadCallback /// public void UploadFile(Stream input, string path, bool canOverride, Action? uploadCallback = null) { + ThrowHelper.ThrowIfNull(input); + ThrowHelper.ThrowIfNullOrWhiteSpace(path); CheckDisposed(); var flags = Flags.Write | Flags.Truncate; @@ -1040,15 +1042,31 @@ public void UploadFile(Stream input, string path, bool canOverride, Action public Task UploadFileAsync(Stream input, string path, CancellationToken cancellationToken = default) { + ThrowHelper.ThrowIfNull(input); + ThrowHelper.ThrowIfNullOrWhiteSpace(path); CheckDisposed(); - return InternalUploadFileAsync(input, path, cancellationToken); + return InternalUploadFile( + input, + path, + Flags.Write | Flags.Truncate | Flags.CreateNewOrOpen, + asyncResult: null, + uploadCallback: null, + isAsync: true, + cancellationToken); } /// @@ -1163,9 +1181,9 @@ public IAsyncResult BeginUploadFile(Stream input, string path, AsyncCallback? as /// public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride, AsyncCallback? asyncCallback, object? state, Action? uploadCallback = null) { - CheckDisposed(); ThrowHelper.ThrowIfNull(input); ThrowHelper.ThrowIfNullOrWhiteSpace(path); + CheckDisposed(); var flags = Flags.Write | Flags.Truncate; @@ -1180,19 +1198,28 @@ public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride, var asyncResult = new SftpUploadAsyncResult(asyncCallback, state); - ThreadAbstraction.ExecuteThread(() => + _ = DoUploadAndSetResult(); + + async Task DoUploadAndSetResult() { try { - InternalUploadFile(input, path, flags, asyncResult, uploadCallback); + await InternalUploadFile( + input, + path, + flags, + asyncResult, + uploadCallback, + isAsync: true, + CancellationToken.None).ConfigureAwait(false); asyncResult.SetAsCompleted(exception: null, completedSynchronously: false); } catch (Exception exp) { - asyncResult.SetAsCompleted(exception: exp, completedSynchronously: false); + asyncResult.SetAsCompleted(exp, completedSynchronously: false); } - }); + } return asyncResult; } @@ -2284,11 +2311,16 @@ private List InternalSynchronizeDirectories(string sourcePath, string var remoteFileName = string.Format(CultureInfo.InvariantCulture, @"{0}/{1}", destinationPath, localFile.Name); try { -#pragma warning disable CA2000 // Dispose objects before losing scope; false positive using (var file = File.OpenRead(localFile.FullName)) -#pragma warning restore CA2000 // Dispose objects before losing scope; false positive { - InternalUploadFile(file, remoteFileName, uploadFlag, asyncResult: null, uploadCallback: null); + InternalUploadFile( + file, + remoteFileName, + uploadFlag, + asyncResult: null, + uploadCallback: null, + isAsync: false, + CancellationToken.None).GetAwaiter().GetResult(); } uploadedFiles.Add(localFile); @@ -2455,37 +2487,42 @@ private async Task InternalDownloadFileAsync(string path, Stream output, Cancell } } - /// - /// Internals the upload file. - /// - /// The input. - /// The path. - /// The flags. - /// An that references the asynchronous request. - /// The upload callback. - /// is . - /// is or contains whitespace. - /// Client not connected. - private void InternalUploadFile(Stream input, string path, Flags flags, SftpUploadAsyncResult? asyncResult, Action? uploadCallback) +#pragma warning disable S6966 // Awaitable method should be used + private async Task InternalUploadFile( + Stream input, + string path, + Flags flags, + SftpUploadAsyncResult? asyncResult, + Action? uploadCallback, + bool isAsync, + CancellationToken cancellationToken) { - ThrowHelper.ThrowIfNull(input); - ThrowHelper.ThrowIfNullOrWhiteSpace(path); + Debug.Assert(isAsync || cancellationToken == default); if (_sftpSession is null) { throw new SshConnectionException("Client not connected."); } - var fullPath = _sftpSession.GetCanonicalPath(path); + string fullPath; + byte[] handle; - var handle = _sftpSession.RequestOpen(fullPath, flags); + if (isAsync) + { + fullPath = await _sftpSession.GetCanonicalPathAsync(path, cancellationToken).ConfigureAwait(false); + handle = await _sftpSession.RequestOpenAsync(fullPath, flags, cancellationToken).ConfigureAwait(false); + } + else + { + fullPath = _sftpSession.GetCanonicalPath(path); + handle = _sftpSession.RequestOpen(fullPath, flags); + } ulong offset = 0; // create buffer of optimal length var buffer = new byte[_sftpSession.CalculateOptimalWriteLength(_bufferSize, handle)]; - int bytesRead; var expectedResponses = 0; // We will send out all the write requests without waiting for each response. @@ -2495,8 +2532,21 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo ExceptionDispatchInfo? exception = null; - while ((bytesRead = input.Read(buffer, 0, buffer.Length)) != 0) + while (true) { + var bytesRead = isAsync +#if NET + ? await input.ReadAsync(buffer, cancellationToken).ConfigureAwait(false) +#else + ? await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false) +#endif + : input.Read(buffer, 0, buffer.Length); + + if (bytesRead == 0) + { + break; + } + if (asyncResult is not null && asyncResult.IsUploadCanceled) { break; @@ -2555,34 +2605,28 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo if (Volatile.Read(ref expectedResponses) != 0) { - _sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout); + if (isAsync) + { + await _sftpSession.WaitOnHandleAsync(mres.WaitHandle, _operationTimeout, cancellationToken).ConfigureAwait(false); + } + else + { + _sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout); + } } exception?.Throw(); - _sftpSession.RequestClose(handle); - } - - private async Task InternalUploadFileAsync(Stream input, string path, CancellationToken cancellationToken) - { - ThrowHelper.ThrowIfNull(input); - ThrowHelper.ThrowIfNullOrWhiteSpace(path); - - if (_sftpSession is null) + if (isAsync) { - throw new SshConnectionException("Client not connected."); + await _sftpSession.RequestCloseAsync(handle, cancellationToken).ConfigureAwait(false); } - - cancellationToken.ThrowIfCancellationRequested(); - - var fullPath = await _sftpSession.GetCanonicalPathAsync(path, cancellationToken).ConfigureAwait(false); - var openStreamTask = SftpFileStream.OpenAsync(_sftpSession, fullPath, FileMode.Create, FileAccess.Write, (int)_bufferSize, cancellationToken); - - using (var output = await openStreamTask.ConfigureAwait(false)) + else { - await input.CopyToAsync(output, 81920, cancellationToken).ConfigureAwait(false); + _sftpSession.RequestClose(handle); } } +#pragma warning restore S6966 // Awaitable method should be used /// /// Called when client is connected to the server. diff --git a/src/Renci.SshNet/SubsystemSession.cs b/src/Renci.SshNet/SubsystemSession.cs index 6c6200f4b..e75a35f2b 100644 --- a/src/Renci.SshNet/SubsystemSession.cs +++ b/src/Renci.SshNet/SubsystemSession.cs @@ -249,56 +249,81 @@ public void WaitOnHandle(WaitHandle waitHandle, int millisecondsTimeout) } } - protected async Task WaitOnHandleAsync(TaskCompletionSource tcs, int millisecondsTimeout, CancellationToken cancellationToken) + public async Task WaitOnHandleAsync(WaitHandle waitHandle, int millisecondsTimeout, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); - - var errorOccuredReg = ThreadPool.RegisterWaitForSingleObject( - _errorOccuredWaitHandle, - (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(_exception), - state: tcs, - millisecondsTimeOutInterval: -1, - executeOnlyOnce: true); - - var sessionDisconnectedReg = ThreadPool.RegisterWaitForSingleObject( - _sessionDisconnectedWaitHandle, - static (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(new SshException("Connection was closed by the server.")), - state: tcs, - millisecondsTimeOutInterval: -1, - executeOnlyOnce: true); - - var channelClosedReg = ThreadPool.RegisterWaitForSingleObject( - _channelClosedWaitHandle, - static (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(new SshException("Channel was closed.")), - state: tcs, - millisecondsTimeOutInterval: -1, - executeOnlyOnce: true); - - using var timeoutCts = new CancellationTokenSource(millisecondsTimeout); - using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); - - using var tokenReg = linkedCts.Token.Register( - static s => - { - (var tcs, var cancellationToken) = ((TaskCompletionSource, CancellationToken))s; - _ = tcs.TrySetCanceled(cancellationToken); - }, - state: (tcs, cancellationToken), - useSynchronizationContext: false); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - try + using RegisteredWait reg = new( + waitHandle, + (tcs, _) => ((TaskCompletionSource)tcs).TrySetResult(null), + state: tcs); + + _ = await WaitOnHandleAsync(tcs, millisecondsTimeout, cancellationToken).ConfigureAwait(false); + } + + public Task WaitOnHandleAsync(TaskCompletionSource tcs, int millisecondsTimeout, CancellationToken cancellationToken) + { + return tcs.Task.IsCompleted ? tcs.Task + : cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) + : DoWaitAsync(tcs, millisecondsTimeout, cancellationToken); + + async Task DoWaitAsync(TaskCompletionSource tcs, int millisecondsTimeout, CancellationToken cancellationToken) { - return await tcs.Task.ConfigureAwait(false); + using RegisteredWait errorOccuredReg = new( + _errorOccuredWaitHandle, + (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(_exception), + state: tcs); + + using RegisteredWait sessionDisconnectedReg = new( + _sessionDisconnectedWaitHandle, + static (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(new SshException("Connection was closed by the server.")), + state: tcs); + + using RegisteredWait channelClosedReg = new( + _channelClosedWaitHandle, + static (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(new SshException("Channel was closed.")), + state: tcs); + + using var timeoutCts = new CancellationTokenSource(millisecondsTimeout); + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); + + using var tokenReg = linkedCts.Token.Register( + static s => + { + (var tcs, var cancellationToken) = ((TaskCompletionSource, CancellationToken))s; + _ = tcs.TrySetCanceled(cancellationToken); + }, + state: (tcs, cancellationToken), + useSynchronizationContext: false); + + try + { + return await tcs.Task.ConfigureAwait(false); + } + catch (OperationCanceledException oce) when (timeoutCts.IsCancellationRequested) + { + throw new SshOperationTimeoutException("Operation has timed out.", oce); + } } - catch (OperationCanceledException oce) when (timeoutCts.IsCancellationRequested) + } + + private readonly struct RegisteredWait : IDisposable + { + private readonly RegisteredWaitHandle _handle; + + public RegisteredWait(WaitHandle waitObject, WaitOrTimerCallback callback, object state) { - throw new SshOperationTimeoutException("Operation has timed out.", oce); + _handle = ThreadPool.RegisterWaitForSingleObject( + waitObject, + callback, + state, + millisecondsTimeOutInterval: -1, + executeOnlyOnce: true); } - finally + + public void Dispose() { - _ = errorOccuredReg.Unregister(waitObject: null); - _ = sessionDisconnectedReg.Unregister(waitObject: null); - _ = channelClosedReg.Unregister(waitObject: null); + _ = _handle.Unregister(waitObject: null); } }