mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Load clipvision model to GPU for faster performance.
This commit is contained in:
parent
1300a1bb4c
commit
a094b45c93
@ -2,14 +2,27 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
|
||||
from .utils import load_torch_file, transformers_convert
|
||||
import os
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
config = CLIPVisionConfig.from_json_file(json_config)
|
||||
with comfy.ops.use_comfy_ops():
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||
self.dtype = torch.float16
|
||||
|
||||
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.model = CLIPVisionModelWithProjection(config)
|
||||
self.model.to(self.dtype)
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.processor = CLIPImageProcessor(crop_size=224,
|
||||
do_center_crop=True,
|
||||
do_convert_rgb=True,
|
||||
@ -27,7 +40,16 @@ class ClipVisionModel():
|
||||
img = torch.clip((255. * image), 0, 255).round().int()
|
||||
img = list(map(lambda a: a, img))
|
||||
inputs = self.processor(images=img, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = inputs['pixel_values'].to(self.load_device)
|
||||
|
||||
if self.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
|
||||
outputs = self.model(pixel_values=pixel_values)
|
||||
return outputs
|
||||
|
||||
def convert_to_transformers(sd, prefix):
|
||||
|
Loading…
Reference in New Issue
Block a user