Forge Space and BiRefNet
This commit is contained in:
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import time
|
||||
import gradio as gr
|
||||
import importlib.util
|
||||
import shutil
|
||||
|
||||
from gradio.context import Context
|
||||
from threading import Thread
|
||||
from huggingface_hub import snapshot_download
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
spaces = []
|
||||
|
||||
|
||||
def build_html(title, installed=False, url=None):
|
||||
if not installed:
|
||||
return f'<div>{title}</div><div style="color: grey;">Not Installed</div>'
|
||||
|
||||
if isinstance(url, str):
|
||||
return f'<div>{title}</div><div>Currently Running: <a href="{url}" style="color: green;" target="_blank">{url}</a></div>'
|
||||
else:
|
||||
return f'<div>{title}</div><div style="color: grey;">Installed, Ready to Launch</div>'
|
||||
|
||||
|
||||
class ForgeSpace:
|
||||
def __init__(self, root_path, title, repo_id=None, repo_type='space', revision=None, **kwargs):
|
||||
self.title = title
|
||||
self.root_path = root_path
|
||||
self.hf_path = os.path.join(root_path, 'huggingface_space_mirror')
|
||||
self.repo_id = repo_id
|
||||
self.repo_type = repo_type
|
||||
self.revision = revision
|
||||
self.is_running = False
|
||||
self.gradio_metas = None
|
||||
|
||||
self.label = gr.HTML(build_html(title=title, url=None), elem_classes=['forge_space_label'])
|
||||
self.btn_launch = gr.Button('Launch', elem_classes=['forge_space_btn'])
|
||||
self.btn_terminate = gr.Button('Terminate', elem_classes=['forge_space_btn'])
|
||||
self.btn_install = gr.Button('Install', elem_classes=['forge_space_btn'])
|
||||
self.btn_uninstall = gr.Button('Uninstall', elem_classes=['forge_space_btn'])
|
||||
|
||||
comps = [
|
||||
self.label,
|
||||
self.btn_install,
|
||||
self.btn_uninstall,
|
||||
self.btn_launch,
|
||||
self.btn_terminate
|
||||
]
|
||||
|
||||
self.btn_launch.click(self.run, outputs=comps)
|
||||
self.btn_terminate.click(self.terminate, outputs=comps)
|
||||
self.btn_install.click(self.install, outputs=comps)
|
||||
self.btn_uninstall.click(self.uninstall, outputs=comps)
|
||||
Context.root_block.load(self.refresh_gradio, outputs=comps, queue=False, show_progress=False)
|
||||
|
||||
return
|
||||
|
||||
def refresh_gradio(self):
|
||||
results = []
|
||||
|
||||
installed = os.path.exists(self.hf_path)
|
||||
|
||||
if isinstance(self.gradio_metas, tuple):
|
||||
results.append(build_html(title=self.title, installed=installed, url=self.gradio_metas[1]))
|
||||
else:
|
||||
results.append(build_html(title=self.title, installed=installed, url=None))
|
||||
|
||||
results.append(gr.update(interactive=not installed))
|
||||
results.append(gr.update(interactive=installed))
|
||||
results.append(gr.update(interactive=installed and not self.is_running))
|
||||
results.append(gr.update(interactive=installed and self.is_running))
|
||||
return results
|
||||
|
||||
def install(self):
|
||||
os.makedirs(self.hf_path, exist_ok=True)
|
||||
|
||||
if self.repo_id is None:
|
||||
return self.refresh_gradio()
|
||||
|
||||
downloaded = snapshot_download(
|
||||
repo_id=self.repo_id,
|
||||
repo_type=self.repo_type,
|
||||
revision=self.revision,
|
||||
local_dir=self.hf_path,
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
print(f'Downloaded: {downloaded}')
|
||||
return self.refresh_gradio()
|
||||
|
||||
def uninstall(self):
|
||||
shutil.rmtree(self.hf_path)
|
||||
print(f'Deleted: {self.hf_path}')
|
||||
return self.refresh_gradio()
|
||||
|
||||
def terminate(self):
|
||||
self.is_running = False
|
||||
while self.gradio_metas is not None:
|
||||
time.sleep(0.1)
|
||||
return self.refresh_gradio()
|
||||
|
||||
def run(self):
|
||||
self.is_running = True
|
||||
Thread(target=self.gradio_worker).start()
|
||||
while self.gradio_metas is None:
|
||||
time.sleep(0.1)
|
||||
return self.refresh_gradio()
|
||||
|
||||
def gradio_worker(self):
|
||||
memory_management.unload_all_models()
|
||||
sys.path.insert(0, self.hf_path)
|
||||
file_path = os.path.join(self.root_path, 'forge_app.py')
|
||||
module_name = 'forge_space_' + str(uuid.uuid4()).replace('-', '_')
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
demo = getattr(module, 'demo')
|
||||
|
||||
self.gradio_metas = demo.launch(inbrowser=True, prevent_thread_lock=True)
|
||||
|
||||
while self.is_running:
|
||||
time.sleep(0.1)
|
||||
|
||||
demo.close()
|
||||
self.gradio_metas = None
|
||||
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
return
|
||||
|
||||
|
||||
def main_entry():
|
||||
global spaces
|
||||
|
||||
from modules.extensions import extensions
|
||||
|
||||
tagged_extensions = {}
|
||||
|
||||
for ex in extensions:
|
||||
if ex.enabled and ex.is_forge_space:
|
||||
tag = ex.space_meta['tag']
|
||||
|
||||
if tag not in tagged_extensions:
|
||||
tagged_extensions[tag] = []
|
||||
|
||||
tagged_extensions[tag].append(ex)
|
||||
|
||||
for tag, exs in tagged_extensions.items():
|
||||
with gr.Accordion(tag, open=True):
|
||||
for ex in exs:
|
||||
with gr.Row(equal_height=True):
|
||||
space = ForgeSpace(root_path=ex.path, **ex.space_meta)
|
||||
spaces.append(space)
|
||||
|
||||
return
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
|
||||
|
||||
INITIALIZED = False
|
||||
MONITOR_MODEL_MOVING = False
|
||||
|
||||
|
||||
@@ -25,6 +26,13 @@ def monitor_module_moving():
|
||||
|
||||
|
||||
def initialize_forge():
|
||||
global INITIALIZED
|
||||
|
||||
if INITIALIZED:
|
||||
return
|
||||
|
||||
INITIALIZED = True
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'packages_3rdparty'))
|
||||
|
||||
bad_list = ['--lowvram', '--medvram', '--medvram-sdxl']
|
||||
@@ -60,9 +68,6 @@ def initialize_forge():
|
||||
from modules_forge.bnb_installer import try_install_bnb
|
||||
try_install_bnb()
|
||||
|
||||
import modules_forge.patch_basic
|
||||
modules_forge.patch_basic.patch_all_basics()
|
||||
|
||||
from backend import stream
|
||||
print('CUDA Using Stream:', stream.should_use_stream())
|
||||
|
||||
@@ -85,4 +90,8 @@ def initialize_forge():
|
||||
|
||||
if 'HF_HUB_CACHE' not in os.environ:
|
||||
os.environ['HF_HUB_CACHE'] = diffusers_dir
|
||||
|
||||
import modules_forge.patch_basic
|
||||
modules_forge.patch_basic.patch_all_basics()
|
||||
|
||||
return
|
||||
|
||||
@@ -6,6 +6,8 @@ import warnings
|
||||
import gradio.networking
|
||||
import safetensors.torch
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def gradio_url_ok_fix(url: str) -> bool:
|
||||
try:
|
||||
@@ -55,7 +57,20 @@ def build_loaded(module, loader_name):
|
||||
return
|
||||
|
||||
|
||||
def always_show_tqdm(*args, **kwargs):
|
||||
kwargs['disable'] = False
|
||||
if 'name' in kwargs:
|
||||
del kwargs['name']
|
||||
return tqdm(*args, **kwargs)
|
||||
|
||||
|
||||
def patch_all_basics():
|
||||
import logging
|
||||
from huggingface_hub import file_download
|
||||
file_download.tqdm = always_show_tqdm
|
||||
from transformers.dynamic_module_utils import logger
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
gradio.networking.url_ok = gradio_url_ok_fix
|
||||
build_loaded(safetensors.torch, 'load_file')
|
||||
build_loaded(torch, 'load')
|
||||
|
||||
Reference in New Issue
Block a user