mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
.sigma and .timestep now return tensors on the same device as the input.
This commit is contained in:
parent
488de0b4df
commit
f30b992b18
@ -65,15 +65,15 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
def timestep(self, sigma):
|
def timestep(self, sigma):
|
||||||
log_sigma = sigma.log()
|
log_sigma = sigma.log()
|
||||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
|
||||||
|
|
||||||
def sigma(self, timestep):
|
def sigma(self, timestep):
|
||||||
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
|
t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
|
||||||
low_idx = t.floor().long()
|
low_idx = t.floor().long()
|
||||||
high_idx = t.ceil().long()
|
high_idx = t.ceil().long()
|
||||||
w = t.frac()
|
w = t.frac()
|
||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
return log_sigma.exp()
|
return log_sigma.exp().to(timestep.device)
|
||||||
|
|
||||||
def percent_to_sigma(self, percent):
|
def percent_to_sigma(self, percent):
|
||||||
if percent <= 0.0:
|
if percent <= 0.0:
|
||||||
|
@ -56,15 +56,15 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
|||||||
def timestep(self, sigma):
|
def timestep(self, sigma):
|
||||||
log_sigma = sigma.log()
|
log_sigma = sigma.log()
|
||||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
|
||||||
|
|
||||||
def sigma(self, timestep):
|
def sigma(self, timestep):
|
||||||
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||||
low_idx = t.floor().long()
|
low_idx = t.floor().long()
|
||||||
high_idx = t.ceil().long()
|
high_idx = t.ceil().long()
|
||||||
w = t.frac()
|
w = t.frac()
|
||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
return log_sigma.exp()
|
return log_sigma.exp().to(timestep.device)
|
||||||
|
|
||||||
def percent_to_sigma(self, percent):
|
def percent_to_sigma(self, percent):
|
||||||
if percent <= 0.0:
|
if percent <= 0.0:
|
||||||
|
Loading…
Reference in New Issue
Block a user