Use relative imports for k_diffusion.

This commit is contained in:
comfyanonymous 2023-02-08 16:51:19 -05:00
parent 3e22815a9a
commit bbdcf0b737

View File

@ -1,5 +1,5 @@
import k_diffusion.sampling
import k_diffusion.external
from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
import torch
import contextlib
import model_management
@ -185,9 +185,9 @@ class KSampler:
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
self.model = model
if self.model.parameterization == "v":
self.model_wrap = k_diffusion.external.CompVisVDenoiser(self.model, quantize=True)
self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True)
else:
self.model_wrap = k_diffusion.external.CompVisDenoiser(self.model, quantize=True)
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True)
self.model_k = CFGDenoiserComplex(self.model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
@ -209,7 +209,7 @@ class KSampler:
discard_penultimate_sigma = True
if self.scheduler == "karras":
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device)
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
elif self.scheduler == "simple":
@ -269,9 +269,9 @@ class KSampler:
with precision_scope(self.device):
if self.sampler == "sample_dpm_fast":
samples = k_diffusion.sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
elif self.sampler == "sample_dpm_adaptive":
samples = k_diffusion.sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
else:
samples = getattr(k_diffusion.sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
return samples.to(torch.float32)