Fix clip vision lowvram mode not working.

This commit is contained in:
comfyanonymous 2023-12-27 13:50:57 -05:00
parent e478b1794e
commit c782144433

View File

@ -151,7 +151,7 @@ class CLIPVisionEmbeddings(torch.nn.Module):
def forward(self, pixel_values): def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module): class CLIPVision(torch.nn.Module):