Small optimizations.

This commit is contained in:
comfyanonymous 2024-12-18 18:23:28 -05:00
parent 0c04a6ae78
commit cbbf077593

View File

@ -30,10 +30,10 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar) self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar) self.var = torch.exp(self.logvar)
if self.deterministic: if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self): def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x return x
def kl(self, other=None): def kl(self, other=None):