Add a ModelSamplingAuraFlow node to change the shift value.

Set the default AuraFlow shift value to 1.73 (sqrt(3)).
This commit is contained in:
comfyanonymous 2024-07-11 17:51:56 -04:00
parent 9f291d75b3
commit 8e012043a9
2 changed files with 16 additions and 2 deletions

View File

@ -564,6 +564,7 @@ class AuraFlow(supported_models_base.BASE):
sampling_settings = { sampling_settings = {
"multiplier": 1.0, "multiplier": 1.0,
"shift": 1.73,
} }
unet_extra_config = {} unet_extra_config = {}

View File

@ -144,7 +144,7 @@ class ModelSamplingSD3:
CATEGORY = "advanced/model" CATEGORY = "advanced/model"
def patch(self, model, shift): def patch(self, model, shift, multiplier=1000):
m = model.clone() m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
@ -154,10 +154,22 @@ class ModelSamplingSD3:
pass pass
model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift) model_sampling.set_parameters(shift=shift, multiplier=multiplier)
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )
class ModelSamplingAuraFlow(ModelSamplingSD3):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.73, "min": 0.0, "max": 100.0, "step":0.01}),
}}
FUNCTION = "patch_aura"
def patch_aura(self, model, shift):
return self.patch(model, shift, multiplier=1.0)
class ModelSamplingContinuousEDM: class ModelSamplingContinuousEDM:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -271,5 +283,6 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingContinuousV": ModelSamplingContinuousV, "ModelSamplingContinuousV": ModelSamplingContinuousV,
"ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3, "ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"RescaleCFG": RescaleCFG, "RescaleCFG": RescaleCFG,
} }