Merge branch 'master' into hooks_part2

This commit is contained in:
Jedrzej Kosinski 2024-12-30 14:16:22 -06:00
commit bf21be066f
13 changed files with 24 additions and 14 deletions

View File

@ -467,6 +467,13 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
sigmas = adj_idxs.new_zeros(n + 1)
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
return sigmas
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -913,7 +920,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic", "kl_optimal"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas(model_sampling, scheduler_name, steps):
@ -933,6 +940,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = beta_scheduler(model_sampling, steps)
elif scheduler_name == "linear_quadratic":
sigmas = linear_quadratic_schedule(model_sampling, steps)
elif scheduler_name == "kl_optimal":
sigmas = kl_optimal_scheduler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas

View File

@ -54,8 +54,8 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
def get_input_info(class_def, input_name, valid_inputs=None):
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:

View File

@ -93,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x)
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
@ -555,7 +555,7 @@ def validate_inputs(prompt, item, validated):
received_types = {}
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
if x not in inputs:
if input_category == "required":

View File

@ -5,6 +5,7 @@ lint.ignore = ["ALL"]
lint.select = [
"S307", # suspicious-eval-usage
"T201", # print-usage
"W292",
"W293",
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f