Skip to content

Commit 23a01a6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f1c0a25 commit 23a01a6

1 file changed

Lines changed: 9 additions & 13 deletions

File tree

linear_algebra/qr_decomposition.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
and is the basis for a particular eigenvalue algorithm, the QR algorithm.
77
This algorithm will simply attempt to perform QR decomposition on any square matrix.
88
Reference: https://en.wikipedia.org/wiki/QR_decomposition"""
9+
910
import numpy as np
1011
from scipy.linalg import qr
1112

13+
1214
def qr_decomposition(matrix_a: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
1315
"""
1416
Perform QR decomposition on a given matrix and raises an error if in
@@ -51,17 +53,12 @@ def qr_decomposition(matrix_a: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
5153
ValueError: row size should be greater than column size
5254
"""
5355

54-
5556
rows, columns = np.shape(matrix_a)
5657
if rows < columns:
57-
msg = (
58-
"row size should be greater than column size"
59-
)
58+
msg = "row size should be greater than column size"
6059
raise ValueError(msg)
6160
if rows < 2 or columns < 2:
62-
msg = (
63-
"row size and column size should be greater than 2"
64-
)
61+
msg = "row size and column size should be greater than 2"
6562
raise ValueError(msg)
6663
# Perform QR decomposition with pivoting
6764
# matrix_q: Orthogonal matrix
@@ -72,15 +69,14 @@ def qr_decomposition(matrix_a: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
7269

7370
# Verification: matrix_a[:, pivot] should equal matrix_q @ matrix_r
7471
permute_matrix = matrix_a[:, pivot]
75-
if(np.allclose(permute_matrix, matrix_q @ matrix_r)):
76-
return np.round(matrix_q,2), np.round(matrix_r,2)
72+
if np.allclose(permute_matrix, matrix_q @ matrix_r):
73+
return np.round(matrix_q, 2), np.round(matrix_r, 2)
7774
else:
78-
msg = (
79-
"No matrix found which decompose given matrix"
80-
)
75+
msg = "No matrix found which decompose given matrix"
8176
raise ValueError(msg)
8277

78+
8379
if __name__ == "__main__":
8480
import doctest
8581

86-
doctest.testmod()
82+
doctest.testmod()

0 commit comments

Comments
 (0)