make results more consistent to A1111
This commit is contained in:
+2
-4
@@ -54,7 +54,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
to_args = dict(device=memory_management.text_encoder_device(), dtype=memory_management.text_encoder_dtype())
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(**to_args):
|
||||
with using_forge_operations(**to_args, manual_cast_enabled=True):
|
||||
model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True).to(**to_args)
|
||||
|
||||
load_state_dict(model, state_dict, ignore_errors=[
|
||||
@@ -70,14 +70,12 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
dtype = memory_management.text_encoder_dtype()
|
||||
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
need_cast = False
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
dtype = sd_dtype
|
||||
need_cast = True
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
|
||||
model = IntegratedT5(config)
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
|
||||
Reference in New Issue
Block a user