forked from primecai/diffusion-self-distillation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
143 lines (127 loc) · 4.18 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import torch
from PIL import Image
from diffusers.utils import load_image
from pipeline import FluxConditionalPipeline
from transformer import FluxTransformer2DConditionalModel
pipe = None
def init_pipeline(model_path, lora_path):
"""Initialize the global pipeline (pipe)."""
global pipe
transformer = FluxTransformer2DConditionalModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
ignore_mismatched_sizes=True
)
pipe = FluxConditionalPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(
lora_path
)
# pipe.enable_model_cpu_offload()
pipe.to("cuda")
def process_image_and_text(image, text, gemini_prompt, guidance, i_guidance, t_guidance):
"""Process the given image and text using the global pipeline."""
# center-crop image
w, h = image.size
min_size = min(w, h)
image = image.crop(((w - min_size) // 2,
(h - min_size) // 2,
(w + min_size) // 2,
(h + min_size) // 2))
image = image.resize((512, 512))
control_image = load_image(image)
result = pipe(
prompt=text.strip().replace("\n", "").replace("\r", ""),
negative_prompt="",
num_inference_steps=28,
height=512,
width=1024,
guidance_scale=guidance,
image=control_image,
guidance_scale_real_i=i_guidance,
guidance_scale_real_t=t_guidance,
gemini_prompt=gemini_prompt,
).images[0]
return result
def parse_args():
parser = argparse.ArgumentParser(description="Run Diffusion Self-Distillation.")
parser.add_argument(
"--model_path",
type=str,
default="/home/shengqu/repos/SimpleTuner/output/1x2_v1/checkpoint-172000/transformer",
help="Path to the model checkpoint."
)
parser.add_argument(
"--lora_path",
type=str,
default="/home/shengqu/repos/SimpleTuner/output/1x2_v1/checkpoint-172000/pytorch_lora_weights.safetensors",
help="Path to the lora checkpoint."
)
parser.add_argument(
"--image_path",
type=str,
required=True,
help="Path to the input image."
)
parser.add_argument(
"--text",
type=str,
required=True,
help="The text prompt."
)
parser.add_argument(
"--disable_gemini_prompt",
action="store_true",
help="Flag to disable gemini prompt. If not set, gemini_prompt is True."
)
parser.add_argument(
"--guidance",
type=float,
default=3.5,
help="Guidance scale for the pipeline."
)
parser.add_argument(
"--i_guidance",
type=float,
default=1.0,
help="Image guidance scale."
)
parser.add_argument(
"--t_guidance",
type=float,
default=1.0,
help="Text guidance scale."
)
parser.add_argument(
"--output_path",
type=str,
default="output.png",
help="Path to save the output image."
)
return parser.parse_args()
def main():
args = parse_args()
# Initialize pipeline
init_pipeline(args.model_path, args.lora_path)
# Open the image
image = Image.open(args.image_path).convert("RGB")
# Process image and text
result_image = process_image_and_text(
image,
args.text,
not args.disable_gemini_prompt,
args.guidance,
args.i_guidance,
args.t_guidance
)
# Save the output
result_image.save(args.output_path)
print(f"Output saved to {args.output_path}")
if __name__ == "__main__":
main()
# CUDA_VISIBLE_DEVICES=7 python generate.py --model_path /home/shengqu/repos/SimpleTuner/output/1x2_v1/checkpoint-172000/transformer --lora_path /home/shengqu/repos/SimpleTuner/output/1x2_v1/checkpoint-172000/pytorch_lora_weights.safetensors --image_path /home/shengqu/repos/dreambench_plus/conditioning_images/seededit_example.png --text "this character sitting on a chair" --output_path output.png