forked from huggingface/swift-coreml-diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDownloader.swift
109 lines (94 loc) · 3.95 KB
/
Downloader.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
//
// Downloader.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Foundation
import Combine
import Path
class Downloader: NSObject, ObservableObject {
private(set) var destination: URL
enum DownloadState {
case notStarted
case downloading(Double)
case completed(URL)
case failed(Error)
}
private(set) lazy var downloadState: CurrentValueSubject<DownloadState, Never> = CurrentValueSubject(.notStarted)
private var stateSubscriber: Cancellable?
private var urlSession: URLSession? = nil
init(from url: URL, to destination: URL, using authToken: String? = nil) {
self.destination = destination
super.init()
// .background allows downloads to proceed in the background
let config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download")
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue())
downloadState.value = .downloading(0)
urlSession?.getAllTasks { tasks in
// If there's an existing pending background task with the same URL, let it proceed.
guard tasks.filter({ $0.originalRequest?.url == url }).isEmpty else {
print("Already downloading \(url)")
return
}
print("Starting download of \(url)")
var request = URLRequest(url: url)
if let authToken = authToken {
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
}
self.urlSession?.downloadTask(with: request).resume()
}
}
@discardableResult
func waitUntilDone() throws -> URL {
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
let semaphore = DispatchSemaphore(value: 0)
stateSubscriber = downloadState.sink { state in
switch state {
case .completed: semaphore.signal()
case .failed: semaphore.signal()
default: break
}
}
semaphore.wait()
switch downloadState.value {
case .completed(let url): return url
case .failed(let error): throw error
default: throw("Should never happen, lol")
}
}
func cancel() {
urlSession?.invalidateAndCancel()
}
}
extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten _: Int64, totalBytesExpectedToWrite _: Int64) {
downloadState.value = .downloading(downloadTask.progress.fractionCompleted)
}
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
guard let path = Path(url: location) else {
downloadState.value = .failed("Invalid download location received: \(location)")
return
}
guard let toPath = Path(url: destination) else {
downloadState.value = .failed("Invalid destination: \(destination)")
return
}
do {
try path.move(to: toPath, overwrite: true)
downloadState.value = .completed(destination)
} catch {
downloadState.value = .failed(error)
}
}
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
downloadState.value = .failed(error)
} else if let response = task.response as? HTTPURLResponse {
print("HTTP response status code: \(response.statusCode)")
// let headers = response.allHeaderFields
// print("HTTP response headers: \(headers)")
}
}
}