Cnet (#22)
* ini * remove shit * Create control_model.py * i * i * Update controlnet_supported.py * Update controlnet_supported.py * Update controlnet_supported.py * i * i * Update controlnet_supported.py * i * Update controlnet_supported.py * remove shits * remove shit * Update global_state.py * i * i * Update legacy_preprocessors.py * Update legacy_preprocessors.py * remove shit * Update batch_hijack.py * remove shit * remove shit * i * i * i * Update external_code.py * Update global_state.py * Update infotext.py * Update utils.py * Update external_code.py * i * i * i * Update controlnet_ui_group.py * remove shit * remove shit * i * Update controlnet.py * Update controlnet.py * Update controlnet.py * Update controlnet.py * Update controlnet.py * i * Update global_state.py * Update global_state.py * i * Update global_state.py * Update global_state.py * Update global_state.py * Update global_state.py * Update controlnet_ui_group.py * i * Update global_state.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py * i * Update controlnet_ui_group.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py
This commit is contained in:
+1348
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,38 @@
|
||||
import gradio as gr
|
||||
from typing import List
|
||||
|
||||
|
||||
class ModalInterface(gr.Interface):
|
||||
modal_id_counter = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
html_content: str,
|
||||
open_button_text: str,
|
||||
open_button_classes: List[str] = [],
|
||||
open_button_extra_attrs: str = ''
|
||||
):
|
||||
self.html_content = html_content
|
||||
self.open_button_text = open_button_text
|
||||
self.open_button_classes = open_button_classes
|
||||
self.open_button_extra_attrs = open_button_extra_attrs
|
||||
self.modal_id = ModalInterface.modal_id_counter
|
||||
ModalInterface.modal_id_counter += 1
|
||||
|
||||
def __call__(self):
|
||||
return self.create_modal()
|
||||
|
||||
def create_modal(self, visible=True):
|
||||
html_code = f"""
|
||||
<div id="cnet-modal-{self.modal_id}" class="cnet-modal">
|
||||
<span class="cnet-modal-close">×</span>
|
||||
<div class="cnet-modal-content">
|
||||
{self.html_content}
|
||||
</div>
|
||||
</div>
|
||||
<div id="cnet-modal-open-{self.modal_id}"
|
||||
class="cnet-modal-open {' '.join(self.open_button_classes)}"
|
||||
{self.open_button_extra_attrs}
|
||||
>{self.open_button_text}</div>
|
||||
"""
|
||||
return gr.HTML(value=html_code, visible=visible)
|
||||
+154
@@ -0,0 +1,154 @@
|
||||
import base64
|
||||
import gradio as gr
|
||||
import json
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
from annotator.openpose import decode_json_as_poses, draw_poses
|
||||
from annotator.openpose.animalpose import draw_animalposes
|
||||
from lib_controlnet.controlnet_ui.modal import ModalInterface
|
||||
from modules import shared
|
||||
from lib_controlnet.logging import logger
|
||||
|
||||
|
||||
def parse_data_url(data_url: str):
|
||||
# Split the URL at the comma
|
||||
media_type, data = data_url.split(",", 1)
|
||||
|
||||
# Check if the data is base64-encoded
|
||||
assert ";base64" in media_type
|
||||
|
||||
# Decode the base64 data
|
||||
return base64.b64decode(data)
|
||||
|
||||
|
||||
def encode_data_url(json_string: str) -> str:
|
||||
base64_encoded_json = base64.b64encode(json_string.encode("utf-8")).decode("utf-8")
|
||||
return f"data:application/json;base64,{base64_encoded_json}"
|
||||
|
||||
|
||||
class OpenposeEditor(object):
|
||||
# Filename used when user click the download link.
|
||||
download_file = "pose.json"
|
||||
# URL the openpose editor is mounted on.
|
||||
editor_url = "/openpose_editor_index"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.render_button = None
|
||||
self.pose_input = None
|
||||
self.download_link = None
|
||||
self.upload_link = None
|
||||
self.modal = None
|
||||
|
||||
def render_edit(self):
|
||||
"""Renders the buttons in preview image control button group."""
|
||||
# The hidden button to trigger a re-render of generated image.
|
||||
self.render_button = gr.Button(visible=False, elem_classes=["cnet-render-pose"])
|
||||
# The hidden element that stores the pose json for backend retrieval.
|
||||
# The front-end javascript will write the edited JSON data to the element.
|
||||
self.pose_input = gr.Textbox(visible=False, elem_classes=["cnet-pose-json"])
|
||||
|
||||
self.modal = ModalInterface(
|
||||
# Use about:blank here as placeholder so that the iframe does not
|
||||
# immediately navigate. Most of controlnet units do not need
|
||||
# openpose editor active. Only navigate when the user first click
|
||||
# 'Edit'. The navigation logic is in `openpose_editor.js`.
|
||||
f'<iframe src="about:blank"></iframe>',
|
||||
open_button_text="Edit",
|
||||
open_button_classes=["cnet-edit-pose"],
|
||||
open_button_extra_attrs=f'title="Send pose to {OpenposeEditor.editor_url} for edit."',
|
||||
).create_modal(visible=False)
|
||||
self.download_link = gr.HTML(
|
||||
value=f"""<a href='' download='{OpenposeEditor.download_file}'>JSON</a>""",
|
||||
visible=False,
|
||||
elem_classes=["cnet-download-pose"],
|
||||
)
|
||||
|
||||
def render_upload(self):
|
||||
"""Renders the button in input image control button group."""
|
||||
self.upload_link = gr.HTML(
|
||||
value="""
|
||||
<label>Upload JSON</label>
|
||||
<input type="file" accept=".json"/>
|
||||
""",
|
||||
visible=False,
|
||||
elem_classes=["cnet-upload-pose"],
|
||||
)
|
||||
|
||||
def register_callbacks(
|
||||
self,
|
||||
generated_image: gr.Image,
|
||||
use_preview_as_input: gr.Checkbox,
|
||||
model: gr.Dropdown,
|
||||
):
|
||||
def render_pose(pose_url: str) -> Tuple[Dict, Dict]:
|
||||
json_string = parse_data_url(pose_url).decode("utf-8")
|
||||
poses, animals, height, width = decode_json_as_poses(
|
||||
json.loads(json_string)
|
||||
)
|
||||
logger.info("Preview as input is enabled.")
|
||||
return (
|
||||
# Generated image.
|
||||
gr.update(
|
||||
value=(
|
||||
draw_poses(
|
||||
poses,
|
||||
height,
|
||||
width,
|
||||
draw_body=True,
|
||||
draw_hand=True,
|
||||
draw_face=True,
|
||||
)
|
||||
if poses
|
||||
else draw_animalposes(animals, height, width)
|
||||
),
|
||||
visible=True,
|
||||
),
|
||||
# Use preview as input.
|
||||
gr.update(value=True),
|
||||
# Self content.
|
||||
*self.update(json_string),
|
||||
)
|
||||
|
||||
self.render_button.click(
|
||||
fn=render_pose,
|
||||
inputs=[self.pose_input],
|
||||
outputs=[generated_image, use_preview_as_input, *self.outputs()],
|
||||
)
|
||||
|
||||
def update_upload_link(model: str) -> Dict:
|
||||
return gr.update(visible="openpose" in model.lower())
|
||||
|
||||
model.change(fn=update_upload_link, inputs=[model], outputs=[self.upload_link])
|
||||
|
||||
def outputs(self) -> List[Any]:
|
||||
return [
|
||||
self.download_link,
|
||||
self.modal,
|
||||
]
|
||||
|
||||
def update(self, json_string: str) -> List[Dict]:
|
||||
"""
|
||||
Called when there is a new JSON pose value generated by running
|
||||
preprocessor.
|
||||
|
||||
Args:
|
||||
json_string: The new JSON string generated by preprocessor.
|
||||
|
||||
Returns:
|
||||
An gr.update event.
|
||||
"""
|
||||
hint = "Download the pose as .json file"
|
||||
html = f"""<a href='{encode_data_url(json_string)}'
|
||||
download='{OpenposeEditor.download_file}' title="{hint}">
|
||||
JSON</a>"""
|
||||
|
||||
visible = json_string != ""
|
||||
return [
|
||||
# Download link update.
|
||||
gr.update(value=html, visible=visible),
|
||||
# Modal update.
|
||||
gr.update(
|
||||
visible=visible
|
||||
and not shared.opts.data.get("controlnet_disable_openpose_edit", False)
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,182 @@
|
||||
import gradio as gr
|
||||
|
||||
from lib_controlnet.controlnet_ui.modal import ModalInterface
|
||||
|
||||
PHOTOPEA_LOGO = """
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
width="100%" viewBox="0 0 256 256" enable-background="new 0 0 256 256" xml:space="preserve"
|
||||
style="width: 0.75rem; height 0.75rem; margin-left: 2px;"
|
||||
>
|
||||
<path fill="#18A497" opacity="1.000000" stroke="none"
|
||||
d="
|
||||
M1.000000,228.000000
|
||||
C1.000000,162.312439 1.000000,96.624878 1.331771,30.719650
|
||||
C2.026278,30.171114 2.594676,29.904894 2.721949,29.500008
|
||||
C6.913495,16.165672 15.629609,7.322631 28.880219,2.875538
|
||||
C29.404272,2.699659 29.633436,1.645129 30.000000,1.000000
|
||||
C95.687561,1.000000 161.375122,1.000000 227.258057,1.317018
|
||||
C227.660217,1.893988 227.815079,2.296565 228.081207,2.393433
|
||||
C241.304657,7.206383 250.980164,15.550970 255.215851,29.410040
|
||||
C255.321625,29.756128 256.383850,29.809898 257.000000,30.000000
|
||||
C257.000000,95.687561 257.000000,161.375122 256.682983,227.257858
|
||||
C256.106049,227.659790 255.699371,227.815521 255.607178,228.080658
|
||||
C250.953033,241.462830 242.292618,250.822968 228.591782,255.214935
|
||||
C228.239929,255.327698 228.190491,256.383820 228.000000,257.000000
|
||||
C175.312439,257.000000 122.624878,257.000000 69.468582,256.531342
|
||||
C68.672188,244.948196 68.218323,233.835587 68.052299,222.718674
|
||||
C67.885620,211.557587 67.886772,200.390717 68.027298,189.229050
|
||||
C68.255180,171.129044 68.084618,152.997421 69.151917,134.942368
|
||||
C70.148468,118.083969 77.974228,103.689308 89.758743,91.961365
|
||||
C104.435837,77.354736 122.313736,69.841736 143.417328,69.901505
|
||||
C168.662338,69.972984 186.981964,90.486633 187.961487,114.156334
|
||||
C189.042435,140.277435 166.783981,163.607941 140.303482,160.823074
|
||||
C137.092346,160.485382 133.490692,158.365784 131.192612,155.987366
|
||||
C126.434669,151.063141 126.975357,144.720825 129.168777,138.834930
|
||||
C131.533630,132.489014 137.260605,130.548050 143.413757,130.046677
|
||||
C150.288467,129.486496 156.424942,123.757378 157.035324,117.320816
|
||||
C157.953949,107.633820 150.959381,101.769096 145.533951,101.194389
|
||||
C132.238846,99.786079 120.699944,104.963120 111.676735,114.167313
|
||||
C102.105782,123.930222 97.469498,136.194061 99.003151,150.234955
|
||||
C100.540352,164.308228 107.108505,175.507980 118.864334,183.311539
|
||||
C128.454544,189.677597 138.866959,191.786957 150.657837,190.245651
|
||||
C166.242554,188.208420 179.874283,182.443329 191.251801,172.056793
|
||||
C209.355011,155.530380 217.848694,134.938721 216.116119,110.085892
|
||||
C214.834335,91.699440 207.721039,76.015915 195.289444,62.978828
|
||||
C175.658447,42.391735 150.833389,37.257801 123.833740,42.281937
|
||||
C98.675804,46.963364 78.315033,60.084667 62.208153,80.157814
|
||||
C46.645889,99.552216 39.305275,121.796379 39.149052,146.201981
|
||||
C38.912663,183.131317 39.666767,220.067017 40.000000,257.000000
|
||||
C36.969406,257.000000 33.938812,257.000000 30.705070,256.668213
|
||||
C30.298622,256.078369 30.144913,255.669220 29.884926,255.583878
|
||||
C16.317770,251.131058 7.127485,242.317780 2.778462,228.591797
|
||||
C2.667588,228.241821 1.613958,228.190567 1.000000,228.000000
|
||||
z"/>
|
||||
<path fill="#000000" opacity="1.000000" stroke="none"
|
||||
d="
|
||||
M40.468658,257.000000
|
||||
C39.666767,220.067017 38.912663,183.131317 39.149052,146.201981
|
||||
C39.305275,121.796379 46.645889,99.552216 62.208153,80.157814
|
||||
C78.315033,60.084667 98.675804,46.963364 123.833740,42.281937
|
||||
C150.833389,37.257801 175.658447,42.391735 195.289444,62.978828
|
||||
C207.721039,76.015915 214.834335,91.699440 216.116119,110.085892
|
||||
C217.848694,134.938721 209.355011,155.530380 191.251801,172.056793
|
||||
C179.874283,182.443329 166.242554,188.208420 150.657837,190.245651
|
||||
C138.866959,191.786957 128.454544,189.677597 118.864334,183.311539
|
||||
C107.108505,175.507980 100.540352,164.308228 99.003151,150.234955
|
||||
C97.469498,136.194061 102.105782,123.930222 111.676735,114.167313
|
||||
C120.699944,104.963120 132.238846,99.786079 145.533951,101.194389
|
||||
C150.959381,101.769096 157.953949,107.633820 157.035324,117.320816
|
||||
C156.424942,123.757378 150.288467,129.486496 143.413757,130.046677
|
||||
C137.260605,130.548050 131.533630,132.489014 129.168777,138.834930
|
||||
C126.975357,144.720825 126.434669,151.063141 131.192612,155.987366
|
||||
C133.490692,158.365784 137.092346,160.485382 140.303482,160.823074
|
||||
C166.783981,163.607941 189.042435,140.277435 187.961487,114.156334
|
||||
C186.981964,90.486633 168.662338,69.972984 143.417328,69.901505
|
||||
C122.313736,69.841736 104.435837,77.354736 89.758743,91.961365
|
||||
C77.974228,103.689308 70.148468,118.083969 69.151917,134.942368
|
||||
C68.084618,152.997421 68.255180,171.129044 68.027298,189.229050
|
||||
C67.886772,200.390717 67.885620,211.557587 68.052299,222.718674
|
||||
C68.218323,233.835587 68.672188,244.948196 68.999924,256.531342
|
||||
C59.645771,257.000000 50.291542,257.000000 40.468658,257.000000
|
||||
z"/>
|
||||
<path fill="#000000" opacity="1.000000" stroke="none"
|
||||
d="
|
||||
M257.000000,29.531342
|
||||
C256.383850,29.809898 255.321625,29.756128 255.215851,29.410040
|
||||
C250.980164,15.550970 241.304657,7.206383 228.081207,2.393433
|
||||
C227.815079,2.296565 227.660217,1.893988 227.726715,1.317018
|
||||
C237.593155,1.000000 247.186295,1.000000 257.000000,1.000000
|
||||
C257.000000,10.353075 257.000000,19.707878 257.000000,29.531342
|
||||
z"/>
|
||||
<path fill="#000000" opacity="1.000000" stroke="none"
|
||||
d="
|
||||
M228.468658,257.000000
|
||||
C228.190491,256.383820 228.239929,255.327698 228.591782,255.214935
|
||||
C242.292618,250.822968 250.953033,241.462830 255.607178,228.080658
|
||||
C255.699371,227.815521 256.106049,227.659790 256.682983,227.726517
|
||||
C257.000000,237.593155 257.000000,247.186295 257.000000,257.000000
|
||||
C247.646927,257.000000 238.292114,257.000000 228.468658,257.000000
|
||||
z"/>
|
||||
<path fill="#000000" opacity="1.000000" stroke="none"
|
||||
d="
|
||||
M1.000000,228.468658
|
||||
C1.613958,228.190567 2.667588,228.241821 2.778462,228.591797
|
||||
C7.127485,242.317780 16.317770,251.131058 29.884926,255.583878
|
||||
C30.144913,255.669220 30.298622,256.078369 30.250959,256.668213
|
||||
C20.406853,257.000000 10.813705,257.000000 1.000000,257.000000
|
||||
C1.000000,247.646927 1.000000,238.292114 1.000000,228.468658
|
||||
z"/>
|
||||
<path fill="#000000" opacity="1.000000" stroke="none"
|
||||
d="
|
||||
M29.531342,1.000000
|
||||
C29.633436,1.645129 29.404272,2.699659 28.880219,2.875538
|
||||
C15.629609,7.322631 6.913495,16.165672 2.721949,29.500008
|
||||
C2.594676,29.904894 2.026278,30.171114 1.331771,30.250992
|
||||
C1.000000,20.406855 1.000000,10.813709 1.000000,1.000000
|
||||
C10.353074,1.000000 19.707878,1.000000 29.531342,1.000000
|
||||
z"/>
|
||||
</svg>"""
|
||||
|
||||
|
||||
class Photopea(object):
|
||||
def __init__(self) -> None:
|
||||
self.modal = None
|
||||
self.triggers = []
|
||||
self.render_editor()
|
||||
|
||||
def render_editor(self):
|
||||
"""Render the editor modal."""
|
||||
with gr.Group(elem_classes=["cnet-photopea-edit"]):
|
||||
self.modal = ModalInterface(
|
||||
# Use about:blank here as placeholder so that the iframe does not
|
||||
# immediately navigate. Only navigate when the user first click
|
||||
# 'Edit'. The navigation logic is in `photopea.js`.
|
||||
f"""
|
||||
<div class="photopea-button-group">
|
||||
<button class="photopea-button photopea-fetch">Fetch from ControlNet</button>
|
||||
<button class="photopea-button photopea-send">Send to ControlNet</button>
|
||||
</div>
|
||||
<iframe class="photopea-iframe" src="about:blank"></iframe>
|
||||
""",
|
||||
open_button_text="Edit",
|
||||
open_button_classes=["cnet-photopea-main-trigger"],
|
||||
open_button_extra_attrs="hidden",
|
||||
).create_modal(visible=True)
|
||||
|
||||
def render_child_trigger(self):
|
||||
self.triggers.append(
|
||||
gr.HTML(
|
||||
f"""<div class="cnet-photopea-child-trigger">
|
||||
Edit {PHOTOPEA_LOGO}
|
||||
</div>"""
|
||||
)
|
||||
)
|
||||
|
||||
def attach_photopea_output(self, generated_image: gr.Image):
|
||||
"""Called in ControlNetUiGroup to attach preprocessor preview image Gradio element
|
||||
as the photopea output. If the front-end directly change the img HTML element's src
|
||||
to reflect the edited image result from photopea, the backend won't be notified.
|
||||
|
||||
In this method we let the front-end upload the result image an invisible gr.Image
|
||||
instance and mirrors the value to preprocessor preview gr.Image. This is because
|
||||
the generated image gr.Image instance is inferred to be an output image by Gradio
|
||||
and has no ability to accept image upload directly.
|
||||
|
||||
Arguments:
|
||||
generated_image: preprocessor result Gradio Image output element.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
output = gr.Image(
|
||||
visible=False,
|
||||
source="upload",
|
||||
type="numpy",
|
||||
elem_classes=[f"cnet-photopea-output"],
|
||||
)
|
||||
|
||||
output.upload(
|
||||
fn=lambda img: img,
|
||||
inputs=[output],
|
||||
outputs=[generated_image],
|
||||
)
|
||||
@@ -0,0 +1,318 @@
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from modules import scripts
|
||||
from lib_controlnet.infotext import parse_unit, serialize_unit
|
||||
from lib_controlnet.controlnet_ui.tool_button import ToolButton
|
||||
from lib_controlnet.logging import logger
|
||||
from lib_controlnet import external_code
|
||||
|
||||
save_symbol = "\U0001f4be" # 💾
|
||||
delete_symbol = "\U0001f5d1\ufe0f" # 🗑️
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
reset_symbol = "\U000021A9" # ↩
|
||||
|
||||
NEW_PRESET = "New Preset"
|
||||
|
||||
|
||||
def load_presets(preset_dir: str) -> Dict[str, str]:
|
||||
if not os.path.exists(preset_dir):
|
||||
os.makedirs(preset_dir)
|
||||
return {}
|
||||
|
||||
presets = {}
|
||||
for filename in os.listdir(preset_dir):
|
||||
if filename.endswith(".txt"):
|
||||
with open(os.path.join(preset_dir, filename), "r") as f:
|
||||
name = filename.replace(".txt", "")
|
||||
if name == NEW_PRESET:
|
||||
continue
|
||||
presets[name] = f.read()
|
||||
return presets
|
||||
|
||||
|
||||
def infer_control_type(module: str, model: str) -> str:
|
||||
def matches_control_type(input_string: str, control_type: str) -> bool:
|
||||
return any(t.lower() in input_string for t in control_type.split("/"))
|
||||
|
||||
control_types = preprocessor_filters.keys()
|
||||
control_type_candidates = [
|
||||
control_type
|
||||
for control_type in control_types
|
||||
if (
|
||||
matches_control_type(module, control_type)
|
||||
or matches_control_type(model, control_type)
|
||||
)
|
||||
]
|
||||
if len(control_type_candidates) != 1:
|
||||
raise ValueError(
|
||||
f"Unable to infer control type from module {module} and model {model}"
|
||||
)
|
||||
return control_type_candidates[0]
|
||||
|
||||
|
||||
class ControlNetPresetUI(object):
|
||||
preset_directory = os.path.join(scripts.basedir(), "presets")
|
||||
presets = load_presets(preset_directory)
|
||||
|
||||
def __init__(self, id_prefix: str):
|
||||
with gr.Row():
|
||||
self.dropdown = gr.Dropdown(
|
||||
label="Presets",
|
||||
show_label=True,
|
||||
elem_classes=["cnet-preset-dropdown"],
|
||||
choices=ControlNetPresetUI.dropdown_choices(),
|
||||
value=NEW_PRESET,
|
||||
)
|
||||
self.reset_button = ToolButton(
|
||||
value=reset_symbol,
|
||||
elem_classes=["cnet-preset-reset"],
|
||||
tooltip="Reset preset",
|
||||
visible=False,
|
||||
)
|
||||
self.save_button = ToolButton(
|
||||
value=save_symbol,
|
||||
elem_classes=["cnet-preset-save"],
|
||||
tooltip="Save preset",
|
||||
)
|
||||
self.delete_button = ToolButton(
|
||||
value=delete_symbol,
|
||||
elem_classes=["cnet-preset-delete"],
|
||||
tooltip="Delete preset",
|
||||
)
|
||||
self.refresh_button = ToolButton(
|
||||
value=refresh_symbol,
|
||||
elem_classes=["cnet-preset-refresh"],
|
||||
tooltip="Refresh preset",
|
||||
)
|
||||
|
||||
with gr.Box(
|
||||
elem_classes=["popup-dialog", "cnet-preset-enter-name"],
|
||||
elem_id=f"{id_prefix}_cnet_preset_enter_name",
|
||||
) as self.name_dialog:
|
||||
with gr.Row():
|
||||
self.preset_name = gr.Textbox(
|
||||
label="Preset name",
|
||||
show_label=True,
|
||||
lines=1,
|
||||
elem_classes=["cnet-preset-name"],
|
||||
)
|
||||
self.confirm_preset_name = ToolButton(
|
||||
value=save_symbol,
|
||||
elem_classes=["cnet-preset-confirm-name"],
|
||||
tooltip="Save preset",
|
||||
)
|
||||
|
||||
def register_callbacks(
|
||||
self,
|
||||
uigroup,
|
||||
control_type: gr.Radio,
|
||||
*ui_states,
|
||||
):
|
||||
def apply_preset(name: str, control_type: str, *ui_states):
|
||||
if name == NEW_PRESET:
|
||||
return (
|
||||
gr.update(visible=False),
|
||||
*(
|
||||
(gr.skip(),)
|
||||
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
|
||||
),
|
||||
)
|
||||
|
||||
assert name in ControlNetPresetUI.presets
|
||||
|
||||
infotext = ControlNetPresetUI.presets[name]
|
||||
preset_unit = parse_unit(infotext)
|
||||
current_unit = external_code.ControlNetUnit(*ui_states)
|
||||
preset_unit.image = None
|
||||
current_unit.image = None
|
||||
|
||||
# Do not compare module param that are not used in preset.
|
||||
for module_param in ("processor_res", "threshold_a", "threshold_b"):
|
||||
if getattr(preset_unit, module_param) == -1:
|
||||
setattr(current_unit, module_param, -1)
|
||||
|
||||
# No update necessary.
|
||||
if vars(current_unit) == vars(preset_unit):
|
||||
return (
|
||||
gr.update(visible=False),
|
||||
*(
|
||||
(gr.skip(),)
|
||||
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
|
||||
),
|
||||
)
|
||||
|
||||
unit = preset_unit
|
||||
|
||||
try:
|
||||
new_control_type = infer_control_type(unit.module, unit.model)
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
new_control_type = control_type
|
||||
|
||||
if new_control_type != control_type:
|
||||
uigroup.prevent_next_n_module_update += 1
|
||||
|
||||
if preset_unit.module != current_unit.module:
|
||||
uigroup.prevent_next_n_slider_value_update += 1
|
||||
|
||||
if preset_unit.pixel_perfect != current_unit.pixel_perfect:
|
||||
uigroup.prevent_next_n_slider_value_update += 1
|
||||
|
||||
return (
|
||||
gr.update(visible=True),
|
||||
gr.update(value=new_control_type),
|
||||
*[
|
||||
gr.update(value=value) if value is not None else gr.update()
|
||||
for value in vars(unit).values()
|
||||
],
|
||||
)
|
||||
|
||||
for element, action in (
|
||||
(self.dropdown, "change"),
|
||||
(self.reset_button, "click"),
|
||||
):
|
||||
getattr(element, action)(
|
||||
fn=apply_preset,
|
||||
inputs=[self.dropdown, control_type, *ui_states],
|
||||
outputs=[self.delete_button, control_type, *ui_states],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=lambda: gr.update(visible=False),
|
||||
inputs=None,
|
||||
outputs=[self.reset_button],
|
||||
)
|
||||
|
||||
def save_preset(name: str, *ui_states):
|
||||
if name == NEW_PRESET:
|
||||
return gr.update(visible=True), gr.update(), gr.update()
|
||||
|
||||
ControlNetPresetUI.save_preset(
|
||||
name, external_code.ControlNetUnit(*ui_states)
|
||||
)
|
||||
return (
|
||||
gr.update(), # name dialog
|
||||
gr.update(choices=ControlNetPresetUI.dropdown_choices(), value=name),
|
||||
gr.update(visible=False), # Reset button
|
||||
)
|
||||
|
||||
self.save_button.click(
|
||||
fn=save_preset,
|
||||
inputs=[self.dropdown, *ui_states],
|
||||
outputs=[self.name_dialog, self.dropdown, self.reset_button],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=None,
|
||||
_js=f"""
|
||||
(name) => {{
|
||||
if (name === "{NEW_PRESET}")
|
||||
popup(gradioApp().getElementById('{self.name_dialog.elem_id}'));
|
||||
}}""",
|
||||
inputs=[self.dropdown],
|
||||
)
|
||||
|
||||
def delete_preset(name: str):
|
||||
ControlNetPresetUI.delete_preset(name)
|
||||
return gr.Dropdown.update(
|
||||
choices=ControlNetPresetUI.dropdown_choices(),
|
||||
value=NEW_PRESET,
|
||||
), gr.update(visible=False)
|
||||
|
||||
self.delete_button.click(
|
||||
fn=delete_preset,
|
||||
inputs=[self.dropdown],
|
||||
outputs=[self.dropdown, self.reset_button],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
self.name_dialog.visible = False
|
||||
|
||||
def save_new_preset(new_name: str, *ui_states):
|
||||
if new_name == NEW_PRESET:
|
||||
logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'")
|
||||
return gr.update(visible=False), gr.update()
|
||||
|
||||
ControlNetPresetUI.save_preset(
|
||||
new_name, external_code.ControlNetUnit(*ui_states)
|
||||
)
|
||||
return gr.update(visible=False), gr.update(
|
||||
choices=ControlNetPresetUI.dropdown_choices(), value=new_name
|
||||
)
|
||||
|
||||
self.confirm_preset_name.click(
|
||||
fn=save_new_preset,
|
||||
inputs=[self.preset_name, *ui_states],
|
||||
outputs=[self.name_dialog, self.dropdown],
|
||||
show_progress="hidden",
|
||||
).then(fn=None, _js="closePopup")
|
||||
|
||||
self.refresh_button.click(
|
||||
fn=ControlNetPresetUI.refresh_preset,
|
||||
inputs=None,
|
||||
outputs=[self.dropdown],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
def update_reset_button(preset_name: str, *ui_states):
|
||||
if preset_name == NEW_PRESET:
|
||||
return gr.update(visible=False)
|
||||
|
||||
infotext = ControlNetPresetUI.presets[preset_name]
|
||||
preset_unit = parse_unit(infotext)
|
||||
current_unit = external_code.ControlNetUnit(*ui_states)
|
||||
preset_unit.image = None
|
||||
current_unit.image = None
|
||||
|
||||
# Do not compare module param that are not used in preset.
|
||||
for module_param in ("processor_res", "threshold_a", "threshold_b"):
|
||||
if getattr(preset_unit, module_param) == -1:
|
||||
setattr(current_unit, module_param, -1)
|
||||
|
||||
return gr.update(visible=vars(current_unit) != vars(preset_unit))
|
||||
|
||||
for ui_state in ui_states:
|
||||
if isinstance(ui_state, gr.Image):
|
||||
continue
|
||||
|
||||
for action in ("edit", "click", "change", "clear", "release"):
|
||||
if action == "release" and not isinstance(ui_state, gr.Slider):
|
||||
continue
|
||||
|
||||
if hasattr(ui_state, action):
|
||||
getattr(ui_state, action)(
|
||||
fn=update_reset_button,
|
||||
inputs=[self.dropdown, *ui_states],
|
||||
outputs=[self.reset_button],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def dropdown_choices() -> List[str]:
|
||||
return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET]
|
||||
|
||||
@staticmethod
|
||||
def save_preset(name: str, unit: external_code.ControlNetUnit):
|
||||
infotext = serialize_unit(unit)
|
||||
with open(
|
||||
os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"
|
||||
) as f:
|
||||
f.write(infotext)
|
||||
|
||||
ControlNetPresetUI.presets[name] = infotext
|
||||
|
||||
@staticmethod
|
||||
def delete_preset(name: str):
|
||||
if name not in ControlNetPresetUI.presets:
|
||||
return
|
||||
|
||||
del ControlNetPresetUI.presets[name]
|
||||
|
||||
file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt")
|
||||
if os.path.exists(file):
|
||||
os.unlink(file)
|
||||
|
||||
@staticmethod
|
||||
def refresh_preset():
|
||||
ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory)
|
||||
return gr.update(choices=ControlNetPresetUI.dropdown_choices())
|
||||
@@ -0,0 +1,12 @@
|
||||
import gradio as gr
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool",
|
||||
elem_classes=kwargs.pop('elem_classes', []) + ["cnet-toolbutton"],
|
||||
**kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
@@ -0,0 +1,73 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StableDiffusionVersion(Enum):
|
||||
"""The version family of stable diffusion model."""
|
||||
|
||||
UNKNOWN = 0
|
||||
SD1x = 1
|
||||
SD2x = 2
|
||||
SDXL = 3
|
||||
|
||||
@staticmethod
|
||||
def detect_from_model_name(model_name: str) -> "StableDiffusionVersion":
|
||||
"""Based on the model name provided, guess what stable diffusion version it is.
|
||||
This might not be accurate without actually inspect the file content.
|
||||
"""
|
||||
if any(f"sd{v}" in model_name.lower() for v in ("14", "15", "16")):
|
||||
return StableDiffusionVersion.SD1x
|
||||
|
||||
if "sd21" in model_name or "2.1" in model_name:
|
||||
return StableDiffusionVersion.SD2x
|
||||
|
||||
if "xl" in model_name.lower():
|
||||
return StableDiffusionVersion.SDXL
|
||||
|
||||
return StableDiffusionVersion.UNKNOWN
|
||||
|
||||
def encoder_block_num(self) -> int:
|
||||
if self in (StableDiffusionVersion.SD1x, StableDiffusionVersion.SD2x, StableDiffusionVersion.UNKNOWN):
|
||||
return 12
|
||||
else:
|
||||
return 9 # SDXL
|
||||
|
||||
def controlnet_layer_num(self) -> int:
|
||||
return self.encoder_block_num() + 1
|
||||
|
||||
def is_compatible_with(self, other: "StableDiffusionVersion") -> bool:
|
||||
""" Incompatible only when one of version is SDXL and other is not. """
|
||||
return (
|
||||
any(v == StableDiffusionVersion.UNKNOWN for v in [self, other]) or
|
||||
sum(v == StableDiffusionVersion.SDXL for v in [self, other]) != 1
|
||||
)
|
||||
|
||||
|
||||
class HiResFixOption(Enum):
|
||||
BOTH = "Both"
|
||||
LOW_RES_ONLY = "Low res only"
|
||||
HIGH_RES_ONLY = "High res only"
|
||||
|
||||
@staticmethod
|
||||
def from_value(value: Any) -> "HiResFixOption":
|
||||
if isinstance(value, str) and value.startswith("HiResFixOption."):
|
||||
_, field = value.split(".")
|
||||
return getattr(HiResFixOption, field)
|
||||
if isinstance(value, str):
|
||||
return HiResFixOption(value)
|
||||
elif isinstance(value, int):
|
||||
return [x for x in HiResFixOption][value]
|
||||
else:
|
||||
assert isinstance(value, HiResFixOption)
|
||||
return value
|
||||
|
||||
|
||||
class InputMode(Enum):
|
||||
# Single image to a single ControlNet unit.
|
||||
SIMPLE = "simple"
|
||||
# Input is a directory. N generations. Each generation takes 1 input image
|
||||
# from the directory.
|
||||
BATCH = "batch"
|
||||
# Input is a directory. 1 generation. Each generation takes N input image
|
||||
# from the directory.
|
||||
MERGE = "merge"
|
||||
@@ -0,0 +1,460 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from copy import copy
|
||||
from typing import List, Any, Optional, Union, Tuple, Dict
|
||||
import numpy as np
|
||||
from modules import scripts, processing, shared
|
||||
from lib_controlnet import global_state
|
||||
from lib_controlnet.logging import logger
|
||||
from lib_controlnet.enums import HiResFixOption
|
||||
|
||||
from modules.api import api
|
||||
|
||||
|
||||
def get_api_version() -> int:
|
||||
return 2
|
||||
|
||||
|
||||
class ControlMode(Enum):
|
||||
"""
|
||||
The improved guess mode.
|
||||
"""
|
||||
|
||||
BALANCED = "Balanced"
|
||||
PROMPT = "My prompt is more important"
|
||||
CONTROL = "ControlNet is more important"
|
||||
|
||||
|
||||
class BatchOption(Enum):
|
||||
DEFAULT = "All ControlNet units for all images in a batch"
|
||||
SEPARATE = "Each ControlNet unit for each image in a batch"
|
||||
|
||||
|
||||
class ResizeMode(Enum):
|
||||
"""
|
||||
Resize modes for ControlNet input images.
|
||||
"""
|
||||
|
||||
RESIZE = "Just Resize"
|
||||
INNER_FIT = "Crop and Resize"
|
||||
OUTER_FIT = "Resize and Fill"
|
||||
|
||||
def int_value(self):
|
||||
if self == ResizeMode.RESIZE:
|
||||
return 0
|
||||
elif self == ResizeMode.INNER_FIT:
|
||||
return 1
|
||||
elif self == ResizeMode.OUTER_FIT:
|
||||
return 2
|
||||
assert False, "NOTREACHED"
|
||||
|
||||
|
||||
resize_mode_aliases = {
|
||||
'Inner Fit (Scale to Fit)': 'Crop and Resize',
|
||||
'Outer Fit (Shrink to Fit)': 'Resize and Fill',
|
||||
'Scale to Fit (Inner Fit)': 'Crop and Resize',
|
||||
'Envelope (Outer Fit)': 'Resize and Fill',
|
||||
}
|
||||
|
||||
|
||||
def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode:
|
||||
if isinstance(value, str):
|
||||
return ResizeMode(resize_mode_aliases.get(value, value))
|
||||
elif isinstance(value, int):
|
||||
assert value >= 0
|
||||
if value == 3: # 'Just Resize (Latent upscale)'
|
||||
return ResizeMode.RESIZE
|
||||
|
||||
if value >= len(ResizeMode):
|
||||
logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.')
|
||||
return ResizeMode.RESIZE
|
||||
|
||||
return [e for e in ResizeMode][value]
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode:
|
||||
if isinstance(value, str):
|
||||
return ControlMode(value)
|
||||
elif isinstance(value, int):
|
||||
return [e for e in ControlMode][value]
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def visualize_inpaint_mask(img):
|
||||
if img.ndim == 3 and img.shape[2] == 4:
|
||||
result = img.copy()
|
||||
mask = result[:, :, 3]
|
||||
mask = 255 - mask // 2
|
||||
result[:, :, 3] = mask
|
||||
return np.ascontiguousarray(result.copy())
|
||||
return img
|
||||
|
||||
|
||||
def pixel_perfect_resolution(
|
||||
image: np.ndarray,
|
||||
target_H: int,
|
||||
target_W: int,
|
||||
resize_mode: ResizeMode,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
||||
|
||||
The function first calculates scaling factors for height and width of the image based on the target
|
||||
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
|
||||
scaling factor to estimate the new resolution.
|
||||
|
||||
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
|
||||
fits within the target dimensions, potentially leaving some empty space.
|
||||
|
||||
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
|
||||
dimensions are fully filled, potentially cropping the image.
|
||||
|
||||
After calculating the estimated resolution, the function prints some debugging information.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
|
||||
target_H (int): The target height for the image.
|
||||
target_W (int): The target width for the image.
|
||||
resize_mode (ResizeMode): The mode for resizing.
|
||||
|
||||
Returns:
|
||||
int: The estimated resolution after resizing.
|
||||
"""
|
||||
raw_H, raw_W, _ = image.shape
|
||||
|
||||
k0 = float(target_H) / float(raw_H)
|
||||
k1 = float(target_W) / float(raw_W)
|
||||
|
||||
if resize_mode == ResizeMode.OUTER_FIT:
|
||||
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||
else:
|
||||
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||
|
||||
logger.debug(f"Pixel Perfect Computation:")
|
||||
logger.debug(f"resize_mode = {resize_mode}")
|
||||
logger.debug(f"raw_H = {raw_H}")
|
||||
logger.debug(f"raw_W = {raw_W}")
|
||||
logger.debug(f"target_H = {target_H}")
|
||||
logger.debug(f"target_W = {target_W}")
|
||||
logger.debug(f"estimation = {estimation}")
|
||||
|
||||
return int(np.round(estimation))
|
||||
|
||||
|
||||
InputImage = Union[np.ndarray, str]
|
||||
InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetUnit:
|
||||
"""
|
||||
Represents an entire ControlNet processing unit.
|
||||
"""
|
||||
enabled: bool = True
|
||||
module: str = "none"
|
||||
model: str = "None"
|
||||
weight: float = 1.0
|
||||
image: Optional[Union[InputImage, List[InputImage]]] = None
|
||||
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
|
||||
low_vram: bool = False
|
||||
processor_res: int = -1
|
||||
threshold_a: float = -1
|
||||
threshold_b: float = -1
|
||||
guidance_start: float = 0.0
|
||||
guidance_end: float = 1.0
|
||||
pixel_perfect: bool = False
|
||||
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
|
||||
# Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area`
|
||||
# in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`.
|
||||
inpaint_crop_input_image: bool = True
|
||||
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
|
||||
# The value is ignored if the generation is not using hires fix.
|
||||
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
|
||||
|
||||
# Whether save the detected map of this unit. Setting this option to False prevents saving the
|
||||
# detected map or sending detected map along with generated images via API.
|
||||
# Currently the option is only accessible in API calls.
|
||||
save_detected_map: bool = True
|
||||
|
||||
# Weight for each layer of ControlNet params.
|
||||
# For ControlNet:
|
||||
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
|
||||
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
|
||||
# For T2IAdapter
|
||||
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
|
||||
# - SDXL: 4 weights (3 encoder block + 1 middle block)
|
||||
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
|
||||
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
|
||||
# Note2: The field `weight` is still used in some places, e.g. reference_only,
|
||||
# even advanced_weighting is set.
|
||||
advanced_weighting: Optional[List[float]] = None
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ControlNetUnit):
|
||||
return False
|
||||
|
||||
return vars(self) == vars(other)
|
||||
|
||||
def accepts_multiple_inputs(self) -> bool:
|
||||
"""This unit can accept multiple input images."""
|
||||
return self.module in (
|
||||
"ip-adapter_clip_sdxl",
|
||||
"ip-adapter_clip_sdxl_plus_vith",
|
||||
"ip-adapter_clip_sd15",
|
||||
"ip-adapter_face_id",
|
||||
"ip-adapter_face_id_plus",
|
||||
"instant_id_face_embedding",
|
||||
)
|
||||
|
||||
|
||||
def to_base64_nparray(encoding: str):
|
||||
"""
|
||||
Convert a base64 image into the image type the extension uses
|
||||
"""
|
||||
|
||||
return np.array(api.decode_base64_to_image(encoding)).astype('uint8')
|
||||
|
||||
|
||||
def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]:
|
||||
"""
|
||||
Fetch ControlNet processing units from a StableDiffusionProcessing.
|
||||
"""
|
||||
|
||||
return get_all_units(p.scripts, p.script_args)
|
||||
|
||||
|
||||
def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]:
|
||||
"""
|
||||
Fetch ControlNet processing units from an existing script runner.
|
||||
Use this function to fetch units from the list of all scripts arguments.
|
||||
"""
|
||||
|
||||
cn_script = find_cn_script(script_runner)
|
||||
if cn_script:
|
||||
return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to])
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]:
|
||||
"""
|
||||
Fetch ControlNet processing units from ControlNet script arguments.
|
||||
Use `external_code.get_all_units` to fetch units from the list of all scripts arguments.
|
||||
"""
|
||||
|
||||
def is_stale_unit(script_arg: Any) -> bool:
|
||||
""" Returns whether the script_arg is potentially an stale version of
|
||||
ControlNetUnit created before module reload."""
|
||||
return (
|
||||
'ControlNetUnit' in type(script_arg).__name__ and
|
||||
not isinstance(script_arg, ControlNetUnit)
|
||||
)
|
||||
|
||||
def is_controlnet_unit(script_arg: Any) -> bool:
|
||||
""" Returns whether the script_arg is ControlNetUnit or anything that
|
||||
can be treated like ControlNetUnit. """
|
||||
return (
|
||||
isinstance(script_arg, (ControlNetUnit, dict)) or
|
||||
(
|
||||
hasattr(script_arg, '__dict__') and
|
||||
set(vars(ControlNetUnit()).keys()).issubset(
|
||||
set(vars(script_arg).keys()))
|
||||
)
|
||||
)
|
||||
|
||||
all_units = [
|
||||
to_processing_unit(script_arg)
|
||||
for script_arg in script_args
|
||||
if is_controlnet_unit(script_arg)
|
||||
]
|
||||
if not all_units:
|
||||
logger.warning(
|
||||
"No ControlNetUnit detected in args. It is very likely that you are having an extension conflict."
|
||||
f"Here are args received by ControlNet: {script_args}.")
|
||||
if any(is_stale_unit(script_arg) for script_arg in script_args):
|
||||
logger.debug(
|
||||
"Stale version of ControlNetUnit detected. The ControlNetUnit received"
|
||||
"by ControlNet is created before the newest load of ControlNet extension."
|
||||
"They will still be used by ControlNet as long as they provide same fields"
|
||||
"defined in the newest version of ControlNetUnit."
|
||||
)
|
||||
|
||||
return all_units
|
||||
|
||||
|
||||
def get_single_unit_from(script_args: List[Any], index: int = 0) -> Optional[ControlNetUnit]:
|
||||
"""
|
||||
Fetch a single ControlNet processing unit from ControlNet script arguments.
|
||||
The list must not contain script positional arguments. It must only contain processing units.
|
||||
"""
|
||||
|
||||
i = 0
|
||||
while i < len(script_args) and index >= 0:
|
||||
if index == 0 and script_args[i] is not None:
|
||||
return to_processing_unit(script_args[i])
|
||||
i += 1
|
||||
|
||||
index -= 1
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_max_models_num():
|
||||
"""
|
||||
Fetch the maximum number of allowed ControlNet models.
|
||||
"""
|
||||
|
||||
max_models_num = shared.opts.data.get("control_net_unit_count", 3)
|
||||
return max_models_num
|
||||
|
||||
|
||||
def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit:
|
||||
"""
|
||||
Convert different types to processing unit.
|
||||
If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details.
|
||||
"""
|
||||
|
||||
ext_compat_keys = {
|
||||
'guessmode': 'guess_mode',
|
||||
'guidance': 'guidance_end',
|
||||
'lowvram': 'low_vram',
|
||||
'input_image': 'image'
|
||||
}
|
||||
|
||||
if isinstance(unit, dict):
|
||||
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()}
|
||||
|
||||
mask = None
|
||||
if 'mask' in unit:
|
||||
mask = unit['mask']
|
||||
del unit['mask']
|
||||
|
||||
if 'image' in unit and not isinstance(unit['image'], dict):
|
||||
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[
|
||||
'image'] else None
|
||||
|
||||
if 'guess_mode' in unit:
|
||||
logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.')
|
||||
|
||||
unit = ControlNetUnit(**{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()})
|
||||
|
||||
# temporary, check #602
|
||||
# assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]'
|
||||
return unit
|
||||
|
||||
|
||||
def update_cn_script_in_processing(
|
||||
p: processing.StableDiffusionProcessing,
|
||||
cn_units: List[ControlNetUnit],
|
||||
**_kwargs, # for backwards compatibility
|
||||
):
|
||||
"""
|
||||
Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`.
|
||||
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want.
|
||||
|
||||
Does not update `p.script_args` if any of the folling is true:
|
||||
- ControlNet is not present in `p.scripts`
|
||||
- `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet
|
||||
"""
|
||||
p.script_args = update_cn_script(p.scripts, p.script_args_value, cn_units)
|
||||
|
||||
|
||||
def update_cn_script(
|
||||
script_runner: scripts.ScriptRunner,
|
||||
script_args: Union[Tuple[Any], List[Any]],
|
||||
cn_units: List[ControlNetUnit],
|
||||
) -> Union[Tuple[Any], List[Any]]:
|
||||
"""
|
||||
Returns: The updated `script_args` with given `cn_units` used as ControlNet
|
||||
script args.
|
||||
|
||||
Does not update `script_args` if any of the folling is true:
|
||||
- ControlNet is not present in `script_runner`
|
||||
- `script_args` is not filled with script arguments for scripts that are
|
||||
processed before ControlNet
|
||||
"""
|
||||
script_args_type = type(script_args)
|
||||
assert script_args_type in (tuple, list), script_args_type
|
||||
updated_script_args = list(copy(script_args))
|
||||
|
||||
cn_script = find_cn_script(script_runner)
|
||||
|
||||
if cn_script is None or len(script_args) < cn_script.args_from:
|
||||
return script_args
|
||||
|
||||
# fill in remaining parameters to satisfy max models, just in case script needs it.
|
||||
max_models = shared.opts.data.get("control_net_unit_count", 3)
|
||||
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0)
|
||||
|
||||
cn_script_args_diff = 0
|
||||
for script in script_runner.alwayson_scripts:
|
||||
if script is cn_script:
|
||||
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from)
|
||||
updated_script_args[script.args_from:script.args_to] = cn_units
|
||||
script.args_to = script.args_from + len(cn_units)
|
||||
else:
|
||||
script.args_from += cn_script_args_diff
|
||||
script.args_to += cn_script_args_diff
|
||||
|
||||
return script_args_type(updated_script_args)
|
||||
|
||||
|
||||
def update_cn_script_in_place(
|
||||
script_runner: scripts.ScriptRunner,
|
||||
script_args: List[Any],
|
||||
cn_units: List[ControlNetUnit],
|
||||
**_kwargs, # for backwards compatibility
|
||||
):
|
||||
"""
|
||||
@Deprecated(Raises assertion error if script_args passed in is Tuple)
|
||||
|
||||
Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`.
|
||||
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want.
|
||||
|
||||
Does not update `script_args` if any of the folling is true:
|
||||
- ControlNet is not present in `script_runner`
|
||||
- `script_args` is not filled with script arguments for scripts that are processed before ControlNet
|
||||
"""
|
||||
assert isinstance(script_args, list), type(script_args)
|
||||
|
||||
cn_script = find_cn_script(script_runner)
|
||||
if cn_script is None or len(script_args) < cn_script.args_from:
|
||||
return
|
||||
|
||||
# fill in remaining parameters to satisfy max models, just in case script needs it.
|
||||
max_models = shared.opts.data.get("control_net_unit_count", 3)
|
||||
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0)
|
||||
|
||||
cn_script_args_diff = 0
|
||||
for script in script_runner.alwayson_scripts:
|
||||
if script is cn_script:
|
||||
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from)
|
||||
script_args[script.args_from:script.args_to] = cn_units
|
||||
script.args_to = script.args_from + len(cn_units)
|
||||
else:
|
||||
script.args_from += cn_script_args_diff
|
||||
script.args_to += cn_script_args_diff
|
||||
|
||||
|
||||
def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]:
|
||||
"""
|
||||
Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script.
|
||||
"""
|
||||
|
||||
if script_runner is None:
|
||||
return None
|
||||
|
||||
for script in script_runner.alwayson_scripts:
|
||||
if is_cn_script(script):
|
||||
return script
|
||||
|
||||
|
||||
def is_cn_script(script: scripts.Script) -> bool:
|
||||
"""
|
||||
Determine whether `script` is a ControlNet script.
|
||||
"""
|
||||
|
||||
return script.title().lower() == 'controlnet'
|
||||
@@ -0,0 +1,138 @@
|
||||
import os.path
|
||||
import stat
|
||||
from collections import OrderedDict
|
||||
|
||||
from modules import shared, sd_models
|
||||
from lib_controlnet.enums import StableDiffusionVersion
|
||||
from modules_forge.shared import controlnet_dir, supported_preprocessors
|
||||
|
||||
|
||||
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
|
||||
|
||||
|
||||
def traverse_all_files(curr_path, model_list):
|
||||
f_list = [
|
||||
(os.path.join(curr_path, entry.name), entry.stat())
|
||||
for entry in os.scandir(curr_path)
|
||||
if os.path.isdir(curr_path)
|
||||
]
|
||||
for f_info in f_list:
|
||||
fname, fstat = f_info
|
||||
if os.path.splitext(fname)[1] in CN_MODEL_EXTS:
|
||||
model_list.append(f_info)
|
||||
elif stat.S_ISDIR(fstat.st_mode):
|
||||
model_list = traverse_all_files(fname, model_list)
|
||||
return model_list
|
||||
|
||||
|
||||
def get_all_models(sort_by, filter_by, path):
|
||||
res = OrderedDict()
|
||||
fileinfos = traverse_all_files(path, [])
|
||||
filter_by = filter_by.strip(" ")
|
||||
if len(filter_by) != 0:
|
||||
fileinfos = [x for x in fileinfos if filter_by.lower()
|
||||
in os.path.basename(x[0]).lower()]
|
||||
if sort_by == "name":
|
||||
fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0]))
|
||||
elif sort_by == "date":
|
||||
fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime)
|
||||
elif sort_by == "path name":
|
||||
fileinfos = sorted(fileinfos)
|
||||
|
||||
for finfo in fileinfos:
|
||||
filename = finfo[0]
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name + f" [{sd_models.model_hash(filename)}]"] = filename
|
||||
|
||||
return res
|
||||
|
||||
|
||||
controlnet_filename_dict = {'None': 'model.safetensors'}
|
||||
controlnet_names = ['None']
|
||||
|
||||
|
||||
def get_preprocessor(name):
|
||||
return supported_preprocessors.get(name, None)
|
||||
|
||||
|
||||
def get_sorted_preprocessors():
|
||||
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
|
||||
preprocessors = sorted(preprocessors, key=lambda x: str(x.sorting_priority).zfill(8) + x.name)[::-1]
|
||||
results = OrderedDict()
|
||||
results['None'] = supported_preprocessors['None']
|
||||
for p in preprocessors:
|
||||
results[p.name] = p
|
||||
return results
|
||||
|
||||
|
||||
def get_all_controlnet_names():
|
||||
return controlnet_names
|
||||
|
||||
|
||||
def get_controlnet_filename(controlnet_name):
|
||||
return controlnet_filename_dict[controlnet_name]
|
||||
|
||||
|
||||
def get_all_preprocessor_names():
|
||||
return list(get_sorted_preprocessors().keys())
|
||||
|
||||
|
||||
def get_all_preprocessor_tags():
|
||||
tags = []
|
||||
for k, p in supported_preprocessors.items():
|
||||
tags += p.tags
|
||||
tags = list(set(tags))
|
||||
tags = sorted(tags)
|
||||
return ['All'] + tags
|
||||
|
||||
|
||||
def get_filtered_preprocessors(tag):
|
||||
if tag == 'All':
|
||||
return supported_preprocessors
|
||||
return {k: v for k, v in get_sorted_preprocessors().items() if tag in v.tags or k == 'None'}
|
||||
|
||||
|
||||
def get_filtered_preprocessor_names(tag):
|
||||
return list(get_filtered_preprocessors(tag).keys())
|
||||
|
||||
|
||||
def get_filtered_controlnet_names(tag):
|
||||
filtered_preprocessors = get_filtered_preprocessors(tag)
|
||||
model_filename_filers = []
|
||||
for p in filtered_preprocessors.values():
|
||||
model_filename_filers += p.model_filename_filers
|
||||
return [x for x in controlnet_names if any(f.lower() in x.lower() for f in model_filename_filers) or x == 'None']
|
||||
|
||||
|
||||
def update_controlnet_filenames():
|
||||
global controlnet_filename_dict, controlnet_names
|
||||
|
||||
controlnet_filename_dict = {'None': 'model.safetensors'}
|
||||
controlnet_names = ['None']
|
||||
|
||||
ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None))
|
||||
extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs
|
||||
if extra_lora_path is not None and os.path.exists(extra_lora_path))
|
||||
paths = [controlnet_dir, *extra_lora_paths]
|
||||
|
||||
for path in paths:
|
||||
sort_by = shared.opts.data.get("control_net_models_sort_models_by", "name")
|
||||
filter_by = shared.opts.data.get("control_net_models_name_filter", "")
|
||||
found = get_all_models(sort_by, filter_by, path)
|
||||
controlnet_filename_dict.update(found)
|
||||
|
||||
controlnet_names = list(controlnet_filename_dict.keys())
|
||||
return
|
||||
|
||||
|
||||
def get_sd_version() -> StableDiffusionVersion:
|
||||
if shared.sd_model.is_sdxl:
|
||||
return StableDiffusionVersion.SDXL
|
||||
elif shared.sd_model.is_sd2:
|
||||
return StableDiffusionVersion.SD2x
|
||||
elif shared.sd_model.is_sd1:
|
||||
return StableDiffusionVersion.SD1x
|
||||
else:
|
||||
return StableDiffusionVersion.UNKNOWN
|
||||
@@ -0,0 +1,135 @@
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
|
||||
from lib_controlnet import external_code
|
||||
from lib_controlnet.logging import logger
|
||||
|
||||
|
||||
def field_to_displaytext(fieldname: str) -> str:
|
||||
return " ".join([word.capitalize() for word in fieldname.split("_")])
|
||||
|
||||
|
||||
def displaytext_to_field(text: str) -> str:
|
||||
return "_".join([word.lower() for word in text.split(" ")])
|
||||
|
||||
|
||||
def parse_value(value: str) -> Union[str, float, int, bool]:
|
||||
if value in ("True", "False"):
|
||||
return value == "True"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value # Plain string.
|
||||
|
||||
|
||||
def serialize_unit(unit: external_code.ControlNetUnit) -> str:
|
||||
excluded_fields = (
|
||||
"image",
|
||||
"enabled",
|
||||
# Note: "advanced_weighting" is excluded as it is an API-only field.
|
||||
"advanced_weighting",
|
||||
# Note: "inpaint_crop_image" is img2img inpaint only flag, which does not
|
||||
# provide much information when restoring the unit.
|
||||
"inpaint_crop_input_image",
|
||||
)
|
||||
|
||||
log_value = {
|
||||
field_to_displaytext(field): getattr(unit, field)
|
||||
for field in vars(external_code.ControlNetUnit()).keys()
|
||||
if field not in excluded_fields and getattr(unit, field) != -1
|
||||
# Note: exclude hidden slider values.
|
||||
}
|
||||
if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()):
|
||||
logger.error(f"Unexpected tokens encountered:\n{log_value}")
|
||||
return ""
|
||||
|
||||
return ", ".join(f"{field}: {value}" for field, value in log_value.items())
|
||||
|
||||
|
||||
def parse_unit(text: str) -> external_code.ControlNetUnit:
|
||||
return external_code.ControlNetUnit(
|
||||
enabled=True,
|
||||
**{
|
||||
displaytext_to_field(key): parse_value(value)
|
||||
for item in text.split(",")
|
||||
for (key, value) in (item.strip().split(": "),)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Infotext(object):
|
||||
def __init__(self) -> None:
|
||||
self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
|
||||
self.paste_field_names: List[str] = []
|
||||
|
||||
@staticmethod
|
||||
def unit_prefix(unit_index: int) -> str:
|
||||
return f"ControlNet {unit_index}"
|
||||
|
||||
def register_unit(self, unit_index: int, uigroup) -> None:
|
||||
"""Register the unit's UI group. By regsitering the unit, A1111 will be
|
||||
able to paste values from infotext to IOComponents.
|
||||
|
||||
Args:
|
||||
unit_index: The index of the ControlNet unit
|
||||
uigroup: The ControlNetUiGroup instance that contains all gradio
|
||||
iocomponents.
|
||||
"""
|
||||
unit_prefix = Infotext.unit_prefix(unit_index)
|
||||
for field in vars(external_code.ControlNetUnit()).keys():
|
||||
# Exclude image for infotext.
|
||||
if field == "image":
|
||||
continue
|
||||
|
||||
# Every field in ControlNetUnit should have a cooresponding
|
||||
# IOComponent in ControlNetUiGroup.
|
||||
io_component = getattr(uigroup, field)
|
||||
component_locator = f"{unit_prefix} {field}"
|
||||
self.infotext_fields.append((io_component, component_locator))
|
||||
self.paste_field_names.append(component_locator)
|
||||
|
||||
@staticmethod
|
||||
def write_infotext(
|
||||
units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing
|
||||
):
|
||||
"""Write infotext to `p`."""
|
||||
p.extra_generation_params.update(
|
||||
{
|
||||
Infotext.unit_prefix(i): serialize_unit(unit)
|
||||
for i, unit in enumerate(units)
|
||||
if unit.enabled
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def on_infotext_pasted(infotext: str, results: dict) -> None:
|
||||
"""Parse ControlNet infotext string and write result to `results` dict."""
|
||||
updates = {}
|
||||
for k, v in results.items():
|
||||
if not k.startswith("ControlNet"):
|
||||
continue
|
||||
|
||||
assert isinstance(v, str), f"Expect string but got {v}."
|
||||
try:
|
||||
for field, value in vars(parse_unit(v)).items():
|
||||
if field == "image":
|
||||
continue
|
||||
if value is None:
|
||||
logger.debug(f"InfoText: Skipping {field} because value is None.")
|
||||
continue
|
||||
|
||||
component_locator = f"{k} {field}"
|
||||
updates[component_locator] = value
|
||||
logger.debug(f"InfoText: Setting {component_locator} = {value}")
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}\n{e}"
|
||||
)
|
||||
|
||||
results.update(updates)
|
||||
@@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import copy
|
||||
import sys
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
COLORS = {
|
||||
"DEBUG": "\033[0;36m", # CYAN
|
||||
"INFO": "\033[0;32m", # GREEN
|
||||
"WARNING": "\033[0;33m", # YELLOW
|
||||
"ERROR": "\033[0;31m", # RED
|
||||
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
|
||||
"RESET": "\033[0m", # RESET COLOR
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
colored_record = copy.copy(record)
|
||||
levelname = colored_record.levelname
|
||||
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
|
||||
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
|
||||
return super().format(colored_record)
|
||||
|
||||
|
||||
# Create a new logger
|
||||
logger = logging.getLogger("ControlNet")
|
||||
logger.propagate = False
|
||||
|
||||
# Add handler if we don't have one.
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(
|
||||
ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Configure logger
|
||||
loglevel_string = getattr(shared.cmd_opts, "controlnet_loglevel", "INFO")
|
||||
loglevel = getattr(logging, loglevel_string.upper(), None)
|
||||
logger.setLevel(loglevel)
|
||||
@@ -0,0 +1,88 @@
|
||||
# High Quality Edge Thinning using Pure Python
|
||||
# Written by Lvmin Zhang
|
||||
# 2023 April
|
||||
# Stanford University
|
||||
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
lvmin_kernels_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[0, 1, 0],
|
||||
[1, 1, 1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[0, -1, -1],
|
||||
[1, 1, -1],
|
||||
[0, 1, 0]
|
||||
], dtype=np.int32)
|
||||
]
|
||||
|
||||
lvmin_kernels = []
|
||||
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
|
||||
lvmin_prunings_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[0, 0, -1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[-1, 0, 0]
|
||||
], dtype=np.int32)
|
||||
]
|
||||
|
||||
lvmin_prunings = []
|
||||
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
|
||||
|
||||
def remove_pattern(x, kernel):
|
||||
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
|
||||
objects = np.where(objects > 127)
|
||||
x[objects] = 0
|
||||
return x, objects[0].shape[0] > 0
|
||||
|
||||
|
||||
def thin_one_time(x, kernels):
|
||||
y = x
|
||||
is_done = True
|
||||
for k in kernels:
|
||||
y, has_update = remove_pattern(y, k)
|
||||
if has_update:
|
||||
is_done = False
|
||||
return y, is_done
|
||||
|
||||
|
||||
def lvmin_thin(x, prunings=True):
|
||||
y = x
|
||||
for i in range(32):
|
||||
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||
if is_done:
|
||||
break
|
||||
if prunings:
|
||||
y, _ = thin_one_time(y, lvmin_prunings)
|
||||
return y
|
||||
|
||||
|
||||
def nake_nms(x):
|
||||
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||
y = np.zeros_like(x)
|
||||
for f in [f1, f2, f3, f4]:
|
||||
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||
return y
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
import torch
|
||||
import os
|
||||
import functools
|
||||
import time
|
||||
import base64
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import cv2
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, Dict, List
|
||||
from modules.safe import unsafe_torch_load
|
||||
from lib_controlnet.logging import logger
|
||||
|
||||
|
||||
def load_state_dict(ckpt_path, location="cpu"):
|
||||
_, extension = os.path.splitext(ckpt_path)
|
||||
if extension.lower() == ".safetensors":
|
||||
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
||||
else:
|
||||
state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location))
|
||||
state_dict = get_state_dict(state_dict)
|
||||
logger.info(f"Loaded state_dict from [{ckpt_path}]")
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_state_dict(d):
|
||||
return d.get("state_dict", d)
|
||||
|
||||
|
||||
def ndarray_lru_cache(max_size: int = 128, typed: bool = False):
|
||||
"""
|
||||
Decorator to enable caching for functions with numpy array arguments.
|
||||
Numpy arrays are mutable, and thus not directly usable as hash keys.
|
||||
|
||||
The idea here is to wrap the incoming arguments with type `np.ndarray`
|
||||
as `HashableNpArray` so that `lru_cache` can correctly handles `np.ndarray`
|
||||
arguments.
|
||||
|
||||
`HashableNpArray` functions exactly the same way as `np.ndarray` except
|
||||
having `__hash__` and `__eq__` overriden.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
"""The actual decorator that accept function as input."""
|
||||
|
||||
class HashableNpArray(np.ndarray):
|
||||
def __new__(cls, input_array):
|
||||
# Input array is an instance of ndarray.
|
||||
# The view makes the input array and returned array share the same data.
|
||||
obj = np.asarray(input_array).view(cls)
|
||||
return obj
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return np.array_equal(self, other)
|
||||
|
||||
def __hash__(self):
|
||||
# Hash the bytes representing the data of the array.
|
||||
return hash(self.tobytes())
|
||||
|
||||
@functools.lru_cache(maxsize=max_size, typed=typed)
|
||||
def cached_func(*args, **kwargs):
|
||||
"""This function only accepts `HashableNpArray` as input params."""
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Preserves original function.__name__ and __doc__.
|
||||
@functools.wraps(func)
|
||||
def decorated_func(*args, **kwargs):
|
||||
"""The decorated function that delegates the original function."""
|
||||
|
||||
def convert_item(item: Any):
|
||||
if isinstance(item, np.ndarray):
|
||||
return HashableNpArray(item)
|
||||
if isinstance(item, tuple):
|
||||
return tuple(convert_item(i) for i in item)
|
||||
return item
|
||||
|
||||
args = [convert_item(arg) for arg in args]
|
||||
kwargs = {k: convert_item(arg) for k, arg in kwargs.items()}
|
||||
return cached_func(*args, **kwargs)
|
||||
|
||||
return decorated_func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def timer_decorator(func):
|
||||
"""Time the decorated function and output the result to debug logger."""
|
||||
if logger.level != logging.DEBUG:
|
||||
return func
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
# Only report function that are significant enough.
|
||||
if duration > 1e-3:
|
||||
logger.debug(f"{func.__name__} ran in: {duration:.3f} sec")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class TimeMeta(type):
|
||||
""" Metaclass to record execution time on all methods of the
|
||||
child class. """
|
||||
def __new__(cls, name, bases, attrs):
|
||||
for attr_name, attr_value in attrs.items():
|
||||
if callable(attr_value):
|
||||
attrs[attr_name] = timer_decorator(attr_value)
|
||||
return super().__new__(cls, name, bases, attrs)
|
||||
|
||||
|
||||
# svgsupports
|
||||
svgsupport = False
|
||||
try:
|
||||
import io
|
||||
from svglib.svglib import svg2rlg
|
||||
from reportlab.graphics import renderPM
|
||||
|
||||
svgsupport = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def svg_preprocess(inputs: Dict, preprocess: Callable):
|
||||
if not inputs:
|
||||
return None
|
||||
|
||||
if inputs["image"].startswith("data:image/svg+xml;base64,") and svgsupport:
|
||||
svg_data = base64.b64decode(
|
||||
inputs["image"].replace("data:image/svg+xml;base64,", "")
|
||||
)
|
||||
drawing = svg2rlg(io.BytesIO(svg_data))
|
||||
png_data = renderPM.drawToString(drawing, fmt="PNG")
|
||||
encoded_string = base64.b64encode(png_data)
|
||||
base64_str = str(encoded_string, "utf-8")
|
||||
base64_str = "data:image/png;base64," + base64_str
|
||||
inputs["image"] = base64_str
|
||||
return preprocess(inputs)
|
||||
|
||||
|
||||
def get_unique_axis0(data):
|
||||
arr = np.asanyarray(data)
|
||||
idxs = np.lexsort(arr.T)
|
||||
arr = arr[idxs]
|
||||
unique_idxs = np.empty(len(arr), dtype=np.bool_)
|
||||
unique_idxs[:1] = True
|
||||
unique_idxs[1:] = np.any(arr[:-1, :] != arr[1:, :], axis=-1)
|
||||
return arr[unique_idxs]
|
||||
|
||||
|
||||
def read_image(img_path: str) -> str:
|
||||
"""Read image from specified path and return a base64 string."""
|
||||
img = cv2.imread(img_path)
|
||||
_, bytes = cv2.imencode(".png", img)
|
||||
encoded_image = base64.b64encode(bytes).decode("utf-8")
|
||||
return encoded_image
|
||||
|
||||
|
||||
def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]:
|
||||
"""Try read all images in given img_dir."""
|
||||
images = []
|
||||
for filename in os.listdir(img_dir):
|
||||
if filename.endswith(suffixes):
|
||||
img_path = os.path.join(img_dir, filename)
|
||||
try:
|
||||
images.append(read_image(img_path))
|
||||
except IOError:
|
||||
logger.error(f"Error opening {img_path}")
|
||||
return images
|
||||
|
||||
|
||||
def align_dim_latent(x: int) -> int:
|
||||
""" Align the pixel dimension (w/h) to latent dimension.
|
||||
Stable diffusion 1:8 ratio for latent/pixel, i.e.,
|
||||
1 latent unit == 8 pixel unit."""
|
||||
return (x // 8) * 8
|
||||
Reference in New Issue
Block a user