Skip to content

Adapt InternalUploadFile for async #1653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Renci.SshNet/.editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,6 @@ dotnet_diagnostic.MA0042.severity = none

# S3236: Caller information arguments should not be provided explicitly
dotnet_diagnostic.S3236.severity = none

# S3358: Ternary operators should not be nested
dotnet_diagnostic.S3358.severity = none
30 changes: 28 additions & 2 deletions src/Renci.SshNet/ISubsystemSession.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Threading;
using System.Threading.Tasks;

using Renci.SshNet.Common;

Expand Down Expand Up @@ -39,15 +40,40 @@ internal interface ISubsystemSession : IDisposable
void Disconnect();

/// <summary>
/// Waits a specified time for a given <see cref="WaitHandle"/> to get signaled.
/// Waits a specified time for a given <see cref="WaitHandle"/> to be signaled.
/// </summary>
/// <param name="waitHandle">The handle to wait for.</param>
/// <param name="millisecondsTimeout">The number of millieseconds wait for <paramref name="waitHandle"/> to get signaled, or <c>-1</c> to wait indefinitely.</param>
/// <param name="millisecondsTimeout">The number of milliseconds to wait for <paramref name="waitHandle"/> to be signaled, or <c>-1</c> to wait indefinitely.</param>
/// <exception cref="SshException">The connection was closed by the server.</exception>
/// <exception cref="SshException">The channel was closed.</exception>
/// <exception cref="SshOperationTimeoutException">The handle did not get signaled within the specified timeout.</exception>
void WaitOnHandle(WaitHandle waitHandle, int millisecondsTimeout);

/// <summary>
/// Asynchronously waits for a given <see cref="WaitHandle"/> to be signaled.
/// </summary>
/// <param name="waitHandle">The handle to wait for.</param>
/// <param name="millisecondsTimeout">The number of milliseconds to wait for <paramref name="waitHandle"/> to be signaled, or <c>-1</c> to wait indefinitely.</param>
/// <param name="cancellationToken">The cancellation token to observe.</param>
/// <exception cref="SshException">The connection was closed by the server.</exception>
/// <exception cref="SshException">The channel was closed.</exception>
/// <exception cref="SshOperationTimeoutException">The handle did not get signaled within the specified timeout.</exception>
/// <returns>A <see cref="Task"/> representing the wait.</returns>
Task WaitOnHandleAsync(WaitHandle waitHandle, int millisecondsTimeout, CancellationToken cancellationToken);

/// <summary>
/// Asynchronously waits for a given <see cref="TaskCompletionSource{T}"/> to complete.
/// </summary>
/// <typeparam name="T">The type of the result which is being awaited.</typeparam>
/// <param name="tcs">The handle to wait for.</param>
/// <param name="millisecondsTimeout">The number of milliseconds to wait for <paramref name="tcs"/> to complete, or <c>-1</c> to wait indefinitely.</param>
/// <param name="cancellationToken">The cancellation token to observe.</param>
/// <exception cref="SshException">The connection was closed by the server.</exception>
/// <exception cref="SshException">The channel was closed.</exception>
/// <exception cref="SshOperationTimeoutException">The handle did not get signaled within the specified timeout.</exception>
/// <returns>A <see cref="Task"/> representing the wait.</returns>
Task<T> WaitOnHandleAsync<T>(TaskCompletionSource<T> tcs, int millisecondsTimeout, CancellationToken cancellationToken);

/// <summary>
/// Blocks the current thread until the specified <see cref="WaitHandle"/> gets signaled, using a
/// 32-bit signed integer to specify the time interval in milliseconds.
Expand Down
138 changes: 91 additions & 47 deletions src/Renci.SshNet/SftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,8 @@ public void UploadFile(Stream input, string path, Action<ulong>? uploadCallback
/// <inheritdoc/>
public void UploadFile(Stream input, string path, bool canOverride, Action<ulong>? uploadCallback = null)
{
ThrowHelper.ThrowIfNull(input);
ThrowHelper.ThrowIfNullOrWhiteSpace(path);
CheckDisposed();

var flags = Flags.Write | Flags.Truncate;
Expand All @@ -1040,15 +1042,31 @@ public void UploadFile(Stream input, string path, bool canOverride, Action<ulong
flags |= Flags.CreateNew;
}

InternalUploadFile(input, path, flags, asyncResult: null, uploadCallback);
InternalUploadFile(
input,
path,
flags,
asyncResult: null,
uploadCallback,
isAsync: false,
default).GetAwaiter().GetResult();
}

/// <inheritdoc />
public Task UploadFileAsync(Stream input, string path, CancellationToken cancellationToken = default)
{
ThrowHelper.ThrowIfNull(input);
ThrowHelper.ThrowIfNullOrWhiteSpace(path);
CheckDisposed();

return InternalUploadFileAsync(input, path, cancellationToken);
return InternalUploadFile(
input,
path,
Flags.Write | Flags.Truncate | Flags.CreateNewOrOpen,
asyncResult: null,
uploadCallback: null,
isAsync: true,
cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -1163,9 +1181,9 @@ public IAsyncResult BeginUploadFile(Stream input, string path, AsyncCallback? as
/// </remarks>
public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride, AsyncCallback? asyncCallback, object? state, Action<ulong>? uploadCallback = null)
{
CheckDisposed();
ThrowHelper.ThrowIfNull(input);
ThrowHelper.ThrowIfNullOrWhiteSpace(path);
CheckDisposed();

var flags = Flags.Write | Flags.Truncate;

Expand All @@ -1180,19 +1198,28 @@ public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride,

var asyncResult = new SftpUploadAsyncResult(asyncCallback, state);

ThreadAbstraction.ExecuteThread(() =>
_ = DoUploadAndSetResult();

async Task DoUploadAndSetResult()
{
try
{
InternalUploadFile(input, path, flags, asyncResult, uploadCallback);
await InternalUploadFile(
input,
path,
flags,
asyncResult,
uploadCallback,
isAsync: true,
CancellationToken.None).ConfigureAwait(false);

asyncResult.SetAsCompleted(exception: null, completedSynchronously: false);
}
catch (Exception exp)
{
asyncResult.SetAsCompleted(exception: exp, completedSynchronously: false);
asyncResult.SetAsCompleted(exp, completedSynchronously: false);
}
});
}

return asyncResult;
}
Expand Down Expand Up @@ -2284,11 +2311,16 @@ private List<FileInfo> InternalSynchronizeDirectories(string sourcePath, string
var remoteFileName = string.Format(CultureInfo.InvariantCulture, @"{0}/{1}", destinationPath, localFile.Name);
try
{
#pragma warning disable CA2000 // Dispose objects before losing scope; false positive
using (var file = File.OpenRead(localFile.FullName))
#pragma warning restore CA2000 // Dispose objects before losing scope; false positive
{
InternalUploadFile(file, remoteFileName, uploadFlag, asyncResult: null, uploadCallback: null);
InternalUploadFile(
file,
remoteFileName,
uploadFlag,
asyncResult: null,
uploadCallback: null,
isAsync: false,
CancellationToken.None).GetAwaiter().GetResult();
}

uploadedFiles.Add(localFile);
Expand Down Expand Up @@ -2455,37 +2487,42 @@ private async Task InternalDownloadFileAsync(string path, Stream output, Cancell
}
}

/// <summary>
/// Internals the upload file.
/// </summary>
/// <param name="input">The input.</param>
/// <param name="path">The path.</param>
/// <param name="flags">The flags.</param>
/// <param name="asyncResult">An <see cref="IAsyncResult"/> that references the asynchronous request.</param>
/// <param name="uploadCallback">The upload callback.</param>
/// <exception cref="ArgumentNullException"><paramref name="input" /> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException"><paramref name="path" /> is <see langword="null"/> or contains whitespace.</exception>
/// <exception cref="SshConnectionException">Client not connected.</exception>
private void InternalUploadFile(Stream input, string path, Flags flags, SftpUploadAsyncResult? asyncResult, Action<ulong>? uploadCallback)
#pragma warning disable S6966 // Awaitable method should be used
private async Task InternalUploadFile(
Stream input,
string path,
Flags flags,
SftpUploadAsyncResult? asyncResult,
Action<ulong>? uploadCallback,
bool isAsync,
CancellationToken cancellationToken)
{
ThrowHelper.ThrowIfNull(input);
ThrowHelper.ThrowIfNullOrWhiteSpace(path);
Debug.Assert(isAsync || cancellationToken == default);

if (_sftpSession is null)
{
throw new SshConnectionException("Client not connected.");
}

var fullPath = _sftpSession.GetCanonicalPath(path);
string fullPath;
byte[] handle;

var handle = _sftpSession.RequestOpen(fullPath, flags);
if (isAsync)
{
fullPath = await _sftpSession.GetCanonicalPathAsync(path, cancellationToken).ConfigureAwait(false);
handle = await _sftpSession.RequestOpenAsync(fullPath, flags, cancellationToken).ConfigureAwait(false);
}
else
{
fullPath = _sftpSession.GetCanonicalPath(path);
handle = _sftpSession.RequestOpen(fullPath, flags);
}

ulong offset = 0;

// create buffer of optimal length
var buffer = new byte[_sftpSession.CalculateOptimalWriteLength(_bufferSize, handle)];

int bytesRead;
var expectedResponses = 0;

// We will send out all the write requests without waiting for each response.
Expand All @@ -2495,8 +2532,21 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo

ExceptionDispatchInfo? exception = null;

while ((bytesRead = input.Read(buffer, 0, buffer.Length)) != 0)
while (true)
{
var bytesRead = isAsync
#if NET
? await input.ReadAsync(buffer, cancellationToken).ConfigureAwait(false)
#else
? await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)
#endif
: input.Read(buffer, 0, buffer.Length);

if (bytesRead == 0)
{
break;
}

if (asyncResult is not null && asyncResult.IsUploadCanceled)
{
break;
Expand Down Expand Up @@ -2555,34 +2605,28 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo

if (Volatile.Read(ref expectedResponses) != 0)
{
_sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout);
if (isAsync)
{
await _sftpSession.WaitOnHandleAsync(mres.WaitHandle, _operationTimeout, cancellationToken).ConfigureAwait(false);
}
else
{
_sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout);
}
}

exception?.Throw();

_sftpSession.RequestClose(handle);
}

private async Task InternalUploadFileAsync(Stream input, string path, CancellationToken cancellationToken)
{
ThrowHelper.ThrowIfNull(input);
ThrowHelper.ThrowIfNullOrWhiteSpace(path);

if (_sftpSession is null)
if (isAsync)
{
throw new SshConnectionException("Client not connected.");
await _sftpSession.RequestCloseAsync(handle, cancellationToken).ConfigureAwait(false);
}

cancellationToken.ThrowIfCancellationRequested();

var fullPath = await _sftpSession.GetCanonicalPathAsync(path, cancellationToken).ConfigureAwait(false);
var openStreamTask = SftpFileStream.OpenAsync(_sftpSession, fullPath, FileMode.Create, FileAccess.Write, (int)_bufferSize, cancellationToken);

using (var output = await openStreamTask.ConfigureAwait(false))
else
{
await input.CopyToAsync(output, 81920, cancellationToken).ConfigureAwait(false);
_sftpSession.RequestClose(handle);
}
}
#pragma warning restore S6966 // Awaitable method should be used

/// <summary>
/// Called when client is connected to the server.
Expand Down
Loading