diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index e03b88fd..5cdb23e7 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -57,29 +57,12 @@ class ControlNetForForgeOfficial(scripts.Script): def show(self, is_img2img): return scripts.AlwaysVisible - @staticmethod - def get_default_ui_unit(is_ui=True): - cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit - return cls( - enabled=False, - module="None", - model="None" - ) - - def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ - ControlNetUiGroup, gr.State]: - group = ControlNetUiGroup( - is_img2img, - self.get_default_ui_unit(), - photopea, - ) + def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]: + default_unit = UiControlNetUnit(enabled=False, module="None", model="None") + group = ControlNetUiGroup(is_img2img, default_unit, photopea) return group, group.render(tabname, elem_id_tabname) def ui(self, is_img2img): - """this function should create gradio UI elements. See https://gradio.app/docs/#components - The return value should be an array of all components that are used in processing. - Values of those returned components will be passed to run() and process() functions. - """ infotext = Infotext() ui_groups = [] controls = [] @@ -107,65 +90,11 @@ class ControlNetForForgeOfficial(scripts.Script): if shared.opts.data.get("control_net_sync_field_args", True): self.infotext_fields = infotext.infotext_fields self.paste_field_names = infotext.paste_field_names - return tuple(controls) - @staticmethod - def get_remote_call(p, attribute, default=None, idx=0, strict=False, force=False): - if not force and not shared.opts.data.get("control_net_allow_script_control", False): - return default - - def get_element(obj, strict=False): - if not isinstance(obj, list): - return obj if not strict or idx == 0 else None - elif idx < len(obj): - return obj[idx] - else: - return None - - attribute_value = get_element(getattr(p, attribute, None), strict) - return attribute_value if attribute_value is not None else default - - def parse_remote_call(self, p, unit: external_code.ControlNetUnit, idx): - selector = self.get_remote_call - - unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True) - unit.module = selector(p, "control_net_module", unit.module, idx) - unit.model = selector(p, "control_net_model", unit.model, idx) - unit.weight = selector(p, "control_net_weight", unit.weight, idx) - unit.image = selector(p, "control_net_image", unit.image, idx) - unit.resize_mode = selector(p, "control_net_resize_mode", unit.resize_mode, idx) - unit.low_vram = selector(p, "control_net_lowvram", unit.low_vram, idx) - unit.processor_res = selector(p, "control_net_pres", unit.processor_res, idx) - unit.threshold_a = selector(p, "control_net_pthr_a", unit.threshold_a, idx) - unit.threshold_b = selector(p, "control_net_pthr_b", unit.threshold_b, idx) - unit.guidance_start = selector(p, "control_net_guidance_start", unit.guidance_start, idx) - unit.guidance_end = selector(p, "control_net_guidance_end", unit.guidance_end, idx) - unit.guidance_end = selector(p, "control_net_guidance_strength", unit.guidance_end, idx) - unit.control_mode = selector(p, "control_net_control_mode", unit.control_mode, idx) - unit.pixel_perfect = selector(p, "control_net_pixel_perfect", unit.pixel_perfect, idx) - - return unit - def get_enabled_units(self, p): units = external_code.get_all_units_in_processing(p) - if len(units) == 0: - # fill a null group - remote_unit = self.parse_remote_call(p, self.get_default_ui_unit(), 0) - if remote_unit.enabled: - units.append(remote_unit) - - enabled_units = [] - for idx, unit in enumerate(units): - local_unit = self.parse_remote_call(p, unit, idx) - if not local_unit.enabled: - continue - if hasattr(local_unit, "unfold_merged"): - enabled_units.extend(local_unit.unfold_merged()) - else: - enabled_units.append(copy(local_unit)) - - Infotext.write_infotext(enabled_units, p) + enabled_units = [x for x in units if x.enabled] return enabled_units def choose_input_image( @@ -346,33 +275,6 @@ class ControlNetForForgeOfficial(scripts.Script): return - @staticmethod - def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None: - """ - Checks whether the given ControlNet unit has model compatible with the currently - active sd model. An exception is thrown if ControlNet unit is detected to be - incompatible. - """ - sd_version = global_state.get_sd_version() - assert sd_version != StableDiffusionVersion.UNKNOWN - - if "revision" in unit.module.lower() and sd_version != StableDiffusionVersion.SDXL: - raise Exception(f"Preprocessor 'revision' only supports SDXL. Current SD base model is {sd_version}.") - - # No need to check if the ControlModelType does not require model to be present. - if unit.model is None or unit.model.lower() == "none": - return - - cnet_sd_version = StableDiffusionVersion.detect_from_model_name(unit.model) - - if cnet_sd_version == StableDiffusionVersion.UNKNOWN: - logger.warn(f"Unable to determine version for ControlNet model '{unit.model}'.") - return - - if not sd_version.is_compatible_with(cnet_sd_version): - raise Exception( - f"ControlNet model {unit.model}({cnet_sd_version}) is not compatible with sd model({sd_version})") - @staticmethod def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]: """Returns (h, w, hr_h, hr_w)."""