ModelSamplingFlux now takes a resolution and adjusts the shift with it.

If you want to sample Flux dev exactly how the reference code does use
the same resolution as your image in this node.
This commit is contained in:
comfyanonymous 2024-08-04 04:06:00 -04:00
parent f7a5107784
commit 56f3c660bf

View File

@ -2,6 +2,7 @@ import folder_paths
import comfy.sd import comfy.sd
import comfy.model_sampling import comfy.model_sampling
import comfy.latent_formats import comfy.latent_formats
import nodes
import torch import torch
class LCM(comfy.model_sampling.EPS): class LCM(comfy.model_sampling.EPS):
@ -174,7 +175,10 @@ class ModelSamplingFlux:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
@ -182,9 +186,15 @@ class ModelSamplingFlux:
CATEGORY = "advanced/model" CATEGORY = "advanced/model"
def patch(self, model, shift): def patch(self, model, max_shift, base_shift, width, height):
m = model.clone() m = model.clone()
x1 = 256
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST sampling_type = comfy.model_sampling.CONST