revise stream logics

This commit is contained in:
layerdiffusion
2024-08-08 18:45:36 -07:00
parent d3b81924df
commit 60c5aea11b
4 changed files with 30 additions and 24 deletions
+20 -13
View File
@@ -340,8 +340,6 @@ class LoadedModel:
raise e raise e
if not disable_async_load: if not disable_async_load:
flag = 'ASYNC' if stream.using_stream else 'SYNC'
print(f"[Memory Management] Requested {flag} Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
real_async_memory = 0 real_async_memory = 0
mem_counter = 0 mem_counter = 0
for m in self.real_model.modules(): for m in self.real_model.modules():
@@ -360,9 +358,14 @@ class LoadedModel:
elif hasattr(m, "weight"): elif hasattr(m, "weight"):
m.to(self.device) m.to(self.device)
mem_counter += module_size(m) mem_counter += module_size(m)
print(f"[Memory Management] {flag} Loader Disabled for", type(m).__name__) print(f"[Memory Management] Swap disabled for", type(m).__name__)
print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024))
print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024)) if stream.should_use_stream():
print(f"[Memory Management] Loaded to CPU Swap: {real_async_memory / (1024 * 1024):.2f} MB (asynchronous method)")
else:
print(f"[Memory Management] Loaded to CPU Swap: {real_async_memory / (1024 * 1024):.2f} MB (blocked method)")
print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB")
self.model_accelerated = True self.model_accelerated = True
@@ -390,8 +393,12 @@ class LoadedModel:
return self.model is other.model # and self.memory_required == other.memory_required return self.model is other.model # and self.memory_required == other.memory_required
current_inference_memory = 1024 * 1024 * 1024
def minimum_inference_memory(): def minimum_inference_memory():
return 1024 * 1024 * 1024 global current_inference_memory
return current_inference_memory
def unload_model_clones(model): def unload_model_clones(model):
@@ -487,17 +494,17 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_memory = loaded_model.model_memory_required(torch_dev) model_memory = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
minimal_inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
estimated_remaining_memory = current_free_mem - model_memory - minimal_inference_memory estimated_remaining_memory = current_free_mem - model_memory - inference_memory
print("[Memory Management] Current Free GPU Memory (MB) = ", current_free_mem / (1024 * 1024)) print(f"[Memory Management] Current Free GPU Memory: {current_free_mem / (1024 * 1024):.2f} MB")
print("[Memory Management] Model Memory (MB) = ", model_memory / (1024 * 1024)) print(f"[Memory Management] Required Model Memory: {model_memory / (1024 * 1024):.2f} MB")
print("[Memory Management] Minimal Inference Memory (MB) = ", minimal_inference_memory / (1024 * 1024)) print(f"[Memory Management] Required Inference Memory: {inference_memory / (1024 * 1024):.2f} MB")
print("[Memory Management] Estimated Remaining GPU Memory (MB) = ", estimated_remaining_memory / (1024 * 1024)) print(f"[Memory Management] Estimated Remaining GPU Memory: {estimated_remaining_memory / (1024 * 1024):.2f} MB")
if estimated_remaining_memory < 0: if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
async_kept_memory = (current_free_mem - minimal_inference_memory) / 1.3 async_kept_memory = (current_free_mem - inference_memory) / 1.3
async_kept_memory = int(max(0, async_kept_memory)) async_kept_memory = int(max(0, async_kept_memory))
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
+3 -3
View File
@@ -21,7 +21,7 @@ def weights_manual_cast(layer, x, skip_dtype=False):
if skip_dtype: if skip_dtype:
target_dtype = None target_dtype = None
if stream.using_stream: if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream): with stream.stream_context()(stream.mover_stream):
if layer.weight is not None: if layer.weight is not None:
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
@@ -39,7 +39,7 @@ def weights_manual_cast(layer, x, skip_dtype=False):
@contextlib.contextmanager @contextlib.contextmanager
def main_stream_worker(weight, bias, signal): def main_stream_worker(weight, bias, signal):
if not stream.using_stream or signal is None: if signal is None or not stream.should_use_stream():
yield yield
return return
@@ -60,7 +60,7 @@ def main_stream_worker(weight, bias, signal):
def cleanup_cache(): def cleanup_cache():
if not stream.using_stream: if not stream.should_use_stream():
return return
stream.current_stream.synchronize() stream.current_stream.synchronize()
+6 -7
View File
@@ -52,11 +52,10 @@ def get_new_stream():
return None return None
current_stream = None def should_use_stream():
mover_stream = None return stream_activated and current_stream is not None and mover_stream is not None
using_stream = False
if args.cuda_stream:
current_stream = get_current_stream() current_stream = get_current_stream()
mover_stream = get_new_stream() mover_stream = get_new_stream()
using_stream = current_stream is not None and mover_stream is not None stream_activated = args.cuda_stream
+1 -1
View File
@@ -58,7 +58,7 @@ def initialize_forge():
modules_forge.patch_basic.patch_all_basics() modules_forge.patch_basic.patch_all_basics()
from backend import stream from backend import stream
print('CUDA Stream Activated: ', stream.using_stream) print('CUDA Using Stream:', stream.should_use_stream())
from modules_forge.shared import diffusers_dir from modules_forge.shared import diffusers_dir