Skip to content

Commit 0977a72

Browse files
authored
Add generation metadata to saved pngs from the generated thumbnails (from right-click or ... menu on thumbnails) (#1942)
* adding generation data to pngs without adding binary dependencies on Pillow which are difficult to fulfil for all platforms * fixing circular imports caused by type-hinting Job * interim commit for choice to save metadata * working again after refactoring * Fixing ruff problems * updating tests * Fixing tests and cleaning up code * remove commented out lines no longer needed * may make writing png slightly faster * adding changes for naming, stray print, translations * remove some comments for linting exceptions caused by circular imports * fixing tests and formatting * fixing types for tests of test_text * ruff reformat
1 parent 30107fc commit 0977a72

File tree

7 files changed

+311
-2
lines changed

7 files changed

+311
-2
lines changed

ai_diffusion/image.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from .settings import settings, ImageFileFormat
1111
from .util import clamp, ensure, is_linux, client_logger as log
1212

13+
import struct
14+
import zlib
15+
1316

1417
def multiple_of(number, multiple):
1518
"""Round up to the nearest multiple of a number."""
@@ -399,6 +402,50 @@ def _mask_op(lhs: "Image", rhs: "Image", mode: QPainter.CompositionMode):
399402
result.reinterpretAsFormat(QImage.Format.Format_Grayscale8)
400403
return Image(result)
401404

405+
@staticmethod
406+
def save_png_w_itxt(img_path: Union[str, Path], png_data: bytes, keyword: str, text: str):
407+
if png_data[:8] != b"\x89PNG\r\n\x1a\n":
408+
raise ValueError("Not a valid PNG file")
409+
410+
offset = 8
411+
ihdr_inserted = False
412+
413+
with open(img_path, "wb") as f:
414+
# Write PNG header
415+
f.write(png_data[:8])
416+
417+
while offset < len(png_data):
418+
length = struct.unpack(">I", png_data[offset : offset + 4])[0]
419+
chunk_type = png_data[offset + 4 : offset + 8]
420+
chunk_data = png_data[offset + 8 : offset + 8 + length]
421+
crc = png_data[offset + 8 + length : offset + 12 + length]
422+
offset += 12 + length
423+
424+
# Write original chunk
425+
f.write(struct.pack(">I", length))
426+
f.write(chunk_type)
427+
f.write(chunk_data)
428+
f.write(crc)
429+
430+
if not ihdr_inserted and chunk_type == b"IHDR":
431+
# Insert iTXt chunk after IHDR
432+
keyword_bytes = keyword.encode("latin1")
433+
text_bytes = text.encode("utf-8")
434+
itxt_data = (
435+
keyword_bytes
436+
+ b"\x00"
437+
+ b"\x00" # compression flag: 0 (not compressed)
438+
+ b"\x00" # compression method: 0
439+
+ b"\x00" # language tag: empty
440+
+ b"\x00" # translated keyword: empty
441+
+ text_bytes
442+
)
443+
f.write(struct.pack(">I", len(itxt_data)))
444+
f.write(b"iTXt")
445+
f.write(itxt_data)
446+
f.write(struct.pack(">I", zlib.crc32(b"iTXt" + itxt_data) & 0xFFFFFFFF))
447+
ihdr_inserted = True
448+
402449
@classmethod
403450
def mask_subtract(cls, lhs: "Image", rhs: "Image"):
404451
return cls._mask_op(rhs, lhs, QPainter.CompositionMode.CompositionMode_SourceOut)
@@ -535,6 +582,10 @@ def save(self, filepath: Union[str, Path]):
535582
finally:
536583
file.close()
537584

585+
def save_png_with_metadata(self, filepath: Union[str, Path], metadata_text: str):
586+
png_bytes = bytes(self.to_bytes(ImageFileFormat.png))
587+
self.save_png_w_itxt(filepath, png_bytes, "parameters", metadata_text)
588+
538589
def debug_save(self, name):
539590
if settings.debug_image_folder:
540591
self.save(Path(settings.debug_image_folder, f"{name}.png"))

ai_diffusion/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .region import Region, RegionLink, RootRegion, process_regions, get_region_inpaint_mask
3636
from .resources import ControlMode
3737
from .resolution import compute_bounds, compute_relative_bounds
38+
from .text import create_img_metadata
3839

3940

4041
class QueueMode(Enum):
@@ -1391,4 +1392,9 @@ def _save_job_result(model: Model, job: Job | None, index: int):
13911392
base_image = model._get_current_image(Bounds(0, 0, *model.document.extent))
13921393
result_image = job.results[index]
13931394
base_image.draw_image(result_image, job.params.bounds.offset)
1394-
base_image.save(path)
1395+
1396+
if settings.save_image_metadata:
1397+
metadata_text = create_img_metadata(job.params)
1398+
base_image.save_png_with_metadata(filepath=path, metadata_text=metadata_text)
1399+
else:
1400+
base_image.save(path)

ai_diffusion/settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,13 @@ class Settings(QObject):
215215
_("Translate text prompts from the selected language to English"),
216216
)
217217

218+
save_image_metadata: bool
219+
_save_image_metadata = Setting(
220+
_("Save Image Metadata"),
221+
False,
222+
_("When saving generated images from thumbnails, include metadata in the PNG"),
223+
)
224+
218225
prompt_line_count: int
219226
_prompt_line_count = Setting(
220227
_("Prompt Line Count"), 2, _("Size of the text editor for image descriptions")

ai_diffusion/text.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .files import FileCollection, FileSource
88
from .localization import translate as _
99
from .util import client_logger as log
10+
from .jobs import JobParams
1011

1112

1213
class LoraId(NamedTuple):
@@ -219,3 +220,53 @@ def edit_attention(text: str, positive: bool) -> str:
219220
if weight == 1.0 and open_bracket == "("
220221
else f"{open_bracket}{attention_string}:{weight:.1f}{close_bracket}"
221222
)
223+
224+
225+
# creates the img text metadata for embedding in PNG files in style like Automatic1111
226+
def create_img_metadata(params: JobParams):
227+
meta = params.metadata
228+
229+
prompt = meta.get("prompt", "")
230+
neg_prompt = meta.get("negative_prompt", "")
231+
sampler_info = meta.get("sampler", "")
232+
model = meta.get("checkpoint", "Unknown")
233+
seed = params.seed
234+
width = params.bounds.width
235+
height = params.bounds.height
236+
strength = meta.get("strength", None)
237+
loras = meta.get("loras", [])
238+
239+
# Try to extract sampler, steps, and cfg scale from "sampler"
240+
match = re.match(r".*?-\s*(.+?)\s*\((\d+)\s*/\s*([\d.]+)\)", sampler_info)
241+
if match:
242+
sampler, steps, cfg_scale = match.groups()
243+
else:
244+
sampler, steps, cfg_scale = sampler_info, "Unknown", "Unknown"
245+
246+
# Embed LoRAs in the prompt
247+
lora_tags = ""
248+
for lora in loras:
249+
if isinstance(lora, dict):
250+
name = lora.get("name")
251+
weight = lora.get("weight", 0.0)
252+
elif isinstance(lora, (list, tuple)) and len(lora) >= 2:
253+
name, weight = lora[0], lora[1]
254+
else:
255+
continue
256+
if weight != 0:
257+
lora_tags += f" <lora:{name}:{weight}>"
258+
259+
full_prompt = f"{prompt.strip()}{lora_tags}"
260+
261+
# Construct output
262+
lines = []
263+
lines.append(f"Prompt: {full_prompt}")
264+
lines.append(f"Negative prompt: {neg_prompt}")
265+
lines.append(
266+
f"Steps: {steps}, Sampler: {sampler}, CFG scale: {cfg_scale}, Seed: {seed}, Size: {width}x{height}, Model hash: unknown, Model: {model}"
267+
)
268+
269+
if strength is not None and strength != 1.0:
270+
lines[-1] += f", Denoising strength: {strength}"
271+
272+
return "\n".join(lines)

ai_diffusion/ui/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def __init__(self):
485485
)
486486
self.add("new_seed_after_apply", SwitchSetting(S._new_seed_after_apply, parent=self))
487487
self.add("debug_dump_workflow", SwitchSetting(S._debug_dump_workflow, parent=self))
488+
self.add("save_image_metadata", SwitchSetting(S._save_image_metadata, parent=self))
488489

489490
languages = [(lang.name, lang.id) for lang in Localization.available]
490491
self._widgets["language"].set_items(languages)

tests/test_image.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import pytest
22
import numpy as np
3+
import struct
4+
import zlib
5+
36
from PyQt5.QtGui import QImage, qRgba
47
from PyQt5.QtCore import Qt, QByteArray
58
from PIL import Image as PILImage
@@ -344,3 +347,57 @@ def test_downscale():
344347
img = create_test_image(12, 8)
345348
result = Image.scale(img, Extent(6, 4))
346349
assert result.width == 6 and result.height == 4
350+
351+
352+
def test_save_png_w_itxt_valid(tmp_path):
353+
# Create a minimal valid PNG file
354+
png_header = b"\x89PNG\r\n\x1a\n"
355+
ihdr_chunk = (
356+
b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01" + b"\x00\x00\x00\x01" + b"\x08\x02\x00\x00\x00"
357+
)
358+
ihdr_crc = struct.pack(
359+
">I",
360+
zlib.crc32(b"IHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00") & 0xFFFFFFFF,
361+
)
362+
iend_chunk = b"\x00\x00\x00\x00IEND" + struct.pack(">I", zlib.crc32(b"IEND") & 0xFFFFFFFF)
363+
png_data = png_header + ihdr_chunk + ihdr_crc + iend_chunk
364+
365+
file_path = tmp_path / "test_image.png"
366+
367+
Image.save_png_w_itxt(
368+
img_path=file_path, png_data=png_data, keyword="testkey", text="testvalue"
369+
)
370+
# (file_path, "testkey", "testvalue")
371+
372+
# Check that the file still starts with PNG header
373+
data = file_path.read_bytes()
374+
assert data.startswith(png_header)
375+
# Check that iTXt chunk is present
376+
assert b"iTXt" in data
377+
assert b"testkey" in data
378+
assert b"testvalue" in data
379+
380+
381+
def test_save_png_w_itxt_invalid(tmp_path):
382+
# Not a PNG file
383+
file_path = tmp_path / "not_png.txt"
384+
png_data = b"not a png"
385+
386+
try:
387+
Image.save_png_w_itxt(img_path=file_path, png_data=png_data, keyword="key", text="value")
388+
assert False, "Should have raised ValueError"
389+
except ValueError as e:
390+
assert "Not a valid PNG file" in str(e)
391+
392+
393+
def test_save_png_with_metadata(tmp_path):
394+
# Create a simple image
395+
img = Image.create(Extent(2, 2), Qt.GlobalColor.red)
396+
file_path = tmp_path / "test_meta.png"
397+
398+
img.save_png_with_metadata(file_path, "my test metadata in the png")
399+
400+
# Check that the file exists and starts with PNG header
401+
data = file_path.read_bytes()
402+
assert data.startswith(b"\x89PNG\r\n\x1a\n")
403+
assert b"my test metadata in the png" in data

tests/test_text.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
from ai_diffusion.text import merge_prompt, extract_loras, edit_attention, select_on_cursor_pos
1+
from ai_diffusion.text import (
2+
merge_prompt,
3+
extract_loras,
4+
edit_attention,
5+
select_on_cursor_pos,
6+
create_img_metadata,
7+
)
28
from ai_diffusion.api import LoraInput
39
from ai_diffusion.files import File, FileCollection
10+
from ai_diffusion.jobs import JobParams
11+
from ai_diffusion.image import Bounds
412

513

614
def test_merge_prompt():
@@ -84,6 +92,134 @@ def test_extract_loras_meta():
8492
)
8593

8694

95+
def test_create_img_metadata_basic():
96+
bounds = Bounds(0, 0, 512, 768)
97+
metadata = {
98+
"prompt": "A cat",
99+
"negative_prompt": "dog",
100+
"sampler": "Euler - euler_a (20 / 7.5)",
101+
"checkpoint": "model.ckpt",
102+
"strength": 0.8,
103+
"loras": [],
104+
}
105+
job_params = JobParams(
106+
bounds=bounds,
107+
name="test",
108+
metadata=metadata,
109+
seed=12345,
110+
)
111+
112+
result = create_img_metadata(job_params)
113+
assert "Prompt: A cat" in result
114+
assert "Negative prompt: dog" in result
115+
assert (
116+
"Steps: 20, Sampler: euler_a, CFG scale: 7.5, Seed: 12345, Size: 512x768, Model hash: unknown, Model: model.ckpt, Denoising strength: 0.8"
117+
in result
118+
)
119+
120+
121+
def test_create_img_metadata_sampler_unmatched():
122+
bounds = Bounds(0, 0, 256, 256)
123+
metadata = {
124+
"prompt": "Test",
125+
"negative_prompt": "",
126+
"sampler": "UnknownSampler",
127+
"checkpoint": "unknown.ckpt",
128+
"loras": [],
129+
}
130+
131+
job_params = JobParams(
132+
bounds=bounds,
133+
name="test",
134+
metadata=metadata,
135+
seed=12345,
136+
)
137+
138+
result = create_img_metadata(job_params)
139+
assert "Sampler: UnknownSampler" in result
140+
assert "Steps: Unknown" in result
141+
assert "CFG scale: Unknown" in result
142+
143+
144+
def test_create_img_metadata_loras_dict_and_tuple():
145+
bounds = Bounds(0, 0, 128, 128)
146+
147+
metadata = {
148+
"prompt": "Prompt",
149+
"negative_prompt": "",
150+
"sampler": "Euler - euler_a (10 / 5.0)",
151+
"checkpoint": "loramodel.ckpt",
152+
"loras": [{"name": "lora1", "weight": 0.7}, ("lora2", 0.5), ["lora3", 0.9]],
153+
}
154+
155+
job_params = JobParams(
156+
bounds=bounds,
157+
name="test",
158+
metadata=metadata,
159+
seed=0,
160+
)
161+
result = create_img_metadata(job_params)
162+
assert "<lora:lora1:0.7>" in result
163+
assert "<lora:lora2:0.5>" in result
164+
assert "<lora:lora3:0.9>" in result
165+
166+
167+
def test_create_img_metadata_strength_none_and_one():
168+
bounds = Bounds(0, 0, 64, 64)
169+
170+
job_params_none = JobParams(
171+
bounds=bounds,
172+
name="test",
173+
metadata={
174+
"prompt": "Prompt",
175+
"negative_prompt": "",
176+
"sampler": "Euler - euler_a (5 / 2.0)",
177+
"checkpoint": "model.ckpt",
178+
"strength": None,
179+
"loras": [],
180+
},
181+
seed=12345,
182+
)
183+
184+
job_params_one = JobParams(
185+
bounds=bounds,
186+
name="test",
187+
metadata={
188+
"prompt": "Prompt",
189+
"negative_prompt": "",
190+
"sampler": "Euler - euler_a (5 / 2.0)",
191+
"checkpoint": "model.ckpt",
192+
"strength": 1.0,
193+
"loras": [],
194+
},
195+
seed=12345,
196+
)
197+
198+
result_none = create_img_metadata(job_params_none)
199+
result_one = create_img_metadata(job_params_one)
200+
assert "Denoising strength" not in result_none
201+
assert "Denoising strength" not in result_one
202+
203+
204+
def test_create_img_metadata_missing_metadata_fields():
205+
jp = JobParams(
206+
bounds=Bounds(0, 0, 100, 200),
207+
name="test",
208+
metadata={},
209+
seed=999,
210+
)
211+
212+
result = create_img_metadata(jp)
213+
assert "Prompt: " in result
214+
assert "Negative prompt: " in result
215+
assert "Steps: Unknown" in result
216+
assert "Sampler: " in result
217+
assert "CFG scale: Unknown" in result
218+
assert "Seed: 999" in result
219+
assert "Size: 100x200" in result
220+
assert "Model: Unknown" in result
221+
222+
87223
class TestEditAttention:
88224
def test_empty_selection(self):
89225
assert edit_attention("", positive=True) == ""

0 commit comments

Comments
 (0)