ControlNet API (#162)

* ControlNet API

* update cache key

* nits

* disable controlnet tests
This commit is contained in:
Chenlei Hu
2024-02-10 06:16:13 +00:00
committed by GitHub
parent bd0878754c
commit 5a7e755528
14 changed files with 630 additions and 30 deletions
@@ -11,9 +11,10 @@ from modules.api.api import decode_base64_to_image
import gradio as gr
from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
prepare_mask, judge_image_type
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \
@@ -21,6 +22,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion
from lib_controlnet.infotext import Infotext
from modules_forge.forge_util import HWC3, numpy_to_pytorch
from lib_controlnet.enums import HiResFixOption
from lib_controlnet.api import controlnet_api
import numpy as np
import functools
@@ -67,7 +69,7 @@ class ControlNetForForgeOfficial(scripts.Script):
max_models = shared.opts.data.get("control_net_unit_count", 3)
gen_type = "img2img" if is_img2img else "txt2img"
elem_id_tabname = gen_type + "_controlnet"
default_unit = UiControlNetUnit(enabled=False, module="None", model="None")
default_unit = ControlNetUnit(enabled=False, module="None", model="None")
with gr.Group(elem_id=elem_id_tabname):
with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet",
elem_classes=["controlnet"]):
@@ -95,13 +97,19 @@ class ControlNetForForgeOfficial(scripts.Script):
return tuple(controls)
def get_enabled_units(self, units):
# Parse dict from API calls.
units = [
ControlNetUnit.from_dict(unit) if isinstance(unit, dict) else unit
for unit in units
]
assert all(isinstance(unit, ControlNetUnit) for unit in units)
enabled_units = [x for x in units if x.enabled]
return enabled_units
@staticmethod
def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
input_image: np.ndarray,
resize_mode: external_code.ResizeMode,
preprocessor
@@ -252,7 +260,7 @@ class ControlNetForForgeOfficial(scripts.Script):
@torch.no_grad()
def process_unit_after_click_generate(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
@@ -279,8 +287,6 @@ class ControlNetForForgeOfficial(scripts.Script):
return tqdm(iterable) if use_tqdm else iterable
for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1):
# p.extra_result_images.append(input_image)
if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
input_image,
@@ -319,14 +325,20 @@ class ControlNetForForgeOfficial(scripts.Script):
hr_option = HiResFixOption.BOTH
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False):
if (
(is_high_res and hr_option.high_res_enabled) or
(not is_high_res and hr_option.low_res_enabled)
) and unit.save_detected_map:
p.extra_result_images.append(img)
if preprocessor_output_is_image:
params.control_cond = []
params.control_cond_for_hr_fix = []
for preprocessor_output in preprocessor_outputs:
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
if hr_option.low_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond))
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))
params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
@@ -334,8 +346,7 @@ class ControlNetForForgeOfficial(scripts.Script):
if has_high_res_fix:
for preprocessor_output in preprocessor_outputs:
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
if hr_option.high_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond_for_hr_fix), is_high_res=True)
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
else:
@@ -343,7 +354,7 @@ class ControlNetForForgeOfficial(scripts.Script):
else:
params.control_cond = preprocessor_output
params.control_cond_for_hr_fix = preprocessor_output
p.extra_result_images.append(input_image)
attach_extra_result_image(input_image)
if len(control_masks) > 0:
params.control_mask = []
@@ -352,15 +363,13 @@ class ControlNetForForgeOfficial(scripts.Script):
for input_mask in control_masks:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
if hr_option.low_res_enabled:
p.extra_result_images.append(control_mask)
attach_extra_result_image(control_mask)
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
params.control_mask.append(control_mask)
if has_high_res_fix:
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
if hr_option.high_res_enabled:
p.extra_result_images.append(control_mask_for_hr_fix)
attach_extra_result_image(control_mask_for_hr_fix, is_high_res=True)
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
@@ -390,7 +399,7 @@ class ControlNetForForgeOfficial(scripts.Script):
@torch.no_grad()
def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
@@ -473,14 +482,14 @@ class ControlNetForForgeOfficial(scripts.Script):
return
@staticmethod
def bound_check_params(unit: external_code.ControlNetUnit) -> None:
def bound_check_params(unit: ControlNetUnit) -> None:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit'.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
Args:
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
unit (ControlNetUnit): The ControlNetUnit instance to check.
"""
preprocessor = global_state.get_preprocessor(unit.module)
@@ -498,7 +507,7 @@ class ControlNetForForgeOfficial(scripts.Script):
@torch.no_grad()
def process_unit_after_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
@@ -577,3 +586,4 @@ script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted)
script_callbacks.on_after_component(ControlNetUiGroup.on_after_component)
script_callbacks.on_before_reload(ControlNetUiGroup.reset)
script_callbacks.on_app_started(controlnet_api)