move to new backend - part 1
This commit is contained in:
+18
-18
@@ -1,14 +1,14 @@
|
||||
import contextlib
|
||||
import torch
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
def has_xpu() -> bool:
|
||||
return model_management.xpu_available
|
||||
return memory_management.xpu_available
|
||||
|
||||
|
||||
def has_mps() -> bool:
|
||||
return model_management.mps_mode()
|
||||
return memory_management.mps_mode()
|
||||
|
||||
|
||||
def cuda_no_autocast(device_id=None) -> bool:
|
||||
@@ -16,27 +16,27 @@ def cuda_no_autocast(device_id=None) -> bool:
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return model_management.get_torch_device().index
|
||||
return memory_management.get_torch_device().index
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
return str(model_management.get_torch_device())
|
||||
return str(memory_management.get_torch_device())
|
||||
|
||||
|
||||
def get_optimal_device_name():
|
||||
return model_management.get_torch_device().type
|
||||
return memory_management.get_torch_device().type
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
return model_management.get_torch_device()
|
||||
return memory_management.get_torch_device()
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
return model_management.get_torch_device()
|
||||
return memory_management.get_torch_device()
|
||||
|
||||
|
||||
def torch_gc():
|
||||
model_management.soft_empty_cache()
|
||||
memory_management.soft_empty_cache()
|
||||
|
||||
|
||||
def torch_npu_set_device():
|
||||
@@ -49,15 +49,15 @@ def enable_tf32():
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = model_management.get_torch_device()
|
||||
device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now
|
||||
device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
dtype: torch.dtype = model_management.unet_dtype()
|
||||
dtype_vae: torch.dtype = model_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = model_management.unet_dtype()
|
||||
dtype_inference: torch.dtype = model_management.unet_dtype()
|
||||
device: torch.device = memory_management.get_torch_device()
|
||||
device_interrogate: torch.device = memory_management.text_encoder_device() # for backward compatibility, not used now
|
||||
device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
dtype: torch.dtype = memory_management.unet_dtype()
|
||||
dtype_vae: torch.dtype = memory_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = memory_management.unet_dtype()
|
||||
dtype_inference: torch.dtype = memory_management.unet_dtype()
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user