From 78f79266a9be9d9316e7feb6d6a0ed6ac616547d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 21 Jun 2025 21:19:41 -0700 Subject: [PATCH] Allow padding in ImageStitch node to be white. (#8631) --- comfy_extras/nodes_images.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index b1e0d466695..8d5fcdb855c 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -304,10 +304,23 @@ def stitch( image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" ).movedim(1, -1) + color_map = { + "white": 1.0, + "black": 0.0, + "red": (1.0, 0.0, 0.0), + "green": (0.0, 1.0, 0.0), + "blue": (0.0, 0.0, 1.0), + } + + color_val = color_map[spacing_color] + # When not matching sizes, pad to align non-concat dimensions if not match_image_size: h1, w1 = image1.shape[1:3] h2, w2 = image2.shape[1:3] + pad_value = 0.0 + if not isinstance(color_val, tuple): + pad_value = color_val if direction in ["left", "right"]: # For horizontal concat, pad heights to match @@ -316,11 +329,11 @@ def stitch( if h1 < target_h: pad_h = target_h - h1 pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 - image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) if h2 < target_h: pad_h = target_h - h2 pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 - image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) else: # up, down # For vertical concat, pad widths to match if w1 != w2: @@ -328,11 +341,11 @@ def stitch( if w1 < target_w: pad_w = target_w - w1 pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 - image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) if w2 < target_w: pad_w = target_w - w2 pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 - image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) # Ensure same number of channels if image1.shape[-1] != image2.shape[-1]: @@ -366,15 +379,6 @@ def stitch( if spacing_width > 0: spacing_width = spacing_width + (spacing_width % 2) # Ensure even - color_map = { - "white": 1.0, - "black": 0.0, - "red": (1.0, 0.0, 0.0), - "green": (0.0, 1.0, 0.0), - "blue": (0.0, 0.0, 1.0), - } - color_val = color_map[spacing_color] - if direction in ["left", "right"]: spacing_shape = ( image1.shape[0],