-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
35 lines (27 loc) · 1023 Bytes
/
main.py
File metadata and controls
35 lines (27 loc) · 1023 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from CNN import CNN
from LSTM import LSTM
import json, pickle
data_folder = '/home/aida/Projects/IE-hw2/data/'
def main():
params = json.load(open("params.json", "r"))
try:
data = pickle.load(open("data.pkl", "rb"))
except Exception:
print("Error: run make_vocab.py first")
exit(1)
if params["pretrain"] and data["embeddings"] == []:
print("No embeddings found, please run make_vocab.py")
exit(1)
if params["model"] == "CNN":
cnn = CNN(params, data["vocab"], data["embeddings"])
cnn.build()
cnn.run_model(data["train_batches"], data["dev_batches"], data["true_dev_labels"], data["tag_dict"])
else:
lstm = LSTM(params, data["vocab"], data["embeddings"])
lstm.build()
predictions = lstm.run_model(data)
with open(data_folder + "test_output.txt", "w") as out:
for i, pred in enumerate(predictions):
out.write(str(8000 + i) + "\t" + pred + "\n")
if __name__ == '__main__':
main()