-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest.py
More file actions
26 lines (21 loc) · 742 Bytes
/
test.py
File metadata and controls
26 lines (21 loc) · 742 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import numpy as np
import xeno
try:
from sklearn.datasets import load_digits
except:
print("your sklearn library needs to be install with whl of numpy+MKL :(\n")
# prepare
xeno.utils.random.set_seed(1234)
# data
digits = load_digits()
X_train = digits.data
X_train /= np.max(X_train)
Y_train = digits.target
n_classes = np.unique(Y_train).size
# model
model = xeno.model.Model()
model.add(xeno.layers.Dense(n_out=500, n_in=64, activation=xeno.activations.ReLU()))
model.add(xeno.layers.Dense(n_out=n_classes, activation=xeno.activations.Softmax()))
model.compile(loss=xeno.objectives.SCCE(), optimizer=xeno.optimizers.SGD(lr=0.005))
# train
model.fit(X_train, xeno.utils.data.one_hot(Y_train), max_iter=150, validation_split=0.1)