diff --git a/backend/loader.py b/backend/loader.py index 809a99f6..b479a007 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -2,10 +2,13 @@ import os import torch import logging import importlib + import huggingface_guess from diffusers import DiffusionPipeline from transformers import modeling_utils +from backend import memory_management + from backend.state_dict import try_filter_state_dict, load_state_dict from backend.operations import using_forge_operations from backend.nn.vae import IntegratedAutoencoderKL @@ -57,9 +60,13 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p return model if cls_name == 'UNet2DConditionModel': + unet_config = guess.unet_config.copy() + state_dict_size = memory_management.state_dict_size(state_dict) + unet_config['dtype'] = memory_management.unet_dtype(model_params=state_dict_size) + with using_forge_operations(): - model = IntegratedUNet2DConditionModel.from_config(guess.unet_config) - model._internal_dict = guess.unet_config + model = IntegratedUNet2DConditionModel.from_config(unet_config) + model._internal_dict = unet_config load_state_dict(model, state_dict) return model diff --git a/backend/memory_management.py b/backend/memory_management.py index 2c3699d2..95a2f8e5 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -8,7 +8,7 @@ import platform from enum import Enum from backend import stream -from backend.args import args +from backend.args import args, dynamic_args class VRAMState(Enum): @@ -289,9 +289,8 @@ if 'rtx' in torch_device_name.lower(): current_loaded_models = [] -def module_size(module, exclude_device=None): +def state_dict_size(sd, exclude_device=None): module_mem = 0 - sd = module.state_dict() for k in sd: t = sd[k] @@ -303,6 +302,10 @@ def module_size(module, exclude_device=None): return module_mem +def module_size(module, exclude_device=None): + return state_dict_size(module.state_dict(), exclude_device=exclude_device) + + class LoadedModel: def __init__(self, model, memory_required): self.model = model @@ -563,20 +566,31 @@ def unet_inital_load_device(parameters, dtype): def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): + unet_storage_dtype_overwrite = dynamic_args.get('forge_unet_storage_dtype') + + if unet_storage_dtype_overwrite is not None: + return unet_storage_dtype_overwrite + if args.unet_in_bf16: return torch.bfloat16 + if args.unet_in_fp16: return torch.float16 + if args.unet_in_fp8_e4m3fn: return torch.float8_e4m3fn + if args.unet_in_fp8_e5m2: return torch.float8_e5m2 + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if torch.float16 in supported_dtypes: return torch.float16 + if should_use_bf16(device, model_params=model_params, manual_cast=True): if torch.bfloat16 in supported_dtypes: return torch.bfloat16 + return torch.float32 diff --git a/backend/modules/k_model.py b/backend/modules/k_model.py index 16236227..9fe3200c 100644 --- a/backend/modules/k_model.py +++ b/backend/modules/k_model.py @@ -1,6 +1,6 @@ import torch -from backend import memory_management, attention +from backend import memory_management, attention, operations from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler @@ -11,6 +11,11 @@ class KModel(torch.nn.Module): self.storage_dtype = storage_dtype self.computation_dtype = computation_dtype + need_manual_cast = self.storage_dtype != self.computation_dtype + operations.shift_manual_cast(model, enabled=need_manual_cast) + + print(f'K-Model Created: {dict(storage_dtype=storage_dtype, computation_dtype=computation_dtype, manual_cast=need_manual_cast)}') + self.diffusion_model = model self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler) diff --git a/backend/operations.py b/backend/operations.py index 8a40fa09..a066223d 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -171,3 +171,10 @@ def using_forge_operations(parameters_manual_cast=False, operations=None): for op_name in op_names: setattr(torch.nn, op_name, backups[op_name]) return + + +def shift_manual_cast(model, enabled): + for m in model.modules(): + if hasattr(m, 'parameters_manual_cast'): + m.parameters_manual_cast = enabled + return diff --git a/modules/sd_models.py b/modules/sd_models.py index 32983d34..d3222dfe 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -426,16 +426,45 @@ def get_obj_from_str(string, reload=False): pass -@torch.no_grad() def load_model(checkpoint_info=None, already_loaded_state_dict=None): - checkpoint_info = checkpoint_info or select_checkpoint() + pass + + +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): + pass + + +def reload_model_weights(sd_model=None, info=None, forced_reload=False): + pass + + +def unload_model_weights(sd_model=None, info=None): + pass + + +def apply_token_merging(sd_model, token_merging_ratio): + if token_merging_ratio <= 0: + return + + print(f'token_merging_ratio = {token_merging_ratio}') + + from backend.misc.tomesd import TomePatcher + + sd_model.forge_objects.unet = TomePatcher().patch( + model=sd_model.forge_objects.unet, + ratio=token_merging_ratio + ) + + return + + +@torch.no_grad() +def forge_model_reload(): + checkpoint_info = select_checkpoint() timer = Timer() if model_data.sd_model: - if model_data.sd_model.filename == checkpoint_info.filename: - return model_data.sd_model - model_data.sd_model = None model_data.loaded_sd_models = [] memory_management.unload_all_models() @@ -444,10 +473,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("unload existing model") - if already_loaded_state_dict is not None: - state_dict = already_loaded_state_dict - else: - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model @@ -489,31 +515,3 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): print(f"Model loaded in {timer.summary()}.") return sd_model - - -def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): - pass - - -def reload_model_weights(sd_model=None, info=None, forced_reload=False): - pass - - -def unload_model_weights(sd_model=None, info=None): - pass - - -def apply_token_merging(sd_model, token_merging_ratio): - if token_merging_ratio <= 0: - return - - print(f'token_merging_ratio = {token_merging_ratio}') - - from backend.misc.tomesd import TomePatcher - - sd_model.forge_objects.unet = TomePatcher().patch( - model=sd_model.forge_objects.unet, - ratio=token_merging_ratio - ) - - return diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index f336612d..5ff1fe8f 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -1,21 +1,31 @@ +import torch import gradio as gr from modules import shared_items, shared, ui_common, sd_models from modules import sd_vae as sd_vae_module from modules_forge import main_thread +from backend import args as backend_args ui_checkpoint: gr.Dropdown = None ui_vae: gr.Dropdown = None ui_clip_skip: gr.Slider = None +forge_unet_storage_dtype_options = { + 'None': None, + 'fp8e4m3': torch.float8_e4m3fn, + 'fp8e5m2': torch.float8_e5m2, +} -def bind_to_opts(comp, k, save=False): + +def bind_to_opts(comp, k, save=False, callback=None): def on_change(v): print(f'Setting Changed: {k} = {v}') shared.opts.set(k, v) if save: shared.opts.save(shared.config_filename) + if callback is not None: + callback() return comp.change(on_change, inputs=[comp], show_progress=False) @@ -35,6 +45,7 @@ def make_checkpoint_manager_ui(): ui_checkpoint = gr.Dropdown( value=shared.opts.sd_model_checkpoint, label="Checkpoint", + elem_classes=['model_selection'], **sd_model_checkpoint_args() ) ui_common.create_refresh_button(ui_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint") @@ -47,6 +58,9 @@ def make_checkpoint_manager_ui(): ) ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae") + ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion in FP8", value=shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys())) + bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=lambda: main_thread.async_run(model_load_entry)) + ui_clip_skip = gr.Slider(label="Clip skip", value=shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1}) bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False) @@ -54,7 +68,11 @@ def make_checkpoint_manager_ui(): def model_load_entry(): - sd_models.load_model() + backend_args.dynamic_args.update(dict( + forge_unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype] + )) + + sd_models.forge_model_reload() return diff --git a/style.css b/style.css index 06e6597d..9abbe8e2 100644 --- a/style.css +++ b/style.css @@ -438,7 +438,8 @@ div.toprow-compact-tools{ box-shadow: none; background: none; } -#quicksettings > div.gradio-dropdown{ + +#quicksettings > div.model_selection{ min-width: 24em !important; }