UNet from Scratch
Now backend rewrite is about 50% finished. Estimated finish is in 72 hours. After that, many newer features will land.
This commit is contained in:
+12
-1
@@ -6,8 +6,10 @@ from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.autoencoder_kl import IntegratedAutoencoderKL
|
||||
from backend.nn.vae import IntegratedAutoencoderKL
|
||||
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
|
||||
from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
|
||||
dir_path = os.path.dirname(__file__)
|
||||
|
||||
@@ -54,6 +56,15 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale',
|
||||
'transformer.text_model.embeddings.position_ids'])
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
|
||||
config = IntegratedUNet2DConditionModel.load_config(config_path)
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedUNet2DConditionModel.from_config(config)
|
||||
|
||||
load_state_dict(model, sd)
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user