From a094b45c93e29193ba2695e834aefe7bf4635c06 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 28 Aug 2023 15:26:29 -0400 Subject: [PATCH] Load clipvision model to GPU for faster performance. --- comfy/clip_vision.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index a887e51b..8635e577 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,14 +2,27 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm from .utils import load_torch_file, transformers_convert import os import torch +import contextlib + import comfy.ops +import comfy.model_patcher +import comfy.model_management class ClipVisionModel(): def __init__(self, json_config): config = CLIPVisionConfig.from_json_file(json_config) - with comfy.ops.use_comfy_ops(): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = torch.float32 + if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): + self.dtype = torch.float16 + + with comfy.ops.use_comfy_ops(offload_device, self.dtype): with modeling_utils.no_init_weights(): self.model = CLIPVisionModelWithProjection(config) + self.model.to(self.dtype) + + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.processor = CLIPImageProcessor(crop_size=224, do_center_crop=True, do_convert_rgb=True, @@ -27,7 +40,16 @@ class ClipVisionModel(): img = torch.clip((255. * image), 0, 255).round().int() img = list(map(lambda a: a, img)) inputs = self.processor(images=img, return_tensors="pt") - outputs = self.model(**inputs) + comfy.model_management.load_model_gpu(self.patcher) + pixel_values = inputs['pixel_values'].to(self.load_device) + + if self.dtype != torch.float32: + precision_scope = torch.autocast + else: + precision_scope = lambda a, b: contextlib.nullcontext(a) + + with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32): + outputs = self.model(pixel_values=pixel_values) return outputs def convert_to_transformers(sd, prefix):