forge 2.0.0
see also discussions
This commit is contained in:
+49
-19
@@ -3,6 +3,7 @@ import torch
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
import backend.args
|
||||
import huggingface_guess
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -69,9 +70,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
config = read_arbitrary_config(config_path)
|
||||
|
||||
dtype = memory_management.text_encoder_dtype()
|
||||
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
sd_dtype = memory_management.state_dict_dtype(state_dict)
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
print(f'Using Detected T5 Data Type: {sd_dtype}')
|
||||
dtype = sd_dtype
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
@@ -81,32 +83,60 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
model_loader = None
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
|
||||
|
||||
with using_forge_operations(**to_args):
|
||||
model = IntegratedUNet2DConditionModel.from_config(unet_config).to(**to_args)
|
||||
model._internal_dict = unet_config
|
||||
storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes)
|
||||
|
||||
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype')
|
||||
|
||||
if unet_storage_dtype_overwrite is not None:
|
||||
storage_dtype = unet_storage_dtype_overwrite
|
||||
else:
|
||||
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4']:
|
||||
print(f'Using Detected UNet Type: {state_dict_dtype}')
|
||||
storage_dtype = state_dict_dtype
|
||||
if state_dict_dtype in ['nf4', 'fp4']:
|
||||
print(f'Using pre-quant state dict!')
|
||||
|
||||
load_device = memory_management.get_torch_device()
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes)
|
||||
offload_device = memory_management.unet_offload_device()
|
||||
|
||||
if storage_dtype in ['nf4', 'fp4']:
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype)
|
||||
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
||||
model = model_loader(unet_config)
|
||||
else:
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype)
|
||||
need_manual_cast = storage_dtype != computation_dtype
|
||||
to_args = dict(device=initial_device, dtype=storage_dtype)
|
||||
|
||||
with using_forge_operations(**to_args, manual_cast_enabled=need_manual_cast):
|
||||
model = model_loader(unet_config).to(**to_args)
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
|
||||
with using_forge_operations(**to_args):
|
||||
model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
|
||||
if hasattr(model, '_internal_dict'):
|
||||
model._internal_dict = unet_config
|
||||
else:
|
||||
model.config = unet_config
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
model.storage_dtype = storage_dtype
|
||||
model.computation_dtype = computation_dtype
|
||||
model.load_device = load_device
|
||||
model.initial_device = initial_device
|
||||
model.offload_device = offload_device
|
||||
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
|
||||
Reference in New Issue
Block a user