revise kernel

This commit is contained in:
lllyasviel
2024-08-07 13:28:12 -07:00
committed by GitHub
parent 1ef0844225
commit 14a759b5ca
10 changed files with 317 additions and 420 deletions
+41 -24
View File
@@ -7,8 +7,9 @@ import huggingface_guess
from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend import memory_management
from backend import memory_management
from backend.utils import read_arbitrary_config
from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
@@ -20,7 +21,9 @@ from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sdxl import StableDiffusionXL
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL]
possible_models = [
StableDiffusion, StableDiffusion2, StableDiffusionXL,
]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -62,38 +65,52 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
], log_name=cls_name)
return model
if cls_name == 'T5EncoderModel':
from transformers import T5EncoderModel, T5Config
config = T5Config.from_pretrained(config_path)
dtype = memory_management.text_encoder_dtype()
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
dtype = sd_dtype
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=dtype):
model = IntegratedCLIP(T5EncoderModel, config)
load_state_dict(model, state_dict, log_name=cls_name)
return model
# if cls_name == 'T5EncoderModel':
# from backend.nn.t5 import IntegratedT5
# config = read_arbitrary_config(config_path)
#
# dtype = memory_management.text_encoder_dtype()
# sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
# need_cast = False
#
# if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# dtype = sd_dtype
# need_cast = True
#
# with modeling_utils.no_init_weights():
# with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
# model = IntegratedT5(config)
#
# load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
#
# return model
if cls_name == 'UNet2DConditionModel':
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
to_args = dict(device=ini_device, dtype=ini_dtype)
unet_config['dtype'] = ini_dtype
unet_config['device'] = ini_device
with using_forge_operations(device=ini_device, dtype=ini_dtype):
model = IntegratedUNet2DConditionModel.from_config(unet_config)
with using_forge_operations(**to_args):
model = IntegratedUNet2DConditionModel.from_config(unet_config).to(**to_args)
model._internal_dict = unet_config
load_state_dict(model, state_dict)
return model
# if cls_name == 'FluxTransformer2DModel':
# from backend.nn.flux import IntegratedFluxTransformer2DModel
# unet_config = guess.unet_config.copy()
# state_dict_size = memory_management.state_dict_size(state_dict)
# ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
# ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
# to_args = dict(device=ini_device, dtype=ini_dtype)
#
# with using_forge_operations(**to_args):
# model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
# model.config = unet_config
#
# load_state_dict(model, state_dict)
# return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
return None