Skip to content

Commit 6cb0bdc

Browse files
dolmerepcuenca
dolmere
andauthored
PromptTextField (huggingface#58)
* PromptTextField huggingface#53 Rather than limiting the prompt entry on textfield size limits or character limits this PR adds a token tracker. As the user enters prompts an indicator displays the current number of used tokens versus the maximum of 75. Tokens are calculated using the currently selected model. For the moment this is a short term goal to expose the token count to the end user with a change of text color to warn of approaching or exceeding mac token count. Note that in a1111 they handle this by merging excessive tokens. "AUTOMATIC1111 has no token limits. If a prompt contains more than 75 tokens, the limit of the CLIP tokenizer, it will start a new chunk of another 75 tokens, so the new “limit” becomes 150. The process can continue forever or until your computer runs out of memory… Each chunk of 75 tokens is processed independently, and the resulting representations are concatenated before feeding into Stable Diffusion’s U-Net." * Update Diffusion-macOS/ControlsView.swift spacing Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update Diffusion/Common/Views/PromptTextField.swift increase the mac token count to account for start and end tokens. Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Moved token count below textfield On macOS token count now displays below the textfield right aligned. On iOS the token count displays below the text field left aligned. The text color in the text field was changed from .white to .primary to avoid dark mode/light mode issues. On iOS a bug was fixed where the selected model was not matching due to changes in the way palettized models were being selected. * Correctly select palettized version when needed v21Palettized * Update Diffusion/Common/Views/PromptTextField.swift Comment spacing style guide compliance. Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update Diffusion/Common/Views/PromptTextField.swift Comment spacing style guide compliance. Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update Diffusion/Common/Views/PromptTextField.swift update token count on appear to ensure that the model (if present) at load is used for field token calculations. Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Remove duplicated lines in project --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent b9816f9 commit 6cb0bdc

File tree

4 files changed

+188
-13
lines changed

4 files changed

+188
-13
lines changed

Diffusion-macOS/ControlsView.swift

+18-8
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ struct ControlsView: View {
7373
@State private var showStepsHelp = false
7474
@State private var showSeedHelp = false
7575
@State private var showAdvancedHelp = false
76+
@State private var positiveTokenCount: Int = 0
77+
@State private var negativeTokenCount: Int = 0
7678

7779
// Reasonable range for the slider
7880
let maxSeed: UInt32 = 1000
@@ -89,7 +91,7 @@ struct ControlsView: View {
8991
func resetComputeUnitsState() {
9092
generation.computeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
9193
}
92-
94+
9395
func modelDidChange(model: ModelInfo) {
9496
guard pipelineLoader?.model != model || pipelineLoader?.computeUnits != generation.computeUnits else {
9597
print("Reusing same model \(model) with units \(generation.computeUnits)")
@@ -147,6 +149,19 @@ struct ControlsView: View {
147149
return selectedURL.path
148150
}
149151

152+
private func prompts() -> some View {
153+
VStack {
154+
Spacer()
155+
PromptTextField(text: $generation.positivePrompt, isPositivePrompt: true, model: $model)
156+
.padding(.top, 5)
157+
Spacer()
158+
PromptTextField(text: $generation.negativePrompt, isPositivePrompt: false, model: $model)
159+
.padding(.bottom, 5)
160+
Spacer()
161+
}
162+
.frame(maxHeight: .infinity)
163+
}
164+
150165
var body: some View {
151166
VStack(alignment: .leading) {
152167

@@ -196,13 +211,7 @@ struct ControlsView: View {
196211

197212
DisclosureGroup(isExpanded: $disclosedPrompt) {
198213
Group {
199-
TextField("Positive prompt", text: $generation.positivePrompt,
200-
axis: .vertical).lineLimit(5)
201-
.textFieldStyle(.squareBorder)
202-
.listRowInsets(EdgeInsets(top: 0, leading: -20, bottom: 0, trailing: 20))
203-
TextField("Negative prompt", text: $generation.negativePrompt,
204-
axis: .vertical).lineLimit(5)
205-
.textFieldStyle(.squareBorder)
214+
prompts()
206215
}.padding(.leading, 10)
207216
} label: {
208217
HStack {
@@ -388,3 +397,4 @@ struct ControlsView: View {
388397
}
389398
}
390399
}
400+

Diffusion.xcodeproj/project.pbxproj

+14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
objects = {
88

99
/* Begin PBXBuildFile section */
10+
8CD8A53A2A456EF800BD8A98 /* PromptTextField.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8CD8A5392A456EF800BD8A98 /* PromptTextField.swift */; };
11+
8CD8A53C2A476E2C00BD8A98 /* PromptTextField.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8CD8A5392A456EF800BD8A98 /* PromptTextField.swift */; };
1012
EB067F872992E561004D1AD9 /* HelpContent.swift in Sources */ = {isa = PBXBuildFile; fileRef = EB067F862992E561004D1AD9 /* HelpContent.swift */; };
1113
EB25B3D62A3A2DC4000E25A1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */; };
1214
EB25B3D82A3A2DD5000E25A1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */; };
@@ -61,6 +63,7 @@
6163
/* End PBXContainerItemProxy section */
6264

6365
/* Begin PBXFileReference section */
66+
8CD8A5392A456EF800BD8A98 /* PromptTextField.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PromptTextField.swift; sourceTree = "<group>"; };
6467
EB067F862992E561004D1AD9 /* HelpContent.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = HelpContent.swift; sourceTree = "<group>"; };
6568
EB33A51E2954E1BC00B16357 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = Info.plist; sourceTree = "<group>"; };
6669
EB560F0329A3C20800C0F8B8 /* Capabilities.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Capabilities.swift; sourceTree = "<group>"; };
@@ -134,6 +137,14 @@
134137
/* End PBXFrameworksBuildPhase section */
135138

136139
/* Begin PBXGroup section */
140+
8CD8A53B2A476E1C00BD8A98 /* Views */ = {
141+
isa = PBXGroup;
142+
children = (
143+
8CD8A5392A456EF800BD8A98 /* PromptTextField.swift */,
144+
);
145+
path = Views;
146+
sourceTree = "<group>";
147+
};
137148
8CF53E022A44AE0400E6358B /* Common */ = {
138149
isa = PBXGroup;
139150
children = (
@@ -142,6 +153,7 @@
142153
EBDD7DB72976AAFE00C1C4B2 /* State.swift */,
143154
EBDD7DB22973200200C1C4B2 /* Utils.swift */,
144155
EBB5BA5129425B07003A2A5B /* Pipeline */,
156+
8CD8A53B2A476E1C00BD8A98 /* Views */,
145157
);
146158
name = Common;
147159
path = Diffusion/Common;
@@ -476,6 +488,7 @@
476488
EBE756092941178600806B32 /* Loading.swift in Sources */,
477489
EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */,
478490
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */,
491+
8CD8A53C2A476E2C00BD8A98 /* PromptTextField.swift in Sources */,
479492
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */,
480493
EBDD7DB32973200200C1C4B2 /* Utils.swift in Sources */,
481494
);
@@ -508,6 +521,7 @@
508521
EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */,
509522
EB067F872992E561004D1AD9 /* HelpContent.swift in Sources */,
510523
EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */,
524+
8CD8A53A2A456EF800BD8A98 /* PromptTextField.swift in Sources */,
511525
F1552031297109C300DC009B /* ControlsView.swift in Sources */,
512526
EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */,
513527
F155203429710B3600DC009B /* StatusView.swift in Sources */,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//
2+
// PromptTextField.swift
3+
// Diffusion-macOS
4+
//
5+
// Created by Dolmere on 22/06/2023.
6+
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
7+
//
8+
9+
import SwiftUI
10+
import Combine
11+
import StableDiffusion
12+
13+
struct PromptTextField: View {
14+
@State private var output: String = ""
15+
@State private var input: String = ""
16+
@State private var typing = false
17+
@State private var tokenCount: Int = 0
18+
@State var isPositivePrompt: Bool = true
19+
@State private var tokenizer: BPETokenizer?
20+
@State private var currentModelVersion: String = ""
21+
22+
@Binding var textBinding: String
23+
@Binding var model: String // the model version as it's stored in Settings
24+
25+
private let maxTokenCount = 77
26+
27+
private var modelInfo: ModelInfo? {
28+
ModelInfo.from(modelVersion: $model.wrappedValue)
29+
}
30+
31+
private var filename: String? {
32+
let variant = modelInfo?.bestAttention ?? .original
33+
return modelInfo?.modelURL(for: variant).lastPathComponent
34+
}
35+
36+
private var downloadedURL: URL? {
37+
if let filename = filename {
38+
return PipelineLoader.models.appendingPathComponent(filename)
39+
}
40+
return nil
41+
}
42+
43+
private var packagesFilename: String? {
44+
(filename as NSString?)?.deletingPathExtension
45+
}
46+
47+
private var compiledURL: URL? {
48+
if let packagesFilename = packagesFilename {
49+
return downloadedURL?.deletingLastPathComponent().appendingPathComponent(packagesFilename)
50+
}
51+
return nil
52+
}
53+
54+
private var textColor: Color {
55+
switch tokenCount {
56+
case 0...65:
57+
return .green
58+
case 66...75:
59+
return .orange
60+
default:
61+
return .red
62+
}
63+
}
64+
65+
// macOS initializer
66+
init(text: Binding<String>, isPositivePrompt: Bool, model: Binding<String>) {
67+
_textBinding = text
68+
self.isPositivePrompt = isPositivePrompt
69+
_model = model
70+
}
71+
72+
// iOS initializer
73+
init(text: Binding<String>, isPositivePrompt: Bool, model: String) {
74+
_textBinding = text
75+
self.isPositivePrompt = isPositivePrompt
76+
_model = .constant(model)
77+
}
78+
79+
var body: some View {
80+
VStack {
81+
#if os(macOS)
82+
TextField(isPositivePrompt ? "Positive prompt" : "Negative Prompt", text: $textBinding,
83+
axis: .vertical)
84+
.lineLimit(20)
85+
.textFieldStyle(.squareBorder)
86+
.listRowInsets(EdgeInsets(top: 0, leading: -20, bottom: 0, trailing: 20))
87+
.foregroundColor(textColor == .green ? .primary : textColor)
88+
.frame(minHeight: 30)
89+
if modelInfo != nil && tokenizer != nil {
90+
HStack {
91+
Spacer()
92+
if !textBinding.isEmpty {
93+
Text("\(tokenCount)")
94+
.foregroundColor(textColor)
95+
Text(" / \(maxTokenCount)")
96+
}
97+
}
98+
.onReceive(Just(textBinding)) { text in
99+
updateTokenCount(newText: text)
100+
}
101+
.font(.caption)
102+
}
103+
#else
104+
TextField("Prompt", text: $textBinding, axis: .vertical)
105+
.lineLimit(20)
106+
.listRowInsets(EdgeInsets(top: 0, leading: -20, bottom: 0, trailing: 20))
107+
.foregroundColor(textColor == .green ? .primary : textColor)
108+
.frame(minHeight: 30)
109+
HStack {
110+
if !textBinding.isEmpty {
111+
Text("\(tokenCount)")
112+
.foregroundColor(textColor)
113+
Text(" / \(maxTokenCount)")
114+
}
115+
Spacer()
116+
}
117+
.onReceive(Just(textBinding)) { text in
118+
updateTokenCount(newText: text)
119+
}
120+
.font(.caption)
121+
#endif
122+
}
123+
.onChange(of: model) { model in
124+
updateTokenCount(newText: textBinding)
125+
}
126+
.onAppear {
127+
updateTokenCount(newText: textBinding)
128+
}
129+
}
130+
131+
private func updateTokenCount(newText: String) {
132+
// ensure that the compiled URL exists
133+
guard let compiledURL = compiledURL else { return }
134+
// Initialize the tokenizer only when it's not created yet or the model changes
135+
// Check if the model version has changed
136+
let modelVersion = $model.wrappedValue
137+
if modelVersion != currentModelVersion {
138+
do {
139+
tokenizer = try BPETokenizer(
140+
mergesAt: compiledURL.appendingPathComponent("merges.txt"),
141+
vocabularyAt: compiledURL.appendingPathComponent("vocab.json")
142+
)
143+
currentModelVersion = modelVersion
144+
} catch {
145+
print("Failed to create tokenizer: \(error)")
146+
return
147+
}
148+
}
149+
let (tokens, _) = tokenizer?.tokenize(input: newText) ?? ([], [])
150+
151+
DispatchQueue.main.async {
152+
self.tokenCount = tokens.count
153+
}
154+
}
155+
}

Diffusion/Views/TextToImage.swift

+1-5
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,7 @@ struct TextToImage: View {
113113
var body: some View {
114114
VStack {
115115
HStack {
116-
TextField("Prompt", text: $generation.positivePrompt)
117-
.textFieldStyle(.roundedBorder)
118-
.onSubmit {
119-
submit()
120-
}
116+
PromptTextField(text: $generation.positivePrompt, isPositivePrompt: true, model: deviceSupportsQuantization ? ModelInfo.v21Palettized.modelVersion : ModelInfo.v21Base.modelVersion)
121117
Button("Generate") {
122118
submit()
123119
}

0 commit comments

Comments
 (0)