Skip to content

Commit dcccd74

Browse files
Add matrix multiplication function with validation
Implement matrix multiplication function with validation.
1 parent ae68a78 commit dcccd74

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

matrix/matrix_diagonal_sum.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Matrix Multiplication Algorithm
3+
4+
This function performs matrix multiplication on two valid matrices.
5+
It follows the mathematical definition:
6+
If A is an m×n matrix and B is an n×p matrix,
7+
then their product C is an m×p matrix.
8+
9+
Raises:
10+
ValueError: if matrices have invalid structure or incompatible sizes.
11+
12+
Sources:
13+
https://en.wikipedia.org/wiki/Matrix_multiplication
14+
15+
Examples:
16+
>>> A = [[1, 2], [3, 4]]
17+
>>> B = [[5, 6], [7, 8]]
18+
>>> matrix_multiply(A, B)
19+
[[19, 22], [43, 50]]
20+
21+
>>> matrix_multiply([[1, 2, 3]], [[4], [5], [6]])
22+
[[32]]
23+
24+
# Invalid structure
25+
>>> matrix_multiply([[1, 2], [3]], [[1, 2]])
26+
Traceback (most recent call last):
27+
...
28+
ValueError: Invalid matrix structure
29+
30+
# Incompatible sizes
31+
>>> matrix_multiply([[1, 2]], [[1, 2]])
32+
Traceback (most recent call last):
33+
...
34+
ValueError: Incompatible matrix sizes
35+
"""
36+
37+
from typing import List
38+
39+
40+
def matrix_multiply(A: List[List[float]], B: List[List[float]]) -> List[List[float]]:
41+
if not _is_valid_matrix(A) or not _is_valid_matrix(B):
42+
raise ValueError("Invalid matrix structure")
43+
44+
rows_A = len(A)
45+
cols_A = len(A[0])
46+
rows_B = len(B)
47+
cols_B = len(B[0])
48+
49+
if cols_A != rows_B:
50+
raise ValueError("Incompatible matrix sizes")
51+
52+
result = [[0.0 for _ in range(cols_B)] for _ in range(rows_A)]
53+
54+
for i in range(rows_A):
55+
for j in range(cols_B):
56+
for k in range(cols_A):
57+
result[i][j] += A[i][k] * B[k][j]
58+
59+
return result
60+
61+
62+
def _is_valid_matrix(M: List[List[float]]) -> bool:
63+
if not isinstance(M, list) or not M:
64+
return False
65+
first_length = len(M[0])
66+
return all(isinstance(row, list) and len(row) == first_length for row in M)

0 commit comments

Comments
 (0)