Skip to content

Commit 3207df5

Browse files
committed
refactor: add type hints to softmax function
1 parent 3c88735 commit 3207df5

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

maths/softmax.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,22 @@
1313
import numpy as np
1414

1515

16-
def softmax(vector):
16+
def softmax(vector: np.ndarray | list | tuple) -> np.ndarray:
1717
"""
1818
Implements the softmax function
1919
2020
Parameters:
21-
vector (np.array,list,tuple): A numpy array of shape (1,n)
22-
consisting of real values or a similar list,tuple
23-
21+
vector (np.array | list | tuple): A numpy array of shape (1,n)
22+
consisting of real values or a similar list, tuple
2423
2524
Returns:
26-
softmax_vec (np.array): The input numpy array after applying
27-
softmax.
25+
np.array: The input numpy array after applying softmax.
2826
29-
The softmax vector adds up to one. We need to ceil to mitigate for
30-
precision
31-
>>> float(np.ceil(np.sum(softmax([1,2,3,4]))))
27+
The softmax vector adds up to one. We need to ceil to mitigate for precision
28+
>>> float(np.ceil(np.sum(softmax([1, 2, 3, 4]))))
3229
1.0
3330
34-
>>> vec = np.array([5,5])
31+
>>> vec = np.array([5, 5])
3532
>>> softmax(vec)
3633
array([0.5, 0.5])
3734

0 commit comments

Comments
 (0)