revise kernel
This commit is contained in:
@@ -619,6 +619,18 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
for candidate in supported_dtypes:
|
||||
if candidate == torch.float16:
|
||||
if should_use_fp16(inference_device, prioritize_performance=False):
|
||||
return candidate
|
||||
if candidate == torch.bfloat16:
|
||||
if should_use_bf16(inference_device):
|
||||
return candidate
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.always_gpu:
|
||||
return get_torch_device()
|
||||
|
||||
Reference in New Issue
Block a user