revise kernel
This commit is contained in:
+41
-24
@@ -7,8 +7,9 @@ import huggingface_guess
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
from backend import memory_management
|
||||
|
||||
from backend import memory_management
|
||||
from backend.utils import read_arbitrary_config
|
||||
from backend.state_dict import try_filter_state_dict, load_state_dict
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.vae import IntegratedAutoencoderKL
|
||||
@@ -20,7 +21,9 @@ from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL]
|
||||
possible_models = [
|
||||
StableDiffusion, StableDiffusion2, StableDiffusionXL,
|
||||
]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -62,38 +65,52 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
], log_name=cls_name)
|
||||
|
||||
return model
|
||||
if cls_name == 'T5EncoderModel':
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
config = T5Config.from_pretrained(config_path)
|
||||
|
||||
dtype = memory_management.text_encoder_dtype()
|
||||
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
dtype = sd_dtype
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype):
|
||||
model = IntegratedCLIP(T5EncoderModel, config)
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name)
|
||||
|
||||
return model
|
||||
# if cls_name == 'T5EncoderModel':
|
||||
# from backend.nn.t5 import IntegratedT5
|
||||
# config = read_arbitrary_config(config_path)
|
||||
#
|
||||
# dtype = memory_management.text_encoder_dtype()
|
||||
# sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
# need_cast = False
|
||||
#
|
||||
# if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
# dtype = sd_dtype
|
||||
# need_cast = True
|
||||
#
|
||||
# with modeling_utils.no_init_weights():
|
||||
# with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
|
||||
# model = IntegratedT5(config)
|
||||
#
|
||||
# load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
#
|
||||
# return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
|
||||
unet_config['dtype'] = ini_dtype
|
||||
unet_config['device'] = ini_device
|
||||
|
||||
with using_forge_operations(device=ini_device, dtype=ini_dtype):
|
||||
model = IntegratedUNet2DConditionModel.from_config(unet_config)
|
||||
with using_forge_operations(**to_args):
|
||||
model = IntegratedUNet2DConditionModel.from_config(unet_config).to(**to_args)
|
||||
model._internal_dict = unet_config
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
# if cls_name == 'FluxTransformer2DModel':
|
||||
# from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
# unet_config = guess.unet_config.copy()
|
||||
# state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
# ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
# ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
# to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
#
|
||||
# with using_forge_operations(**to_args):
|
||||
# model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
|
||||
# model.config = unet_config
|
||||
#
|
||||
# load_state_dict(model, state_dict)
|
||||
# return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user