.sigma and .timestep now return tensors on the same device as the input.

This commit is contained in:
comfyanonymous 2023-11-27 16:41:33 -05:00
parent 488de0b4df
commit f30b992b18
2 changed files with 6 additions and 6 deletions

View File

@ -65,15 +65,15 @@ class ModelSamplingDiscrete(torch.nn.Module):
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep): def sigma(self, timestep):
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long() low_idx = t.floor().long()
high_idx = t.ceil().long() high_idx = t.ceil().long()
w = t.frac() w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0:

View File

@ -56,15 +56,15 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
def sigma(self, timestep): def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long() low_idx = t.floor().long()
high_idx = t.ceil().long() high_idx = t.ceil().long()
w = t.frac() w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0: