diffusion in fp8 landed
This commit is contained in:
@@ -1,21 +1,31 @@
|
||||
import torch
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared_items, shared, ui_common, sd_models
|
||||
from modules import sd_vae as sd_vae_module
|
||||
from modules_forge import main_thread
|
||||
from backend import args as backend_args
|
||||
|
||||
|
||||
ui_checkpoint: gr.Dropdown = None
|
||||
ui_vae: gr.Dropdown = None
|
||||
ui_clip_skip: gr.Slider = None
|
||||
|
||||
forge_unet_storage_dtype_options = {
|
||||
'None': None,
|
||||
'fp8e4m3': torch.float8_e4m3fn,
|
||||
'fp8e5m2': torch.float8_e5m2,
|
||||
}
|
||||
|
||||
def bind_to_opts(comp, k, save=False):
|
||||
|
||||
def bind_to_opts(comp, k, save=False, callback=None):
|
||||
def on_change(v):
|
||||
print(f'Setting Changed: {k} = {v}')
|
||||
shared.opts.set(k, v)
|
||||
if save:
|
||||
shared.opts.save(shared.config_filename)
|
||||
if callback is not None:
|
||||
callback()
|
||||
return
|
||||
|
||||
comp.change(on_change, inputs=[comp], show_progress=False)
|
||||
@@ -35,6 +45,7 @@ def make_checkpoint_manager_ui():
|
||||
ui_checkpoint = gr.Dropdown(
|
||||
value=shared.opts.sd_model_checkpoint,
|
||||
label="Checkpoint",
|
||||
elem_classes=['model_selection'],
|
||||
**sd_model_checkpoint_args()
|
||||
)
|
||||
ui_common.create_refresh_button(ui_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint")
|
||||
@@ -47,6 +58,9 @@ def make_checkpoint_manager_ui():
|
||||
)
|
||||
ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae")
|
||||
|
||||
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion in FP8", value=shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
|
||||
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=lambda: main_thread.async_run(model_load_entry))
|
||||
|
||||
ui_clip_skip = gr.Slider(label="Clip skip", value=shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1})
|
||||
bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False)
|
||||
|
||||
@@ -54,7 +68,11 @@ def make_checkpoint_manager_ui():
|
||||
|
||||
|
||||
def model_load_entry():
|
||||
sd_models.load_model()
|
||||
backend_args.dynamic_args.update(dict(
|
||||
forge_unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype]
|
||||
))
|
||||
|
||||
sd_models.forge_model_reload()
|
||||
return
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user