Get Flux working on Apple Silicon (#1264)
Co-authored-by: Conor Nash <conor@nbs.consulting>
This commit is contained in:
@@ -19,6 +19,9 @@ def attention(q, k, v, pe):
|
|||||||
|
|
||||||
|
|
||||||
def rope(pos, dim, theta):
|
def rope(pos, dim, theta):
|
||||||
|
if pos.device.type == "mps":
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
||||||
|
else:
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
omega = 1.0 / (theta ** scale)
|
omega = 1.0 / (theta ** scale)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user