ModelSamplingFlux now takes a resolution and adjusts the shift with it.

If you want to sample Flux dev exactly how the reference code does use
the same resolution as your image in this node.
This commit is contained in:
comfyanonymous 2024-08-04 04:06:00 -04:00
parent f7a5107784
commit 56f3c660bf

View File

@ -2,6 +2,7 @@ import folder_paths
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
import nodes
import torch
class LCM(comfy.model_sampling.EPS):
@ -174,7 +175,10 @@ class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("MODEL",)
@ -182,9 +186,15 @@ class ModelSamplingFlux:
CATEGORY = "advanced/model"
def patch(self, model, shift):
def patch(self, model, max_shift, base_shift, width, height):
m = model.clone()
x1 = 256
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST