forked from huggingface/swift-coreml-diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDownloader.swift
110 lines (95 loc) · 4.1 KB
/
Downloader.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
//
// Downloader.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Foundation
import Combine
class Downloader: NSObject, ObservableObject {
private(set) var destination: URL
enum DownloadState {
case notStarted
case downloading(Double)
case completed(URL)
case failed(Error)
}
private(set) lazy var downloadState: CurrentValueSubject<DownloadState, Never> = CurrentValueSubject(.notStarted)
private var stateSubscriber: Cancellable?
private var urlSession: URLSession? = nil
init(from url: URL, to destination: URL, using authToken: String? = nil) {
self.destination = destination
super.init()
var config = URLSessionConfiguration.default
#if !os(macOS)
// .background allows downloads to proceed in the background
// helpful for devices that may not keep the app in the foreground for the download duration
config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download")
config.isDiscretionary = false
config.sessionSendsLaunchEvents = true
#endif
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue())
downloadState.value = .downloading(0)
urlSession?.getAllTasks { tasks in
// If there's an existing pending background task with the same URL, let it proceed.
guard tasks.filter({ $0.originalRequest?.url == url }).isEmpty else {
print("Already downloading \(url)")
return
}
print("Starting download of \(url)")
var request = URLRequest(url: url)
if let authToken = authToken {
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
}
self.urlSession?.downloadTask(with: request).resume()
}
}
@discardableResult
func waitUntilDone() throws -> URL {
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
let semaphore = DispatchSemaphore(value: 0)
stateSubscriber = downloadState.sink { state in
switch state {
case .completed: semaphore.signal()
case .failed: semaphore.signal()
default: break
}
}
semaphore.wait()
switch downloadState.value {
case .completed(let url): return url
case .failed(let error): throw error
default: throw("Should never happen, lol")
}
}
func cancel() {
urlSession?.invalidateAndCancel()
}
}
extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))
}
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
guard FileManager.default.fileExists(atPath: location.path) else {
downloadState.value = .failed("Invalid download location received: \(location)")
return
}
do {
try FileManager.default.moveItem(at: location, to: destination)
downloadState.value = .completed(destination)
} catch {
downloadState.value = .failed(error)
}
}
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
downloadState.value = .failed(error)
} else if let response = task.response as? HTTPURLResponse {
print("HTTP response status code: \(response.statusCode)")
// let headers = response.allHeaderFields
// print("HTTP response headers: \(headers)")
}
}
}