Skip to content

Commit c68efc4

Browse files
Implement matrix multiplication function
This function performs matrix multiplication on two valid matrices, raising ValueErrors for invalid structures or incompatible sizes.
1 parent ae68a78 commit c68efc4

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

matrix/matrix_diagonal_sum.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Matrix Multiplication Algorithm
3+
4+
This function performs matrix multiplication on two valid matrices.
5+
It follows the mathematical definition:
6+
If a is an m x n matrix and b is an n x p matrix,
7+
then their product c is an m x p matrix.
8+
9+
Raises:
10+
ValueError: if matrices have invalid structure or incompatible sizes.
11+
12+
Sources:
13+
https://en.wikipedia.org/wiki/Matrix_multiplication
14+
15+
Examples:
16+
>>> A = [[1, 2], [3, 4]]
17+
>>> B = [[5, 6], [7, 8]]
18+
>>> matrix_multiply(A, B)
19+
[[19, 22], [43, 50]]
20+
21+
>>> matrix_multiply([[1, 2, 3]], [[4], [5], [6]])
22+
[[32]]
23+
24+
# Invalid structure
25+
>>> matrix_multiply([[1, 2], [3]], [[1, 2]])
26+
Traceback (most recent call last):
27+
...
28+
ValueError: Invalid matrix structure
29+
30+
# Incompatible sizes
31+
>>> matrix_multiply([[1, 2]], [[1, 2]])
32+
Traceback (most recent call last):
33+
...
34+
ValueError: Incompatible matrix sizes
35+
"""
36+
37+
def matrix_multiply(a: list[list[float]], b: list[list[float]]) -> list[list[float]]:
38+
if not _is_valid_matrix(a) or not _is_valid_matrix(b):
39+
raise ValueError("Invalid matrix structure")
40+
41+
rows_a = len(a)
42+
cols_a = len(a[0])
43+
rows_b = len(b)
44+
cols_b = len(b[0])
45+
46+
if cols_a != rows_b:
47+
raise ValueError("Incompatible matrix sizes")
48+
49+
result = [[0.0 for _ in range(cols_b)] for _ in range(rows_a)]
50+
51+
for i in range(rows_a):
52+
for j in range(cols_b):
53+
for k in range(cols_a):
54+
result[i][j] += a[i][k] * b[k][j]
55+
56+
return result
57+
58+
59+
def _is_valid_matrix(m: list[list[float]]) -> bool:
60+
if not isinstance(m, list) or not m:
61+
return False
62+
first_length = len(m[0])
63+
return all(isinstance(row, list) and len(row) == first_length for row in m)

0 commit comments

Comments
 (0)