Significantly reduce thread abuse for faster model moving

This will move all major gradio calls into the main thread rather than random gradio threads.
This ensures that all torch.module.to() are performed in main thread to completely possible avoid GPU fragments.
In my test now model moving is 0.7 ~ 1.2 seconds faster, which means all 6GB/8GB VRAM users will get 0.7 ~ 1.2 seconds faster per image on SDXL.
This commit is contained in:
lllyasviel
2024-02-08 10:13:59 -08:00
parent 291ec743b6
commit f06ba8e60b
8 changed files with 122 additions and 31 deletions
+6 -1
View File
@@ -15,6 +15,7 @@ import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.scripts
from modules_forge import main_thread
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
@@ -146,7 +147,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
return batch_results
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -244,3 +245,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
processed.images = []
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
return main_thread.run_and_wait_result(img2img_function, id_task, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps, sampler_name, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, request, *args)
+3 -18
View File
@@ -149,24 +149,9 @@ def initialize_rest(*, reload_script_modules=False):
sd_unet.list_unets()
startup_timer.record("scripts list_unets")
def load_model():
"""
Accesses shared.sd_model property to load model.
After it's available, if it has been loaded before this access by some extension,
its optimization may be None because the list of optimizaers has neet been filled
by that time, so we apply optimization again.
"""
from modules import devices
devices.torch_npu_set_device()
shared.sd_model # noqa: B018
if sd_hijack.current_optimizer is None:
sd_hijack.apply_optimizations()
devices.first_time_calculation()
if not shared.cmd_opts.skip_load_model_at_start:
Thread(target=load_model).start()
from modules_forge import main_thread
import modules.sd_models
main_thread.async_run(modules.sd_models.model_data.get_sd_model)
from modules import shared_items
shared_items.reload_hypernetworks()
+4 -3
View File
@@ -170,10 +170,11 @@ def configure_sigint_handler():
def configure_opts_onchange():
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
from modules.call_queue import wrap_queued_call
from modules_forge import main_thread
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+5
View File
@@ -511,6 +511,11 @@ def start():
else:
webui.webui()
from modules_forge import main_thread
main_thread.loop()
return
def dump_sysinfo():
from modules import sysinfo
+8 -1
View File
@@ -2,6 +2,8 @@ import datetime
import logging
import threading
import time
import traceback
import torch
from modules import errors, shared, devices
from typing import Optional
@@ -134,6 +136,7 @@ class State:
devices.torch_gc()
@torch.inference_mode()
def set_current_image(self):
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
if not shared.parallel_processing_allowed:
@@ -142,6 +145,7 @@ class State:
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
@torch.inference_mode()
def do_set_current_image(self):
if self.current_latent is None:
return
@@ -156,11 +160,14 @@ class State:
self.current_image_sampling_step = self.sampling_step
except Exception:
except Exception as e:
# traceback.print_exc()
# print(e)
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
# we silently ignore this error
errors.record_exception()
@torch.inference_mode()
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1
+11 -2
View File
@@ -9,6 +9,7 @@ import modules.shared as shared
from modules.ui import plaintext_to_html
from PIL import Image
import gradio as gr
from modules_forge import main_thread
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
@@ -56,7 +57,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
return p
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
assert len(gallery) > 0, 'No image to upscale'
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
@@ -100,7 +101,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
def txt2img(id_task: str, request: gr.Request, *args):
def txt2img_function(id_task: str, request: gr.Request, *args):
p = txt2img_create_processing(id_task, request, *args)
with closing(p):
@@ -119,3 +120,11 @@ def txt2img(id_task: str, request: gr.Request, *args):
processed.images = []
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
return main_thread.run_and_wait_result(txt2img_upscale_function, id_task, request, gallery, gallery_index, generation_info, *args)
def txt2img(id_task: str, request: gr.Request, *args):
return main_thread.run_and_wait_result(txt2img_function, id_task, request, *args)