* ini

* remove shit

* Create control_model.py

* i

* i

* Update controlnet_supported.py

* Update controlnet_supported.py

* Update controlnet_supported.py

* i

* i

* Update controlnet_supported.py

* i

* Update controlnet_supported.py

* remove shits

* remove shit

* Update global_state.py

* i

* i

* Update legacy_preprocessors.py

* Update legacy_preprocessors.py

* remove shit

* Update batch_hijack.py

* remove shit

* remove shit

* i

* i

* i

* Update external_code.py

* Update global_state.py

* Update infotext.py

* Update utils.py

* Update external_code.py

* i

* i

* i

* Update controlnet_ui_group.py

* remove shit

* remove shit

* i

* Update controlnet.py

* Update controlnet.py

* Update controlnet.py

* Update controlnet.py

* Update controlnet.py

* i

* Update global_state.py

* Update global_state.py

* i

* Update global_state.py

* Update global_state.py

* Update global_state.py

* Update global_state.py

* Update controlnet_ui_group.py

* i

* Update global_state.py

* Update controlnet_ui_group.py

* Update controlnet_ui_group.py

* i

* Update controlnet_ui_group.py

* Update controlnet_ui_group.py

* Update controlnet_ui_group.py

* Update controlnet_ui_group.py
This commit is contained in:
lllyasviel
2024-01-29 14:25:03 -08:00
committed by GitHub
parent 7ef96eeffa
commit ac374e0b97
35 changed files with 7483 additions and 106 deletions
@@ -0,0 +1,38 @@
import gradio as gr
from typing import List
class ModalInterface(gr.Interface):
modal_id_counter = 0
def __init__(
self,
html_content: str,
open_button_text: str,
open_button_classes: List[str] = [],
open_button_extra_attrs: str = ''
):
self.html_content = html_content
self.open_button_text = open_button_text
self.open_button_classes = open_button_classes
self.open_button_extra_attrs = open_button_extra_attrs
self.modal_id = ModalInterface.modal_id_counter
ModalInterface.modal_id_counter += 1
def __call__(self):
return self.create_modal()
def create_modal(self, visible=True):
html_code = f"""
<div id="cnet-modal-{self.modal_id}" class="cnet-modal">
<span class="cnet-modal-close">&times;</span>
<div class="cnet-modal-content">
{self.html_content}
</div>
</div>
<div id="cnet-modal-open-{self.modal_id}"
class="cnet-modal-open {' '.join(self.open_button_classes)}"
{self.open_button_extra_attrs}
>{self.open_button_text}</div>
"""
return gr.HTML(value=html_code, visible=visible)
@@ -0,0 +1,154 @@
import base64
import gradio as gr
import json
from typing import List, Dict, Any, Tuple
from annotator.openpose import decode_json_as_poses, draw_poses
from annotator.openpose.animalpose import draw_animalposes
from lib_controlnet.controlnet_ui.modal import ModalInterface
from modules import shared
from lib_controlnet.logging import logger
def parse_data_url(data_url: str):
# Split the URL at the comma
media_type, data = data_url.split(",", 1)
# Check if the data is base64-encoded
assert ";base64" in media_type
# Decode the base64 data
return base64.b64decode(data)
def encode_data_url(json_string: str) -> str:
base64_encoded_json = base64.b64encode(json_string.encode("utf-8")).decode("utf-8")
return f"data:application/json;base64,{base64_encoded_json}"
class OpenposeEditor(object):
# Filename used when user click the download link.
download_file = "pose.json"
# URL the openpose editor is mounted on.
editor_url = "/openpose_editor_index"
def __init__(self) -> None:
self.render_button = None
self.pose_input = None
self.download_link = None
self.upload_link = None
self.modal = None
def render_edit(self):
"""Renders the buttons in preview image control button group."""
# The hidden button to trigger a re-render of generated image.
self.render_button = gr.Button(visible=False, elem_classes=["cnet-render-pose"])
# The hidden element that stores the pose json for backend retrieval.
# The front-end javascript will write the edited JSON data to the element.
self.pose_input = gr.Textbox(visible=False, elem_classes=["cnet-pose-json"])
self.modal = ModalInterface(
# Use about:blank here as placeholder so that the iframe does not
# immediately navigate. Most of controlnet units do not need
# openpose editor active. Only navigate when the user first click
# 'Edit'. The navigation logic is in `openpose_editor.js`.
f'<iframe src="about:blank"></iframe>',
open_button_text="Edit",
open_button_classes=["cnet-edit-pose"],
open_button_extra_attrs=f'title="Send pose to {OpenposeEditor.editor_url} for edit."',
).create_modal(visible=False)
self.download_link = gr.HTML(
value=f"""<a href='' download='{OpenposeEditor.download_file}'>JSON</a>""",
visible=False,
elem_classes=["cnet-download-pose"],
)
def render_upload(self):
"""Renders the button in input image control button group."""
self.upload_link = gr.HTML(
value="""
<label>Upload JSON</label>
<input type="file" accept=".json"/>
""",
visible=False,
elem_classes=["cnet-upload-pose"],
)
def register_callbacks(
self,
generated_image: gr.Image,
use_preview_as_input: gr.Checkbox,
model: gr.Dropdown,
):
def render_pose(pose_url: str) -> Tuple[Dict, Dict]:
json_string = parse_data_url(pose_url).decode("utf-8")
poses, animals, height, width = decode_json_as_poses(
json.loads(json_string)
)
logger.info("Preview as input is enabled.")
return (
# Generated image.
gr.update(
value=(
draw_poses(
poses,
height,
width,
draw_body=True,
draw_hand=True,
draw_face=True,
)
if poses
else draw_animalposes(animals, height, width)
),
visible=True,
),
# Use preview as input.
gr.update(value=True),
# Self content.
*self.update(json_string),
)
self.render_button.click(
fn=render_pose,
inputs=[self.pose_input],
outputs=[generated_image, use_preview_as_input, *self.outputs()],
)
def update_upload_link(model: str) -> Dict:
return gr.update(visible="openpose" in model.lower())
model.change(fn=update_upload_link, inputs=[model], outputs=[self.upload_link])
def outputs(self) -> List[Any]:
return [
self.download_link,
self.modal,
]
def update(self, json_string: str) -> List[Dict]:
"""
Called when there is a new JSON pose value generated by running
preprocessor.
Args:
json_string: The new JSON string generated by preprocessor.
Returns:
An gr.update event.
"""
hint = "Download the pose as .json file"
html = f"""<a href='{encode_data_url(json_string)}'
download='{OpenposeEditor.download_file}' title="{hint}">
JSON</a>"""
visible = json_string != ""
return [
# Download link update.
gr.update(value=html, visible=visible),
# Modal update.
gr.update(
visible=visible
and not shared.opts.data.get("controlnet_disable_openpose_edit", False)
),
]
@@ -0,0 +1,182 @@
import gradio as gr
from lib_controlnet.controlnet_ui.modal import ModalInterface
PHOTOPEA_LOGO = """
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
width="100%" viewBox="0 0 256 256" enable-background="new 0 0 256 256" xml:space="preserve"
style="width: 0.75rem; height 0.75rem; margin-left: 2px;"
>
<path fill="#18A497" opacity="1.000000" stroke="none"
d="
M1.000000,228.000000
C1.000000,162.312439 1.000000,96.624878 1.331771,30.719650
C2.026278,30.171114 2.594676,29.904894 2.721949,29.500008
C6.913495,16.165672 15.629609,7.322631 28.880219,2.875538
C29.404272,2.699659 29.633436,1.645129 30.000000,1.000000
C95.687561,1.000000 161.375122,1.000000 227.258057,1.317018
C227.660217,1.893988 227.815079,2.296565 228.081207,2.393433
C241.304657,7.206383 250.980164,15.550970 255.215851,29.410040
C255.321625,29.756128 256.383850,29.809898 257.000000,30.000000
C257.000000,95.687561 257.000000,161.375122 256.682983,227.257858
C256.106049,227.659790 255.699371,227.815521 255.607178,228.080658
C250.953033,241.462830 242.292618,250.822968 228.591782,255.214935
C228.239929,255.327698 228.190491,256.383820 228.000000,257.000000
C175.312439,257.000000 122.624878,257.000000 69.468582,256.531342
C68.672188,244.948196 68.218323,233.835587 68.052299,222.718674
C67.885620,211.557587 67.886772,200.390717 68.027298,189.229050
C68.255180,171.129044 68.084618,152.997421 69.151917,134.942368
C70.148468,118.083969 77.974228,103.689308 89.758743,91.961365
C104.435837,77.354736 122.313736,69.841736 143.417328,69.901505
C168.662338,69.972984 186.981964,90.486633 187.961487,114.156334
C189.042435,140.277435 166.783981,163.607941 140.303482,160.823074
C137.092346,160.485382 133.490692,158.365784 131.192612,155.987366
C126.434669,151.063141 126.975357,144.720825 129.168777,138.834930
C131.533630,132.489014 137.260605,130.548050 143.413757,130.046677
C150.288467,129.486496 156.424942,123.757378 157.035324,117.320816
C157.953949,107.633820 150.959381,101.769096 145.533951,101.194389
C132.238846,99.786079 120.699944,104.963120 111.676735,114.167313
C102.105782,123.930222 97.469498,136.194061 99.003151,150.234955
C100.540352,164.308228 107.108505,175.507980 118.864334,183.311539
C128.454544,189.677597 138.866959,191.786957 150.657837,190.245651
C166.242554,188.208420 179.874283,182.443329 191.251801,172.056793
C209.355011,155.530380 217.848694,134.938721 216.116119,110.085892
C214.834335,91.699440 207.721039,76.015915 195.289444,62.978828
C175.658447,42.391735 150.833389,37.257801 123.833740,42.281937
C98.675804,46.963364 78.315033,60.084667 62.208153,80.157814
C46.645889,99.552216 39.305275,121.796379 39.149052,146.201981
C38.912663,183.131317 39.666767,220.067017 40.000000,257.000000
C36.969406,257.000000 33.938812,257.000000 30.705070,256.668213
C30.298622,256.078369 30.144913,255.669220 29.884926,255.583878
C16.317770,251.131058 7.127485,242.317780 2.778462,228.591797
C2.667588,228.241821 1.613958,228.190567 1.000000,228.000000
z"/>
<path fill="#000000" opacity="1.000000" stroke="none"
d="
M40.468658,257.000000
C39.666767,220.067017 38.912663,183.131317 39.149052,146.201981
C39.305275,121.796379 46.645889,99.552216 62.208153,80.157814
C78.315033,60.084667 98.675804,46.963364 123.833740,42.281937
C150.833389,37.257801 175.658447,42.391735 195.289444,62.978828
C207.721039,76.015915 214.834335,91.699440 216.116119,110.085892
C217.848694,134.938721 209.355011,155.530380 191.251801,172.056793
C179.874283,182.443329 166.242554,188.208420 150.657837,190.245651
C138.866959,191.786957 128.454544,189.677597 118.864334,183.311539
C107.108505,175.507980 100.540352,164.308228 99.003151,150.234955
C97.469498,136.194061 102.105782,123.930222 111.676735,114.167313
C120.699944,104.963120 132.238846,99.786079 145.533951,101.194389
C150.959381,101.769096 157.953949,107.633820 157.035324,117.320816
C156.424942,123.757378 150.288467,129.486496 143.413757,130.046677
C137.260605,130.548050 131.533630,132.489014 129.168777,138.834930
C126.975357,144.720825 126.434669,151.063141 131.192612,155.987366
C133.490692,158.365784 137.092346,160.485382 140.303482,160.823074
C166.783981,163.607941 189.042435,140.277435 187.961487,114.156334
C186.981964,90.486633 168.662338,69.972984 143.417328,69.901505
C122.313736,69.841736 104.435837,77.354736 89.758743,91.961365
C77.974228,103.689308 70.148468,118.083969 69.151917,134.942368
C68.084618,152.997421 68.255180,171.129044 68.027298,189.229050
C67.886772,200.390717 67.885620,211.557587 68.052299,222.718674
C68.218323,233.835587 68.672188,244.948196 68.999924,256.531342
C59.645771,257.000000 50.291542,257.000000 40.468658,257.000000
z"/>
<path fill="#000000" opacity="1.000000" stroke="none"
d="
M257.000000,29.531342
C256.383850,29.809898 255.321625,29.756128 255.215851,29.410040
C250.980164,15.550970 241.304657,7.206383 228.081207,2.393433
C227.815079,2.296565 227.660217,1.893988 227.726715,1.317018
C237.593155,1.000000 247.186295,1.000000 257.000000,1.000000
C257.000000,10.353075 257.000000,19.707878 257.000000,29.531342
z"/>
<path fill="#000000" opacity="1.000000" stroke="none"
d="
M228.468658,257.000000
C228.190491,256.383820 228.239929,255.327698 228.591782,255.214935
C242.292618,250.822968 250.953033,241.462830 255.607178,228.080658
C255.699371,227.815521 256.106049,227.659790 256.682983,227.726517
C257.000000,237.593155 257.000000,247.186295 257.000000,257.000000
C247.646927,257.000000 238.292114,257.000000 228.468658,257.000000
z"/>
<path fill="#000000" opacity="1.000000" stroke="none"
d="
M1.000000,228.468658
C1.613958,228.190567 2.667588,228.241821 2.778462,228.591797
C7.127485,242.317780 16.317770,251.131058 29.884926,255.583878
C30.144913,255.669220 30.298622,256.078369 30.250959,256.668213
C20.406853,257.000000 10.813705,257.000000 1.000000,257.000000
C1.000000,247.646927 1.000000,238.292114 1.000000,228.468658
z"/>
<path fill="#000000" opacity="1.000000" stroke="none"
d="
M29.531342,1.000000
C29.633436,1.645129 29.404272,2.699659 28.880219,2.875538
C15.629609,7.322631 6.913495,16.165672 2.721949,29.500008
C2.594676,29.904894 2.026278,30.171114 1.331771,30.250992
C1.000000,20.406855 1.000000,10.813709 1.000000,1.000000
C10.353074,1.000000 19.707878,1.000000 29.531342,1.000000
z"/>
</svg>"""
class Photopea(object):
def __init__(self) -> None:
self.modal = None
self.triggers = []
self.render_editor()
def render_editor(self):
"""Render the editor modal."""
with gr.Group(elem_classes=["cnet-photopea-edit"]):
self.modal = ModalInterface(
# Use about:blank here as placeholder so that the iframe does not
# immediately navigate. Only navigate when the user first click
# 'Edit'. The navigation logic is in `photopea.js`.
f"""
<div class="photopea-button-group">
<button class="photopea-button photopea-fetch">Fetch from ControlNet</button>
<button class="photopea-button photopea-send">Send to ControlNet</button>
</div>
<iframe class="photopea-iframe" src="about:blank"></iframe>
""",
open_button_text="Edit",
open_button_classes=["cnet-photopea-main-trigger"],
open_button_extra_attrs="hidden",
).create_modal(visible=True)
def render_child_trigger(self):
self.triggers.append(
gr.HTML(
f"""<div class="cnet-photopea-child-trigger">
Edit {PHOTOPEA_LOGO}
</div>"""
)
)
def attach_photopea_output(self, generated_image: gr.Image):
"""Called in ControlNetUiGroup to attach preprocessor preview image Gradio element
as the photopea output. If the front-end directly change the img HTML element's src
to reflect the edited image result from photopea, the backend won't be notified.
In this method we let the front-end upload the result image an invisible gr.Image
instance and mirrors the value to preprocessor preview gr.Image. This is because
the generated image gr.Image instance is inferred to be an output image by Gradio
and has no ability to accept image upload directly.
Arguments:
generated_image: preprocessor result Gradio Image output element.
Returns:
None
"""
output = gr.Image(
visible=False,
source="upload",
type="numpy",
elem_classes=[f"cnet-photopea-output"],
)
output.upload(
fn=lambda img: img,
inputs=[output],
outputs=[generated_image],
)
@@ -0,0 +1,318 @@
import os
import gradio as gr
from typing import Dict, List
from modules import scripts
from lib_controlnet.infotext import parse_unit, serialize_unit
from lib_controlnet.controlnet_ui.tool_button import ToolButton
from lib_controlnet.logging import logger
from lib_controlnet import external_code
save_symbol = "\U0001f4be" # 💾
delete_symbol = "\U0001f5d1\ufe0f" # 🗑️
refresh_symbol = "\U0001f504" # 🔄
reset_symbol = "\U000021A9" # ↩
NEW_PRESET = "New Preset"
def load_presets(preset_dir: str) -> Dict[str, str]:
if not os.path.exists(preset_dir):
os.makedirs(preset_dir)
return {}
presets = {}
for filename in os.listdir(preset_dir):
if filename.endswith(".txt"):
with open(os.path.join(preset_dir, filename), "r") as f:
name = filename.replace(".txt", "")
if name == NEW_PRESET:
continue
presets[name] = f.read()
return presets
def infer_control_type(module: str, model: str) -> str:
def matches_control_type(input_string: str, control_type: str) -> bool:
return any(t.lower() in input_string for t in control_type.split("/"))
control_types = preprocessor_filters.keys()
control_type_candidates = [
control_type
for control_type in control_types
if (
matches_control_type(module, control_type)
or matches_control_type(model, control_type)
)
]
if len(control_type_candidates) != 1:
raise ValueError(
f"Unable to infer control type from module {module} and model {model}"
)
return control_type_candidates[0]
class ControlNetPresetUI(object):
preset_directory = os.path.join(scripts.basedir(), "presets")
presets = load_presets(preset_directory)
def __init__(self, id_prefix: str):
with gr.Row():
self.dropdown = gr.Dropdown(
label="Presets",
show_label=True,
elem_classes=["cnet-preset-dropdown"],
choices=ControlNetPresetUI.dropdown_choices(),
value=NEW_PRESET,
)
self.reset_button = ToolButton(
value=reset_symbol,
elem_classes=["cnet-preset-reset"],
tooltip="Reset preset",
visible=False,
)
self.save_button = ToolButton(
value=save_symbol,
elem_classes=["cnet-preset-save"],
tooltip="Save preset",
)
self.delete_button = ToolButton(
value=delete_symbol,
elem_classes=["cnet-preset-delete"],
tooltip="Delete preset",
)
self.refresh_button = ToolButton(
value=refresh_symbol,
elem_classes=["cnet-preset-refresh"],
tooltip="Refresh preset",
)
with gr.Box(
elem_classes=["popup-dialog", "cnet-preset-enter-name"],
elem_id=f"{id_prefix}_cnet_preset_enter_name",
) as self.name_dialog:
with gr.Row():
self.preset_name = gr.Textbox(
label="Preset name",
show_label=True,
lines=1,
elem_classes=["cnet-preset-name"],
)
self.confirm_preset_name = ToolButton(
value=save_symbol,
elem_classes=["cnet-preset-confirm-name"],
tooltip="Save preset",
)
def register_callbacks(
self,
uigroup,
control_type: gr.Radio,
*ui_states,
):
def apply_preset(name: str, control_type: str, *ui_states):
if name == NEW_PRESET:
return (
gr.update(visible=False),
*(
(gr.skip(),)
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
),
)
assert name in ControlNetPresetUI.presets
infotext = ControlNetPresetUI.presets[name]
preset_unit = parse_unit(infotext)
current_unit = external_code.ControlNetUnit(*ui_states)
preset_unit.image = None
current_unit.image = None
# Do not compare module param that are not used in preset.
for module_param in ("processor_res", "threshold_a", "threshold_b"):
if getattr(preset_unit, module_param) == -1:
setattr(current_unit, module_param, -1)
# No update necessary.
if vars(current_unit) == vars(preset_unit):
return (
gr.update(visible=False),
*(
(gr.skip(),)
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
),
)
unit = preset_unit
try:
new_control_type = infer_control_type(unit.module, unit.model)
except ValueError as e:
logger.error(e)
new_control_type = control_type
if new_control_type != control_type:
uigroup.prevent_next_n_module_update += 1
if preset_unit.module != current_unit.module:
uigroup.prevent_next_n_slider_value_update += 1
if preset_unit.pixel_perfect != current_unit.pixel_perfect:
uigroup.prevent_next_n_slider_value_update += 1
return (
gr.update(visible=True),
gr.update(value=new_control_type),
*[
gr.update(value=value) if value is not None else gr.update()
for value in vars(unit).values()
],
)
for element, action in (
(self.dropdown, "change"),
(self.reset_button, "click"),
):
getattr(element, action)(
fn=apply_preset,
inputs=[self.dropdown, control_type, *ui_states],
outputs=[self.delete_button, control_type, *ui_states],
show_progress="hidden",
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=[self.reset_button],
)
def save_preset(name: str, *ui_states):
if name == NEW_PRESET:
return gr.update(visible=True), gr.update(), gr.update()
ControlNetPresetUI.save_preset(
name, external_code.ControlNetUnit(*ui_states)
)
return (
gr.update(), # name dialog
gr.update(choices=ControlNetPresetUI.dropdown_choices(), value=name),
gr.update(visible=False), # Reset button
)
self.save_button.click(
fn=save_preset,
inputs=[self.dropdown, *ui_states],
outputs=[self.name_dialog, self.dropdown, self.reset_button],
show_progress="hidden",
).then(
fn=None,
_js=f"""
(name) => {{
if (name === "{NEW_PRESET}")
popup(gradioApp().getElementById('{self.name_dialog.elem_id}'));
}}""",
inputs=[self.dropdown],
)
def delete_preset(name: str):
ControlNetPresetUI.delete_preset(name)
return gr.Dropdown.update(
choices=ControlNetPresetUI.dropdown_choices(),
value=NEW_PRESET,
), gr.update(visible=False)
self.delete_button.click(
fn=delete_preset,
inputs=[self.dropdown],
outputs=[self.dropdown, self.reset_button],
show_progress="hidden",
)
self.name_dialog.visible = False
def save_new_preset(new_name: str, *ui_states):
if new_name == NEW_PRESET:
logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'")
return gr.update(visible=False), gr.update()
ControlNetPresetUI.save_preset(
new_name, external_code.ControlNetUnit(*ui_states)
)
return gr.update(visible=False), gr.update(
choices=ControlNetPresetUI.dropdown_choices(), value=new_name
)
self.confirm_preset_name.click(
fn=save_new_preset,
inputs=[self.preset_name, *ui_states],
outputs=[self.name_dialog, self.dropdown],
show_progress="hidden",
).then(fn=None, _js="closePopup")
self.refresh_button.click(
fn=ControlNetPresetUI.refresh_preset,
inputs=None,
outputs=[self.dropdown],
show_progress="hidden",
)
def update_reset_button(preset_name: str, *ui_states):
if preset_name == NEW_PRESET:
return gr.update(visible=False)
infotext = ControlNetPresetUI.presets[preset_name]
preset_unit = parse_unit(infotext)
current_unit = external_code.ControlNetUnit(*ui_states)
preset_unit.image = None
current_unit.image = None
# Do not compare module param that are not used in preset.
for module_param in ("processor_res", "threshold_a", "threshold_b"):
if getattr(preset_unit, module_param) == -1:
setattr(current_unit, module_param, -1)
return gr.update(visible=vars(current_unit) != vars(preset_unit))
for ui_state in ui_states:
if isinstance(ui_state, gr.Image):
continue
for action in ("edit", "click", "change", "clear", "release"):
if action == "release" and not isinstance(ui_state, gr.Slider):
continue
if hasattr(ui_state, action):
getattr(ui_state, action)(
fn=update_reset_button,
inputs=[self.dropdown, *ui_states],
outputs=[self.reset_button],
)
@staticmethod
def dropdown_choices() -> List[str]:
return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET]
@staticmethod
def save_preset(name: str, unit: external_code.ControlNetUnit):
infotext = serialize_unit(unit)
with open(
os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"
) as f:
f.write(infotext)
ControlNetPresetUI.presets[name] = infotext
@staticmethod
def delete_preset(name: str):
if name not in ControlNetPresetUI.presets:
return
del ControlNetPresetUI.presets[name]
file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt")
if os.path.exists(file):
os.unlink(file)
@staticmethod
def refresh_preset():
ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory)
return gr.update(choices=ControlNetPresetUI.dropdown_choices())
@@ -0,0 +1,12 @@
import gradio as gr
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool",
elem_classes=kwargs.pop('elem_classes', []) + ["cnet-toolbutton"],
**kwargs)
def get_block_name(self):
return "button"
@@ -0,0 +1,73 @@
from enum import Enum
from typing import Any
class StableDiffusionVersion(Enum):
"""The version family of stable diffusion model."""
UNKNOWN = 0
SD1x = 1
SD2x = 2
SDXL = 3
@staticmethod
def detect_from_model_name(model_name: str) -> "StableDiffusionVersion":
"""Based on the model name provided, guess what stable diffusion version it is.
This might not be accurate without actually inspect the file content.
"""
if any(f"sd{v}" in model_name.lower() for v in ("14", "15", "16")):
return StableDiffusionVersion.SD1x
if "sd21" in model_name or "2.1" in model_name:
return StableDiffusionVersion.SD2x
if "xl" in model_name.lower():
return StableDiffusionVersion.SDXL
return StableDiffusionVersion.UNKNOWN
def encoder_block_num(self) -> int:
if self in (StableDiffusionVersion.SD1x, StableDiffusionVersion.SD2x, StableDiffusionVersion.UNKNOWN):
return 12
else:
return 9 # SDXL
def controlnet_layer_num(self) -> int:
return self.encoder_block_num() + 1
def is_compatible_with(self, other: "StableDiffusionVersion") -> bool:
""" Incompatible only when one of version is SDXL and other is not. """
return (
any(v == StableDiffusionVersion.UNKNOWN for v in [self, other]) or
sum(v == StableDiffusionVersion.SDXL for v in [self, other]) != 1
)
class HiResFixOption(Enum):
BOTH = "Both"
LOW_RES_ONLY = "Low res only"
HIGH_RES_ONLY = "High res only"
@staticmethod
def from_value(value: Any) -> "HiResFixOption":
if isinstance(value, str) and value.startswith("HiResFixOption."):
_, field = value.split(".")
return getattr(HiResFixOption, field)
if isinstance(value, str):
return HiResFixOption(value)
elif isinstance(value, int):
return [x for x in HiResFixOption][value]
else:
assert isinstance(value, HiResFixOption)
return value
class InputMode(Enum):
# Single image to a single ControlNet unit.
SIMPLE = "simple"
# Input is a directory. N generations. Each generation takes 1 input image
# from the directory.
BATCH = "batch"
# Input is a directory. 1 generation. Each generation takes N input image
# from the directory.
MERGE = "merge"
@@ -0,0 +1,460 @@
from dataclasses import dataclass
from enum import Enum
from copy import copy
from typing import List, Any, Optional, Union, Tuple, Dict
import numpy as np
from modules import scripts, processing, shared
from lib_controlnet import global_state
from lib_controlnet.logging import logger
from lib_controlnet.enums import HiResFixOption
from modules.api import api
def get_api_version() -> int:
return 2
class ControlMode(Enum):
"""
The improved guess mode.
"""
BALANCED = "Balanced"
PROMPT = "My prompt is more important"
CONTROL = "ControlNet is more important"
class BatchOption(Enum):
DEFAULT = "All ControlNet units for all images in a batch"
SEPARATE = "Each ControlNet unit for each image in a batch"
class ResizeMode(Enum):
"""
Resize modes for ControlNet input images.
"""
RESIZE = "Just Resize"
INNER_FIT = "Crop and Resize"
OUTER_FIT = "Resize and Fill"
def int_value(self):
if self == ResizeMode.RESIZE:
return 0
elif self == ResizeMode.INNER_FIT:
return 1
elif self == ResizeMode.OUTER_FIT:
return 2
assert False, "NOTREACHED"
resize_mode_aliases = {
'Inner Fit (Scale to Fit)': 'Crop and Resize',
'Outer Fit (Shrink to Fit)': 'Resize and Fill',
'Scale to Fit (Inner Fit)': 'Crop and Resize',
'Envelope (Outer Fit)': 'Resize and Fill',
}
def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode:
if isinstance(value, str):
return ResizeMode(resize_mode_aliases.get(value, value))
elif isinstance(value, int):
assert value >= 0
if value == 3: # 'Just Resize (Latent upscale)'
return ResizeMode.RESIZE
if value >= len(ResizeMode):
logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.')
return ResizeMode.RESIZE
return [e for e in ResizeMode][value]
else:
return value
def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode:
if isinstance(value, str):
return ControlMode(value)
elif isinstance(value, int):
return [e for e in ControlMode][value]
else:
return value
def visualize_inpaint_mask(img):
if img.ndim == 3 and img.shape[2] == 4:
result = img.copy()
mask = result[:, :, 3]
mask = 255 - mask // 2
result[:, :, 3] = mask
return np.ascontiguousarray(result.copy())
return img
def pixel_perfect_resolution(
image: np.ndarray,
target_H: int,
target_W: int,
resize_mode: ResizeMode,
) -> int:
"""
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
The function first calculates scaling factors for height and width of the image based on the target
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
scaling factor to estimate the new resolution.
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
fits within the target dimensions, potentially leaving some empty space.
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
dimensions are fully filled, potentially cropping the image.
After calculating the estimated resolution, the function prints some debugging information.
Args:
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
target_H (int): The target height for the image.
target_W (int): The target width for the image.
resize_mode (ResizeMode): The mode for resizing.
Returns:
int: The estimated resolution after resizing.
"""
raw_H, raw_W, _ = image.shape
k0 = float(target_H) / float(raw_H)
k1 = float(target_W) / float(raw_W)
if resize_mode == ResizeMode.OUTER_FIT:
estimation = min(k0, k1) * float(min(raw_H, raw_W))
else:
estimation = max(k0, k1) * float(min(raw_H, raw_W))
logger.debug(f"Pixel Perfect Computation:")
logger.debug(f"resize_mode = {resize_mode}")
logger.debug(f"raw_H = {raw_H}")
logger.debug(f"raw_W = {raw_W}")
logger.debug(f"target_H = {target_H}")
logger.debug(f"target_W = {target_W}")
logger.debug(f"estimation = {estimation}")
return int(np.round(estimation))
InputImage = Union[np.ndarray, str]
InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage]
@dataclass
class ControlNetUnit:
"""
Represents an entire ControlNet processing unit.
"""
enabled: bool = True
module: str = "none"
model: str = "None"
weight: float = 1.0
image: Optional[Union[InputImage, List[InputImage]]] = None
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
low_vram: bool = False
processor_res: int = -1
threshold_a: float = -1
threshold_b: float = -1
guidance_start: float = 0.0
guidance_end: float = 1.0
pixel_perfect: bool = False
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
# Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area`
# in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`.
inpaint_crop_input_image: bool = True
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
# The value is ignored if the generation is not using hires fix.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
# Whether save the detected map of this unit. Setting this option to False prevents saving the
# detected map or sending detected map along with generated images via API.
# Currently the option is only accessible in API calls.
save_detected_map: bool = True
# Weight for each layer of ControlNet params.
# For ControlNet:
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
# For T2IAdapter
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
# - SDXL: 4 weights (3 encoder block + 1 middle block)
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
# Note2: The field `weight` is still used in some places, e.g. reference_only,
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None
def __eq__(self, other):
if not isinstance(other, ControlNetUnit):
return False
return vars(self) == vars(other)
def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return self.module in (
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_clip_sd15",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"instant_id_face_embedding",
)
def to_base64_nparray(encoding: str):
"""
Convert a base64 image into the image type the extension uses
"""
return np.array(api.decode_base64_to_image(encoding)).astype('uint8')
def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]:
"""
Fetch ControlNet processing units from a StableDiffusionProcessing.
"""
return get_all_units(p.scripts, p.script_args)
def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]:
"""
Fetch ControlNet processing units from an existing script runner.
Use this function to fetch units from the list of all scripts arguments.
"""
cn_script = find_cn_script(script_runner)
if cn_script:
return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to])
return []
def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]:
"""
Fetch ControlNet processing units from ControlNet script arguments.
Use `external_code.get_all_units` to fetch units from the list of all scripts arguments.
"""
def is_stale_unit(script_arg: Any) -> bool:
""" Returns whether the script_arg is potentially an stale version of
ControlNetUnit created before module reload."""
return (
'ControlNetUnit' in type(script_arg).__name__ and
not isinstance(script_arg, ControlNetUnit)
)
def is_controlnet_unit(script_arg: Any) -> bool:
""" Returns whether the script_arg is ControlNetUnit or anything that
can be treated like ControlNetUnit. """
return (
isinstance(script_arg, (ControlNetUnit, dict)) or
(
hasattr(script_arg, '__dict__') and
set(vars(ControlNetUnit()).keys()).issubset(
set(vars(script_arg).keys()))
)
)
all_units = [
to_processing_unit(script_arg)
for script_arg in script_args
if is_controlnet_unit(script_arg)
]
if not all_units:
logger.warning(
"No ControlNetUnit detected in args. It is very likely that you are having an extension conflict."
f"Here are args received by ControlNet: {script_args}.")
if any(is_stale_unit(script_arg) for script_arg in script_args):
logger.debug(
"Stale version of ControlNetUnit detected. The ControlNetUnit received"
"by ControlNet is created before the newest load of ControlNet extension."
"They will still be used by ControlNet as long as they provide same fields"
"defined in the newest version of ControlNetUnit."
)
return all_units
def get_single_unit_from(script_args: List[Any], index: int = 0) -> Optional[ControlNetUnit]:
"""
Fetch a single ControlNet processing unit from ControlNet script arguments.
The list must not contain script positional arguments. It must only contain processing units.
"""
i = 0
while i < len(script_args) and index >= 0:
if index == 0 and script_args[i] is not None:
return to_processing_unit(script_args[i])
i += 1
index -= 1
return None
def get_max_models_num():
"""
Fetch the maximum number of allowed ControlNet models.
"""
max_models_num = shared.opts.data.get("control_net_unit_count", 3)
return max_models_num
def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit:
"""
Convert different types to processing unit.
If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details.
"""
ext_compat_keys = {
'guessmode': 'guess_mode',
'guidance': 'guidance_end',
'lowvram': 'low_vram',
'input_image': 'image'
}
if isinstance(unit, dict):
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()}
mask = None
if 'mask' in unit:
mask = unit['mask']
del unit['mask']
if 'image' in unit and not isinstance(unit['image'], dict):
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[
'image'] else None
if 'guess_mode' in unit:
logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.')
unit = ControlNetUnit(**{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()})
# temporary, check #602
# assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]'
return unit
def update_cn_script_in_processing(
p: processing.StableDiffusionProcessing,
cn_units: List[ControlNetUnit],
**_kwargs, # for backwards compatibility
):
"""
Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`.
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want.
Does not update `p.script_args` if any of the folling is true:
- ControlNet is not present in `p.scripts`
- `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet
"""
p.script_args = update_cn_script(p.scripts, p.script_args_value, cn_units)
def update_cn_script(
script_runner: scripts.ScriptRunner,
script_args: Union[Tuple[Any], List[Any]],
cn_units: List[ControlNetUnit],
) -> Union[Tuple[Any], List[Any]]:
"""
Returns: The updated `script_args` with given `cn_units` used as ControlNet
script args.
Does not update `script_args` if any of the folling is true:
- ControlNet is not present in `script_runner`
- `script_args` is not filled with script arguments for scripts that are
processed before ControlNet
"""
script_args_type = type(script_args)
assert script_args_type in (tuple, list), script_args_type
updated_script_args = list(copy(script_args))
cn_script = find_cn_script(script_runner)
if cn_script is None or len(script_args) < cn_script.args_from:
return script_args
# fill in remaining parameters to satisfy max models, just in case script needs it.
max_models = shared.opts.data.get("control_net_unit_count", 3)
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0)
cn_script_args_diff = 0
for script in script_runner.alwayson_scripts:
if script is cn_script:
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from)
updated_script_args[script.args_from:script.args_to] = cn_units
script.args_to = script.args_from + len(cn_units)
else:
script.args_from += cn_script_args_diff
script.args_to += cn_script_args_diff
return script_args_type(updated_script_args)
def update_cn_script_in_place(
script_runner: scripts.ScriptRunner,
script_args: List[Any],
cn_units: List[ControlNetUnit],
**_kwargs, # for backwards compatibility
):
"""
@Deprecated(Raises assertion error if script_args passed in is Tuple)
Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`.
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want.
Does not update `script_args` if any of the folling is true:
- ControlNet is not present in `script_runner`
- `script_args` is not filled with script arguments for scripts that are processed before ControlNet
"""
assert isinstance(script_args, list), type(script_args)
cn_script = find_cn_script(script_runner)
if cn_script is None or len(script_args) < cn_script.args_from:
return
# fill in remaining parameters to satisfy max models, just in case script needs it.
max_models = shared.opts.data.get("control_net_unit_count", 3)
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0)
cn_script_args_diff = 0
for script in script_runner.alwayson_scripts:
if script is cn_script:
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from)
script_args[script.args_from:script.args_to] = cn_units
script.args_to = script.args_from + len(cn_units)
else:
script.args_from += cn_script_args_diff
script.args_to += cn_script_args_diff
def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]:
"""
Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script.
"""
if script_runner is None:
return None
for script in script_runner.alwayson_scripts:
if is_cn_script(script):
return script
def is_cn_script(script: scripts.Script) -> bool:
"""
Determine whether `script` is a ControlNet script.
"""
return script.title().lower() == 'controlnet'
@@ -0,0 +1,138 @@
import os.path
import stat
from collections import OrderedDict
from modules import shared, sd_models
from lib_controlnet.enums import StableDiffusionVersion
from modules_forge.shared import controlnet_dir, supported_preprocessors
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
def traverse_all_files(curr_path, model_list):
f_list = [
(os.path.join(curr_path, entry.name), entry.stat())
for entry in os.scandir(curr_path)
if os.path.isdir(curr_path)
]
for f_info in f_list:
fname, fstat = f_info
if os.path.splitext(fname)[1] in CN_MODEL_EXTS:
model_list.append(f_info)
elif stat.S_ISDIR(fstat.st_mode):
model_list = traverse_all_files(fname, model_list)
return model_list
def get_all_models(sort_by, filter_by, path):
res = OrderedDict()
fileinfos = traverse_all_files(path, [])
filter_by = filter_by.strip(" ")
if len(filter_by) != 0:
fileinfos = [x for x in fileinfos if filter_by.lower()
in os.path.basename(x[0]).lower()]
if sort_by == "name":
fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0]))
elif sort_by == "date":
fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime)
elif sort_by == "path name":
fileinfos = sorted(fileinfos)
for finfo in fileinfos:
filename = finfo[0]
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name + f" [{sd_models.model_hash(filename)}]"] = filename
return res
controlnet_filename_dict = {'None': 'model.safetensors'}
controlnet_names = ['None']
def get_preprocessor(name):
return supported_preprocessors.get(name, None)
def get_sorted_preprocessors():
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
preprocessors = sorted(preprocessors, key=lambda x: str(x.sorting_priority).zfill(8) + x.name)[::-1]
results = OrderedDict()
results['None'] = supported_preprocessors['None']
for p in preprocessors:
results[p.name] = p
return results
def get_all_controlnet_names():
return controlnet_names
def get_controlnet_filename(controlnet_name):
return controlnet_filename_dict[controlnet_name]
def get_all_preprocessor_names():
return list(get_sorted_preprocessors().keys())
def get_all_preprocessor_tags():
tags = []
for k, p in supported_preprocessors.items():
tags += p.tags
tags = list(set(tags))
tags = sorted(tags)
return ['All'] + tags
def get_filtered_preprocessors(tag):
if tag == 'All':
return supported_preprocessors
return {k: v for k, v in get_sorted_preprocessors().items() if tag in v.tags or k == 'None'}
def get_filtered_preprocessor_names(tag):
return list(get_filtered_preprocessors(tag).keys())
def get_filtered_controlnet_names(tag):
filtered_preprocessors = get_filtered_preprocessors(tag)
model_filename_filers = []
for p in filtered_preprocessors.values():
model_filename_filers += p.model_filename_filers
return [x for x in controlnet_names if any(f.lower() in x.lower() for f in model_filename_filers) or x == 'None']
def update_controlnet_filenames():
global controlnet_filename_dict, controlnet_names
controlnet_filename_dict = {'None': 'model.safetensors'}
controlnet_names = ['None']
ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None))
extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs
if extra_lora_path is not None and os.path.exists(extra_lora_path))
paths = [controlnet_dir, *extra_lora_paths]
for path in paths:
sort_by = shared.opts.data.get("control_net_models_sort_models_by", "name")
filter_by = shared.opts.data.get("control_net_models_name_filter", "")
found = get_all_models(sort_by, filter_by, path)
controlnet_filename_dict.update(found)
controlnet_names = list(controlnet_filename_dict.keys())
return
def get_sd_version() -> StableDiffusionVersion:
if shared.sd_model.is_sdxl:
return StableDiffusionVersion.SDXL
elif shared.sd_model.is_sd2:
return StableDiffusionVersion.SD2x
elif shared.sd_model.is_sd1:
return StableDiffusionVersion.SD1x
else:
return StableDiffusionVersion.UNKNOWN
@@ -0,0 +1,135 @@
from typing import List, Tuple, Union
import gradio as gr
from modules.processing import StableDiffusionProcessing
from lib_controlnet import external_code
from lib_controlnet.logging import logger
def field_to_displaytext(fieldname: str) -> str:
return " ".join([word.capitalize() for word in fieldname.split("_")])
def displaytext_to_field(text: str) -> str:
return "_".join([word.lower() for word in text.split(" ")])
def parse_value(value: str) -> Union[str, float, int, bool]:
if value in ("True", "False"):
return value == "True"
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value # Plain string.
def serialize_unit(unit: external_code.ControlNetUnit) -> str:
excluded_fields = (
"image",
"enabled",
# Note: "advanced_weighting" is excluded as it is an API-only field.
"advanced_weighting",
# Note: "inpaint_crop_image" is img2img inpaint only flag, which does not
# provide much information when restoring the unit.
"inpaint_crop_input_image",
)
log_value = {
field_to_displaytext(field): getattr(unit, field)
for field in vars(external_code.ControlNetUnit()).keys()
if field not in excluded_fields and getattr(unit, field) != -1
# Note: exclude hidden slider values.
}
if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()):
logger.error(f"Unexpected tokens encountered:\n{log_value}")
return ""
return ", ".join(f"{field}: {value}" for field, value in log_value.items())
def parse_unit(text: str) -> external_code.ControlNetUnit:
return external_code.ControlNetUnit(
enabled=True,
**{
displaytext_to_field(key): parse_value(value)
for item in text.split(",")
for (key, value) in (item.strip().split(": "),)
},
)
class Infotext(object):
def __init__(self) -> None:
self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
self.paste_field_names: List[str] = []
@staticmethod
def unit_prefix(unit_index: int) -> str:
return f"ControlNet {unit_index}"
def register_unit(self, unit_index: int, uigroup) -> None:
"""Register the unit's UI group. By regsitering the unit, A1111 will be
able to paste values from infotext to IOComponents.
Args:
unit_index: The index of the ControlNet unit
uigroup: The ControlNetUiGroup instance that contains all gradio
iocomponents.
"""
unit_prefix = Infotext.unit_prefix(unit_index)
for field in vars(external_code.ControlNetUnit()).keys():
# Exclude image for infotext.
if field == "image":
continue
# Every field in ControlNetUnit should have a cooresponding
# IOComponent in ControlNetUiGroup.
io_component = getattr(uigroup, field)
component_locator = f"{unit_prefix} {field}"
self.infotext_fields.append((io_component, component_locator))
self.paste_field_names.append(component_locator)
@staticmethod
def write_infotext(
units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing
):
"""Write infotext to `p`."""
p.extra_generation_params.update(
{
Infotext.unit_prefix(i): serialize_unit(unit)
for i, unit in enumerate(units)
if unit.enabled
}
)
@staticmethod
def on_infotext_pasted(infotext: str, results: dict) -> None:
"""Parse ControlNet infotext string and write result to `results` dict."""
updates = {}
for k, v in results.items():
if not k.startswith("ControlNet"):
continue
assert isinstance(v, str), f"Expect string but got {v}."
try:
for field, value in vars(parse_unit(v)).items():
if field == "image":
continue
if value is None:
logger.debug(f"InfoText: Skipping {field} because value is None.")
continue
component_locator = f"{k} {field}"
updates[component_locator] = value
logger.debug(f"InfoText: Setting {component_locator} = {value}")
except Exception as e:
logger.warn(
f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}\n{e}"
)
results.update(updates)
@@ -0,0 +1,41 @@
import logging
import copy
import sys
from modules import shared
class ColoredFormatter(logging.Formatter):
COLORS = {
"DEBUG": "\033[0;36m", # CYAN
"INFO": "\033[0;32m", # GREEN
"WARNING": "\033[0;33m", # YELLOW
"ERROR": "\033[0;31m", # RED
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
"RESET": "\033[0m", # RESET COLOR
}
def format(self, record):
colored_record = copy.copy(record)
levelname = colored_record.levelname
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
return super().format(colored_record)
# Create a new logger
logger = logging.getLogger("ControlNet")
logger.propagate = False
# Add handler if we don't have one.
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(handler)
# Configure logger
loglevel_string = getattr(shared.cmd_opts, "controlnet_loglevel", "INFO")
loglevel = getattr(logging, loglevel_string.upper(), None)
logger.setLevel(loglevel)
@@ -0,0 +1,88 @@
# High Quality Edge Thinning using Pure Python
# Written by Lvmin Zhang
# 2023 April
# Stanford University
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
import cv2
import numpy as np
lvmin_kernels_raw = [
np.array([
[-1, -1, -1],
[0, 1, 0],
[1, 1, 1]
], dtype=np.int32),
np.array([
[0, -1, -1],
[1, 1, -1],
[0, 1, 0]
], dtype=np.int32)
]
lvmin_kernels = []
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_prunings_raw = [
np.array([
[-1, -1, -1],
[-1, 1, -1],
[0, 0, -1]
], dtype=np.int32),
np.array([
[-1, -1, -1],
[-1, 1, -1],
[-1, 0, 0]
], dtype=np.int32)
]
lvmin_prunings = []
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
def remove_pattern(x, kernel):
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
objects = np.where(objects > 127)
x[objects] = 0
return x, objects[0].shape[0] > 0
def thin_one_time(x, kernels):
y = x
is_done = True
for k in kernels:
y, has_update = remove_pattern(y, k)
if has_update:
is_done = False
return y, is_done
def lvmin_thin(x, prunings=True):
y = x
for i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break
if prunings:
y, _ = thin_one_time(y, lvmin_prunings)
return y
def nake_nms(x):
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
return y
@@ -0,0 +1,180 @@
import torch
import os
import functools
import time
import base64
import numpy as np
import safetensors.torch
import cv2
import logging
from typing import Any, Callable, Dict, List
from modules.safe import unsafe_torch_load
from lib_controlnet.logging import logger
def load_state_dict(ckpt_path, location="cpu"):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location))
state_dict = get_state_dict(state_dict)
logger.info(f"Loaded state_dict from [{ckpt_path}]")
return state_dict
def get_state_dict(d):
return d.get("state_dict", d)
def ndarray_lru_cache(max_size: int = 128, typed: bool = False):
"""
Decorator to enable caching for functions with numpy array arguments.
Numpy arrays are mutable, and thus not directly usable as hash keys.
The idea here is to wrap the incoming arguments with type `np.ndarray`
as `HashableNpArray` so that `lru_cache` can correctly handles `np.ndarray`
arguments.
`HashableNpArray` functions exactly the same way as `np.ndarray` except
having `__hash__` and `__eq__` overriden.
"""
def decorator(func: Callable):
"""The actual decorator that accept function as input."""
class HashableNpArray(np.ndarray):
def __new__(cls, input_array):
# Input array is an instance of ndarray.
# The view makes the input array and returned array share the same data.
obj = np.asarray(input_array).view(cls)
return obj
def __eq__(self, other) -> bool:
return np.array_equal(self, other)
def __hash__(self):
# Hash the bytes representing the data of the array.
return hash(self.tobytes())
@functools.lru_cache(maxsize=max_size, typed=typed)
def cached_func(*args, **kwargs):
"""This function only accepts `HashableNpArray` as input params."""
return func(*args, **kwargs)
# Preserves original function.__name__ and __doc__.
@functools.wraps(func)
def decorated_func(*args, **kwargs):
"""The decorated function that delegates the original function."""
def convert_item(item: Any):
if isinstance(item, np.ndarray):
return HashableNpArray(item)
if isinstance(item, tuple):
return tuple(convert_item(i) for i in item)
return item
args = [convert_item(arg) for arg in args]
kwargs = {k: convert_item(arg) for k, arg in kwargs.items()}
return cached_func(*args, **kwargs)
return decorated_func
return decorator
def timer_decorator(func):
"""Time the decorated function and output the result to debug logger."""
if logger.level != logging.DEBUG:
return func
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
# Only report function that are significant enough.
if duration > 1e-3:
logger.debug(f"{func.__name__} ran in: {duration:.3f} sec")
return result
return wrapper
class TimeMeta(type):
""" Metaclass to record execution time on all methods of the
child class. """
def __new__(cls, name, bases, attrs):
for attr_name, attr_value in attrs.items():
if callable(attr_value):
attrs[attr_name] = timer_decorator(attr_value)
return super().__new__(cls, name, bases, attrs)
# svgsupports
svgsupport = False
try:
import io
from svglib.svglib import svg2rlg
from reportlab.graphics import renderPM
svgsupport = True
except ImportError:
pass
def svg_preprocess(inputs: Dict, preprocess: Callable):
if not inputs:
return None
if inputs["image"].startswith("data:image/svg+xml;base64,") and svgsupport:
svg_data = base64.b64decode(
inputs["image"].replace("data:image/svg+xml;base64,", "")
)
drawing = svg2rlg(io.BytesIO(svg_data))
png_data = renderPM.drawToString(drawing, fmt="PNG")
encoded_string = base64.b64encode(png_data)
base64_str = str(encoded_string, "utf-8")
base64_str = "data:image/png;base64," + base64_str
inputs["image"] = base64_str
return preprocess(inputs)
def get_unique_axis0(data):
arr = np.asanyarray(data)
idxs = np.lexsort(arr.T)
arr = arr[idxs]
unique_idxs = np.empty(len(arr), dtype=np.bool_)
unique_idxs[:1] = True
unique_idxs[1:] = np.any(arr[:-1, :] != arr[1:, :], axis=-1)
return arr[unique_idxs]
def read_image(img_path: str) -> str:
"""Read image from specified path and return a base64 string."""
img = cv2.imread(img_path)
_, bytes = cv2.imencode(".png", img)
encoded_image = base64.b64encode(bytes).decode("utf-8")
return encoded_image
def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]:
"""Try read all images in given img_dir."""
images = []
for filename in os.listdir(img_dir):
if filename.endswith(suffixes):
img_path = os.path.join(img_dir, filename)
try:
images.append(read_image(img_path))
except IOError:
logger.error(f"Error opening {img_path}")
return images
def align_dim_latent(x: int) -> int:
""" Align the pixel dimension (w/h) to latent dimension.
Stable diffusion 1:8 ratio for latent/pixel, i.e.,
1 latent unit == 8 pixel unit."""
return (x // 8) * 8