revise kernel
and add unused files
This commit is contained in:
+35
-36
@@ -19,11 +19,10 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
|
||||
|
||||
possible_models = [
|
||||
StableDiffusion, StableDiffusion2, StableDiffusionXL,
|
||||
]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -65,25 +64,25 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
], log_name=cls_name)
|
||||
|
||||
return model
|
||||
# if cls_name == 'T5EncoderModel':
|
||||
# from backend.nn.t5 import IntegratedT5
|
||||
# config = read_arbitrary_config(config_path)
|
||||
#
|
||||
# dtype = memory_management.text_encoder_dtype()
|
||||
# sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
# need_cast = False
|
||||
#
|
||||
# if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
# dtype = sd_dtype
|
||||
# need_cast = True
|
||||
#
|
||||
# with modeling_utils.no_init_weights():
|
||||
# with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
|
||||
# model = IntegratedT5(config)
|
||||
#
|
||||
# load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
#
|
||||
# return model
|
||||
if cls_name == 'T5EncoderModel':
|
||||
from backend.nn.t5 import IntegratedT5
|
||||
config = read_arbitrary_config(config_path)
|
||||
|
||||
dtype = memory_management.text_encoder_dtype()
|
||||
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
need_cast = False
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
dtype = sd_dtype
|
||||
need_cast = True
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
|
||||
model = IntegratedT5(config)
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
@@ -97,20 +96,20 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
# if cls_name == 'FluxTransformer2DModel':
|
||||
# from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
# unet_config = guess.unet_config.copy()
|
||||
# state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
# ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
# ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
# to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
#
|
||||
# with using_forge_operations(**to_args):
|
||||
# model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
|
||||
# model.config = unet_config
|
||||
#
|
||||
# load_state_dict(model, state_dict)
|
||||
# return model
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
|
||||
with using_forge_operations(**to_args):
|
||||
model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
|
||||
model.config = unet_config
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user