mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Compare commits
7 Commits
20a700a9f1
...
83c0c43734
Author | SHA1 | Date | |
---|---|---|---|
|
83c0c43734 | ||
|
ff838657fa | ||
|
2307ff6746 | ||
|
d0f3752e33 | ||
|
c515bdf371 | ||
|
1a864435f6 | ||
|
ce31337813 |
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
from aiohttp import web
|
||||
import logging
|
||||
|
||||
|
||||
class AppSettings():
|
||||
@ -11,8 +12,12 @@ class AppSettings():
|
||||
file = self.user_manager.get_request_user_filepath(
|
||||
request, "comfy.settings.json")
|
||||
if os.path.isfile(file):
|
||||
try:
|
||||
with open(file) as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
logging.error(f"The user settings file is corrupted: {file}")
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
|
@ -111,7 +111,7 @@ class CLIP:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
self.use_clip_schedule = False
|
||||
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@ -898,7 +898,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded straight to GPU")
|
||||
logging.info("loaded diffusion model directly to GPU")
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
@ -227,8 +227,9 @@ class T5(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
model_dim = config_dict["d_model"]
|
||||
inner_dim = config_dict["d_kv"] * config_dict["num_heads"]
|
||||
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, inner_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
|
||||
|
||||
|
@ -5,19 +5,27 @@ import torch
|
||||
class DifferentialDiffusion():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
}}
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", ),
|
||||
"strength": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0
|
||||
}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "apply"
|
||||
CATEGORY = "_for_testing"
|
||||
INIT = False
|
||||
|
||||
def apply(self, model):
|
||||
def apply(self, model, strength=1.0):
|
||||
model = model.clone()
|
||||
model.set_model_denoise_mask_function(self.forward)
|
||||
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
|
||||
return (model, )
|
||||
|
||||
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
|
||||
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
|
||||
model = extra_options["model"]
|
||||
step_sigmas = extra_options["sigmas"]
|
||||
sigma_to = model.inner_model.model_sampling.sigma_min
|
||||
@ -31,7 +39,15 @@ class DifferentialDiffusion():
|
||||
|
||||
threshold = (current_ts - ts_to) / (ts_from - ts_to)
|
||||
|
||||
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||||
# Generate the binary mask based on the threshold
|
||||
binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||||
|
||||
# Blend binary mask with the original denoise_mask using strength
|
||||
if strength and strength < 1:
|
||||
blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
|
||||
return blended_mask
|
||||
else:
|
||||
return binary_mask
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
@ -4,7 +4,8 @@ lint.ignore = ["ALL"]
|
||||
# Enable specific rules
|
||||
lint.select = [
|
||||
"S307", # suspicious-eval-usage
|
||||
"T201", # print-usage
|
||||
"S102", # exec
|
||||
"T", # print-usage
|
||||
"W",
|
||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
|
Loading…
Reference in New Issue
Block a user