Skip to content

Commit 9e447a1

Browse files
q and dtype
1 parent a707a49 commit 9e447a1

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

kron_torch/kron.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,10 @@ def _solve_triangular_right(X, A):
376376
orig_dtype = X.dtype
377377
X = X.to(dtype=torch.float32, non_blocking=True)
378378
A = A.to(dtype=torch.float32, non_blocking=True)
379-
return torch.linalg.solve_triangular(A, X.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(X)
380-
379+
out = torch.linalg.solve_triangular(
380+
A, X.reshape(-1, X.size(-1)), upper=True, left=False
381+
).reshape_as(X)
382+
return out.to(dtype=orig_dtype, non_blocking=True)
381383

382384
@torch.compile(fullgraph=True, dynamic=False)
383385
def _calc_A_and_conjB(exprA, G, Q, V):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "kron-torch"
7-
version = "0.2.8"
7+
version = "0.2.9"
88
description = "An implementation of PSGD Kron optimizer in PyTorch."
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

0 commit comments

Comments
 (0)