implement multi image prompting for gpt-image-1 and fix transparency in outputs (#7763)

* implement multi image prompting for GPTI Image 1

* fix transparency not working

* fix ruff
This commit is contained in:
thot experiment 2025-04-23 13:10:10 -07:00 committed by GitHub
parent e8ddc2be95
commit 2c1d686ec6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -53,7 +53,7 @@ def validate_and_cast_response (response):
raise Exception("Failed to download the image")
img = Image.open(io.BytesIO(img_response.content))
img = img.convert("RGB") # Ensure RGB format
img = img.convert("RGBA")
# Convert to numpy array, normalize to float32 between 0 and 1
img_array = np.array(img).astype(np.float32) / 255.0
@ -339,25 +339,38 @@ class OpenAIGPTImage1(ComfyNodeABC):
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
request_class = OpenAIImageGenerationRequest
img_binary = None
img_binaries = []
mask_binary = None
files = []
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
scaled_image = downscale_input(image).squeeze()
batch_size = image.shape[0]
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_binary = img_byte_arr#.getvalue()
img_binary.name = "image.png"
for i in range(batch_size):
single_image = image[i:i+1]
scaled_image = downscale_input(single_image).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_binary = img_byte_arr
img_binary.name = f"image_{i}.png"
img_binaries.append(img_binary)
if batch_size == 1:
files.append(("image", img_binary))
else:
files.append(("image[]", img_binary))
if mask is not None:
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if image is None:
raise Exception("Cannot use a mask without an input image")
if mask.shape[1:] != image.shape[1:-1]:
@ -373,14 +386,10 @@ class OpenAIGPTImage1(ComfyNodeABC):
mask_img_byte_arr = io.BytesIO()
mask_img.save(mask_img_byte_arr, format='PNG')
mask_img_byte_arr.seek(0)
mask_binary = mask_img_byte_arr#.getvalue()
mask_binary = mask_img_byte_arr
mask_binary.name = "mask.png"
files.append(("mask", mask_binary))
files = {}
if img_binary:
files["image"] = img_binary
if mask_binary:
files["mask"] = mask_binary
# Build the operation
operation = SynchronousOperation(