revise kernel

and add unused files
This commit is contained in:
lllyasviel
2024-08-07 16:51:24 -07:00
committed by GitHub
parent a07c758658
commit a6baf4a4b5
11 changed files with 700 additions and 52 deletions
+35 -36
View File
@@ -19,11 +19,10 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sdxl import StableDiffusionXL
from backend.diffusion_engine.flux import Flux
possible_models = [
StableDiffusion, StableDiffusion2, StableDiffusionXL,
]
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -65,25 +64,25 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
], log_name=cls_name)
return model
# if cls_name == 'T5EncoderModel':
# from backend.nn.t5 import IntegratedT5
# config = read_arbitrary_config(config_path)
#
# dtype = memory_management.text_encoder_dtype()
# sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
# need_cast = False
#
# if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# dtype = sd_dtype
# need_cast = True
#
# with modeling_utils.no_init_weights():
# with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
# model = IntegratedT5(config)
#
# load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
#
# return model
if cls_name == 'T5EncoderModel':
from backend.nn.t5 import IntegratedT5
config = read_arbitrary_config(config_path)
dtype = memory_management.text_encoder_dtype()
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
need_cast = False
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
dtype = sd_dtype
need_cast = True
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
model = IntegratedT5(config)
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
return model
if cls_name == 'UNet2DConditionModel':
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
@@ -97,20 +96,20 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict)
return model
# if cls_name == 'FluxTransformer2DModel':
# from backend.nn.flux import IntegratedFluxTransformer2DModel
# unet_config = guess.unet_config.copy()
# state_dict_size = memory_management.state_dict_size(state_dict)
# ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
# ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
# to_args = dict(device=ini_device, dtype=ini_dtype)
#
# with using_forge_operations(**to_args):
# model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
# model.config = unet_config
#
# load_state_dict(model, state_dict)
# return model
if cls_name == 'FluxTransformer2DModel':
from backend.nn.flux import IntegratedFluxTransformer2DModel
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
to_args = dict(device=ini_device, dtype=ini_dtype)
with using_forge_operations(**to_args):
model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
model.config = unet_config
load_state_dict(model, state_dict)
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
return None