@@ -57,7 +57,6 @@ struct ControlsView: View {
57
57
@State private var disclosedSteps = false
58
58
@State private var disclosedSeed = false
59
59
@State private var disclosedAdvanced = false
60
- @State private var useANE = ( Settings . shared. userSelectedAttentionVariant ?? ModelInfo . defaultAttention) == . splitEinsum
61
60
62
61
// TODO: refactor download with similar code in Loading.swift (iOS)
63
62
@State private var stateSubscriber : Cancellable ?
@@ -82,14 +81,18 @@ struct ControlsView: View {
82
81
mustShowSafetyCheckerDisclaimer = generation. disableSafety && !Settings. shared. safetyCheckerDisclaimerShown
83
82
}
84
83
85
- func updateANEState ( ) {
86
- Settings . shared. userSelectedAttentionVariant = useANE ? . splitEinsum : . original
84
+ func updateComputeUnitsState ( ) {
85
+ Settings . shared. userSelectedComputeUnits = generation . computeUnits
87
86
modelDidChange ( model: Settings . shared. currentModel)
88
87
}
89
88
89
+ func resetComputeUnitsState( ) {
90
+ generation. computeUnits = Settings . shared. userSelectedComputeUnits ?? ModelInfo . defaultComputeUnits
91
+ }
92
+
90
93
func modelDidChange( model: ModelInfo ) {
91
- guard pipelineLoader? . model != model || pipelineLoader? . variant != Settings . shared . userSelectedAttentionVariant else {
92
- print ( " Reusing same model \( model) with attention \( String ( describing : Settings . shared . userSelectedAttentionVariant ) ) " )
94
+ guard pipelineLoader? . model != model || pipelineLoader? . computeUnits != generation . computeUnits else {
95
+ print ( " Reusing same model \( model) with units \( generation . computeUnits ) " )
93
96
return
94
97
}
95
98
@@ -99,7 +102,7 @@ struct ControlsView: View {
99
102
pipelineLoader? . cancel ( )
100
103
pipelineState = . downloading( 0 )
101
104
Task . init {
102
- let loader = PipelineLoader ( model: model, variant : Settings . shared . userSelectedAttentionVariant , maxSeed: maxSeed)
105
+ let loader = PipelineLoader ( model: model, computeUnits : generation . computeUnits , maxSeed: maxSeed)
103
106
self . pipelineLoader = loader
104
107
stateSubscriber = loader. statePublisher. sink { state in
105
108
DispatchQueue . main. async {
@@ -128,8 +131,8 @@ struct ControlsView: View {
128
131
}
129
132
}
130
133
131
- func isModelDownloaded( _ model: ModelInfo , variant : AttentionVariant ? = nil ) -> Bool {
132
- PipelineLoader ( model: model, variant : variant ?? Settings . shared . userSelectedAttentionVariant ) . ready
134
+ func isModelDownloaded( _ model: ModelInfo , computeUnits : ComputeUnits ? = nil ) -> Bool {
135
+ PipelineLoader ( model: model, computeUnits : computeUnits ?? generation . computeUnits ) . ready
133
136
}
134
137
135
138
func modelLabel( _ model: ModelInfo ) -> Text {
@@ -301,21 +304,25 @@ struct ControlsView: View {
301
304
Divider ( )
302
305
DisclosureGroup ( isExpanded: $disclosedAdvanced) {
303
306
HStack {
304
- Toggle ( " Use Neural Engine " , isOn: $useANE) . onChange ( of: useANE) { value in
305
- guard let currentModel = ModelInfo . from ( modelVersion: model) else { return }
306
- let variantDownloaded = isModelDownloaded ( currentModel, variant: useANE ? . splitEinsum : . original)
307
- if variantDownloaded {
308
- updateANEState ( )
309
- } else {
310
- mustShowModelDownloadDisclaimer. toggle ( )
311
- }
312
- }
313
- . padding ( . leading, 10 )
307
+ Picker ( selection: $generation. computeUnits, label: Text ( " Use " ) ) {
308
+ Text ( " GPU " ) . tag ( ComputeUnits . cpuAndGPU)
309
+ Text ( " Neural Engine " ) . tag ( ComputeUnits . cpuAndNeuralEngine)
310
+ Text ( " GPU and Neural Engine " ) . tag ( ComputeUnits . all)
311
+ } . pickerStyle ( . radioGroup) . padding ( . leading)
314
312
Spacer ( )
315
313
}
314
+ . onChange ( of: generation. computeUnits) { units in
315
+ guard let currentModel = ModelInfo . from ( modelVersion: model) else { return }
316
+ let variantDownloaded = isModelDownloaded ( currentModel, computeUnits: units)
317
+ if variantDownloaded {
318
+ updateComputeUnitsState ( )
319
+ } else {
320
+ mustShowModelDownloadDisclaimer. toggle ( )
321
+ }
322
+ }
316
323
. alert ( " Download Required " , isPresented: $mustShowModelDownloadDisclaimer, actions: {
317
- Button ( " Cancel " , role: . destructive) { useANE . toggle ( ) }
318
- Button ( " Download " , role: . cancel) { updateANEState ( ) }
324
+ Button ( " Cancel " , role: . destructive) { resetComputeUnitsState ( ) }
325
+ Button ( " Download " , role: . cancel) { updateComputeUnitsState ( ) }
319
326
} , message: {
320
327
Text ( " This setting requires a new version of the selected model. " )
321
328
} )
0 commit comments