Skip to content

Commit 473164d

Browse files
committed
rework download progress
1 parent fb46593 commit 473164d

File tree

3 files changed

+68
-184
lines changed

3 files changed

+68
-184
lines changed

Tests.Vpn.Service/DownloaderTest.cs

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using System.Security.Cryptography;
33
using System.Security.Cryptography.X509Certificates;
44
using System.Text;
5-
using System.Threading.Channels;
65
using Coder.Desktop.Vpn.Service;
76
using Microsoft.Extensions.Logging.Abstractions;
87

@@ -278,7 +277,7 @@ public async Task Download(CancellationToken ct)
278277
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
279278
NullDownloadValidator.Instance, ct);
280279
await dlTask.Task;
281-
Assert.That(dlTask.TotalBytes, Is.EqualTo(4));
280+
Assert.That(dlTask.BytesTotal, Is.EqualTo(4));
282281
Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
283282
Assert.That(dlTask.Progress, Is.EqualTo(1));
284283
Assert.That(dlTask.IsCompleted, Is.True);
@@ -301,13 +300,13 @@ public async Task DownloadSameDest(CancellationToken ct)
301300
NullDownloadValidator.Instance, ct);
302301
var dlTask0 = await startTask0;
303302
await dlTask0.Task;
304-
Assert.That(dlTask0.TotalBytes, Is.EqualTo(5));
303+
Assert.That(dlTask0.BytesTotal, Is.EqualTo(5));
305304
Assert.That(dlTask0.BytesWritten, Is.EqualTo(5));
306305
Assert.That(dlTask0.Progress, Is.EqualTo(1));
307306
Assert.That(dlTask0.IsCompleted, Is.True);
308307
var dlTask1 = await startTask1;
309308
await dlTask1.Task;
310-
Assert.That(dlTask1.TotalBytes, Is.EqualTo(5));
309+
Assert.That(dlTask1.BytesTotal, Is.EqualTo(5));
311310
Assert.That(dlTask1.BytesWritten, Is.EqualTo(5));
312311
Assert.That(dlTask1.Progress, Is.EqualTo(1));
313312
Assert.That(dlTask1.IsCompleted, Is.True);
@@ -320,9 +319,9 @@ public async Task DownloadWithXOriginalContentLength(CancellationToken ct)
320319
using var httpServer = new TestHttpServer(async ctx =>
321320
{
322321
ctx.Response.StatusCode = 200;
323-
ctx.Response.Headers.Add("X-Original-Content-Length", "6"); // wrong but should be used until complete
322+
ctx.Response.Headers.Add("X-Original-Content-Length", "4");
324323
ctx.Response.ContentType = "text/plain";
325-
ctx.Response.ContentLength64 = 4; // This should be ignored.
324+
// Don't set Content-Length.
326325
await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
327326
});
328327
var url = new Uri(httpServer.BaseUrl + "/test");
@@ -331,25 +330,30 @@ public async Task DownloadWithXOriginalContentLength(CancellationToken ct)
331330
var req = new HttpRequestMessage(HttpMethod.Get, url);
332331
var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);
333332

334-
var progressChannel = Channel.CreateUnbounded<DownloadProgressEvent>();
335-
dlTask.ProgressChanged += (_, args) =>
336-
Assert.That(progressChannel.Writer.TryWrite(args), Is.True);
337-
338333
await dlTask.Task;
339-
Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); // should equal BytesWritten after completion
334+
Assert.That(dlTask.BytesTotal, Is.EqualTo(4));
340335
Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
341-
progressChannel.Writer.Complete();
342-
343-
var list = progressChannel.Reader.ReadAllAsync(ct).ToBlockingEnumerable(ct).ToList();
344-
Assert.That(list.Count, Is.GreaterThanOrEqualTo(2)); // there may be an item in the middle
345-
// The first item should be the initial progress with 0 bytes written.
346-
Assert.That(list[0].BytesWritten, Is.EqualTo(0));
347-
Assert.That(list[0].BytesTotal, Is.EqualTo(6)); // from X-Original-Content-Length
348-
Assert.That(list[0].Progress, Is.EqualTo(0.0d));
349-
// The last item should be final progress with the actual total bytes.
350-
Assert.That(list[^1].BytesWritten, Is.EqualTo(4));
351-
Assert.That(list[^1].BytesTotal, Is.EqualTo(4)); // from the actual bytes written
352-
Assert.That(list[^1].Progress, Is.EqualTo(1.0d));
336+
}
337+
338+
[Test(Description = "Download with mismatched Content-Length")]
339+
[CancelAfter(30_000)]
340+
public async Task DownloadWithMismatchedContentLength(CancellationToken ct)
341+
{
342+
using var httpServer = new TestHttpServer(async ctx =>
343+
{
344+
ctx.Response.StatusCode = 200;
345+
ctx.Response.Headers.Add("X-Original-Content-Length", "5"); // incorrect
346+
ctx.Response.ContentType = "text/plain";
347+
await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
348+
});
349+
var url = new Uri(httpServer.BaseUrl + "/test");
350+
var destPath = Path.Combine(_tempDir, "test");
351+
var manager = new Downloader(NullLogger<Downloader>.Instance);
352+
var req = new HttpRequestMessage(HttpMethod.Get, url);
353+
var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);
354+
355+
var ex = Assert.ThrowsAsync<IOException>(() => dlTask.Task);
356+
Assert.That(ex.Message, Is.EqualTo("Downloaded file size does not match expected response content length: Expected=5, BytesWritten=4"));
353357
}
354358

355359
[Test(Description = "Download with custom headers")]

Vpn.Service/Downloader.cs

Lines changed: 10 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -338,51 +338,14 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance
338338
}
339339
}
340340

341-
public class DownloadProgressEvent
342-
{
343-
// TODO: speed calculation would be nice
344-
public ulong BytesWritten { get; init; }
345-
public ulong? BytesTotal { get; init; } // null if unknown
346-
347-
public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value;
348-
349-
public override string ToString()
350-
{
351-
var s = FriendlyBytes(BytesWritten);
352-
if (BytesTotal != null)
353-
s += $" of {FriendlyBytes(BytesTotal.Value)}";
354-
else
355-
s += " of unknown";
356-
if (Progress != null)
357-
s += $" ({Progress:0%})";
358-
return s;
359-
}
360-
361-
private static readonly string[] ByteSuffixes = ["B", "KB", "MB", "GB", "TB", "PB", "EB"];
362-
363-
// Unfortunately this is copied from FriendlyByteConverter in App. Ideally
364-
// it should go into some shared utilities project, but it's overkill to do
365-
// that for a single tiny function until we have more shared code.
366-
private static string FriendlyBytes(ulong bytes)
367-
{
368-
if (bytes == 0)
369-
return $"0 {ByteSuffixes[0]}";
370-
371-
var place = Convert.ToInt32(Math.Floor(Math.Log(bytes, 1024)));
372-
var num = Math.Round(bytes / Math.Pow(1024, place), 1);
373-
return $"{num} {ByteSuffixes[place]}";
374-
}
375-
}
376-
377341
/// <summary>
378342
/// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final
379343
/// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if
380344
/// it hasn't changed.
381345
/// </summary>
382346
public class DownloadTask
383347
{
384-
private const int BufferSize = 4096;
385-
private const int ProgressUpdateDelayMs = 50;
348+
private const int BufferSize = 64 * 1024;
386349
private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available
387350

388351
private static readonly HttpClient HttpClient = new(new HttpClientHandler
@@ -398,22 +361,13 @@ public class DownloadTask
398361
private readonly string _destinationPath;
399362
private readonly string _tempDestinationPath;
400363

401-
// ProgressChanged events are always delayed by up to 50ms to avoid
402-
// flooding.
403-
//
404-
// This will be called:
405-
// - once after the request succeeds but before the read/write routine
406-
// begins
407-
// - occasionally while the file is being downloaded (at least 50ms apart)
408-
// - once when the download is complete
409-
public EventHandler<DownloadProgressEvent>? ProgressChanged;
410-
411364
public readonly HttpRequestMessage Request;
412365

413366
public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync
367+
public bool DownloadStarted { get; private set; } // Whether we've received headers yet and started the actual download
414368
public ulong BytesWritten { get; private set; }
415-
public ulong? TotalBytes { get; private set; }
416-
public double? Progress => TotalBytes == null ? null : (double)BytesWritten / TotalBytes.Value;
369+
public ulong? BytesTotal { get; private set; }
370+
public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value;
417371
public bool IsCompleted => Task.IsCompleted;
418372

419373
internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator)
@@ -496,32 +450,27 @@ private async Task Start(CancellationToken ct = default)
496450
}
497451

498452
if (res.Content.Headers.ContentLength >= 0)
499-
TotalBytes = (ulong)res.Content.Headers.ContentLength;
453+
BytesTotal = (ulong)res.Content.Headers.ContentLength;
500454

501455
// X-Original-Content-Length overrules Content-Length if set.
502456
if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues))
503457
{
504458
// If there are multiple we only look at the first one.
505459
var headerValue = headerValues.ToList().FirstOrDefault();
506460
if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength))
507-
TotalBytes = originalContentLength;
461+
BytesTotal = originalContentLength;
508462
else
509463
_logger.LogWarning(
510464
"Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'",
511465
XOriginalContentLengthHeader, headerValue);
512466
}
513467

514-
SendProgressUpdate(new DownloadProgressEvent
515-
{
516-
BytesWritten = 0,
517-
BytesTotal = TotalBytes,
518-
});
519-
520468
await Download(res, ct);
521469
}
522470

523471
private async Task Download(HttpResponseMessage res, CancellationToken ct)
524472
{
473+
DownloadStarted = true;
525474
try
526475
{
527476
var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null;
@@ -546,28 +495,13 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct)
546495
await tempFile.WriteAsync(buffer.AsMemory(0, n), ct);
547496
sha1?.TransformBlock(buffer, 0, n, null, 0);
548497
BytesWritten += (ulong)n;
549-
await QueueProgressUpdate(new DownloadProgressEvent
550-
{
551-
BytesWritten = BytesWritten,
552-
BytesTotal = TotalBytes,
553-
}, ct);
554498
}
555499
}
556500

557-
// Clear any pending progress updates to ensure they won't be sent
558-
// after the final update.
559-
await ClearQueuedProgressUpdate(ct);
560-
// Then write the final status update.
561-
TotalBytes = BytesWritten;
562-
SendProgressUpdate(new DownloadProgressEvent
563-
{
564-
BytesWritten = BytesWritten,
565-
BytesTotal = BytesWritten,
566-
});
567-
568-
if (TotalBytes != null && BytesWritten != TotalBytes)
501+
BytesTotal ??= BytesWritten;
502+
if (BytesWritten != BytesTotal)
569503
throw new IOException(
570-
$"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesWritten}");
504+
$"Downloaded file size does not match expected response content length: Expected={BytesTotal}, BytesWritten={BytesWritten}");
571505

572506
// Verify the ETag if it was sent by the server.
573507
if (res.Headers.Contains("ETag") && sha1 != null)
@@ -612,69 +546,4 @@ await QueueProgressUpdate(new DownloadProgressEvent
612546
throw;
613547
}
614548
}
615-
616-
// _progressEventLock protects _progressUpdateTask and _pendingProgressEvent.
617-
private readonly RaiiSemaphoreSlim _progressEventLock = new(1, 1);
618-
private readonly CancellationTokenSource _progressUpdateCts = new();
619-
private Task? _progressUpdateTask;
620-
private DownloadProgressEvent? _pendingProgressEvent;
621-
622-
// Can be called multiple times, but must not be called or in progress while
623-
// SendQueuedProgressUpdateNow is called.
624-
private async Task QueueProgressUpdate(DownloadProgressEvent e, CancellationToken ct)
625-
{
626-
using var _1 = await _progressEventLock.LockAsync(ct);
627-
_pendingProgressEvent = e;
628-
629-
if (_progressUpdateCts.IsCancellationRequested)
630-
throw new InvalidOperationException("Progress update task was cancelled, cannot queue new progress update");
631-
632-
// Start a task with a 50ms delay unless one is already running.
633-
var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _progressUpdateCts.Token);
634-
cts.CancelAfter(TimeSpan.FromSeconds(5));
635-
_progressUpdateTask ??= Task.Delay(ProgressUpdateDelayMs, cts.Token)
636-
.ContinueWith(t =>
637-
{
638-
cts.Cancel();
639-
using var _2 = _progressEventLock.Lock();
640-
_progressUpdateTask = null;
641-
if (t.IsFaulted || t.IsCanceled) return;
642-
643-
var ev = _pendingProgressEvent;
644-
if (ev != null) SendProgressUpdate(ev);
645-
}, cts.Token);
646-
}
647-
648-
// Must only be called after all QueueProgressUpdate calls have completed.
649-
private async Task ClearQueuedProgressUpdate(CancellationToken ct)
650-
{
651-
Task? t;
652-
using (var _ = _progressEventLock.LockAsync(ct))
653-
{
654-
await _progressUpdateCts.CancelAsync();
655-
t = _progressUpdateTask;
656-
}
657-
658-
// We can't continue to hold the lock here because the continuation
659-
// grabs a lock. We don't need to worry about a new task spawning after
660-
// this because the token is cancelled.
661-
if (t == null) return;
662-
try
663-
{
664-
await t.WaitAsync(ct);
665-
}
666-
catch (TaskCanceledException)
667-
{
668-
// Ignore
669-
}
670-
}
671-
672-
private void SendProgressUpdate(DownloadProgressEvent e)
673-
{
674-
var handler = ProgressChanged;
675-
if (handler == null)
676-
return;
677-
// Start a new task in the background to invoke the event.
678-
_ = Task.Run(() => handler.Invoke(this, e));
679-
}
680549
}

Vpn.Service/Manager.cs

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -450,34 +450,46 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected
450450
_logger.LogDebug("Skipping tunnel binary version validation");
451451
}
452452

453+
// Note: all ETag, signature and version validation is performed by the
454+
// DownloadTask.
453455
var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct);
454456

455-
var progressLock = new RaiiSemaphoreSlim(1, 1);
456-
var progressBroadcastCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
457-
downloadTask.ProgressChanged += (sender, ev) =>
457+
// Wait for the download to complete, sending progress updates every
458+
// 50ms.
459+
while (true)
458460
{
459-
using var _ = progressLock.Lock();
460-
if (progressBroadcastCts.IsCancellationRequested) return;
461-
_logger.LogInformation("Download progress: {ev}", ev);
461+
// Wait for the download to complete, or for a short delay before
462+
// we send a progress update.
463+
var delayTask = Task.Delay(TimeSpan.FromMilliseconds(50), ct);
464+
var winner = await Task.WhenAny([
465+
downloadTask.Task,
466+
delayTask,
467+
]);
468+
if (winner == downloadTask.Task)
469+
break;
470+
471+
// Task.WhenAny will not throw if the winner was cancelled, so
472+
// check CT afterward and not beforehand.
473+
ct.ThrowIfCancellationRequested();
474+
475+
if (!downloadTask.DownloadStarted)
476+
// Don't send progress updates if we don't know what the
477+
// progress is yet.
478+
continue;
462479

463480
var progress = new StartProgressDownloadProgress
464481
{
465-
BytesWritten = ev.BytesWritten,
482+
BytesWritten = downloadTask.BytesWritten,
466483
};
467-
if (ev.BytesTotal != null)
468-
progress.BytesTotal = ev.BytesTotal.Value;
469-
BroadcastStartProgress(StartProgressStage.Downloading, progress, progressBroadcastCts.Token)
470-
.Wait(progressBroadcastCts.Token);
471-
};
484+
if (downloadTask.BytesTotal != null)
485+
progress.BytesTotal = downloadTask.BytesTotal.Value;
472486

473-
// Awaiting this will check the checksum (via the ETag) if the file
474-
// exists, and will also validate the signature and version.
475-
await downloadTask.Task;
487+
await BroadcastStartProgress(StartProgressStage.Downloading, progress, ct);
488+
}
476489

477-
// Prevent any lagging progress events from being sent.
478-
// ReSharper disable once PossiblyMistakenUseOfCancellationToken
479-
using (await progressLock.LockAsync(ct))
480-
await progressBroadcastCts.CancelAsync();
490+
// Await again to re-throw any exceptions that occurred during the
491+
// download.
492+
await downloadTask.Task;
481493

482494
// We don't send a broadcast here as we immediately send one in the
483495
// parent routine.
@@ -486,7 +498,6 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected
486498

487499
private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default)
488500
{
489-
_logger.LogInformation("Start progress: {stage}", stage);
490501
await FallibleBroadcast(new ServiceMessage
491502
{
492503
StartProgress = new StartProgress

0 commit comments

Comments
 (0)