-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer_manager.lua
More file actions
112 lines (93 loc) · 3.18 KB
/
trainer_manager.lua
File metadata and controls
112 lines (93 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
require "love.image"
require "love.math"
require "love.filesystem"
local love = require "love"
local neat = require "neat/neat"
local DATASET_ROOT = "dataset"
local BATCH_SIZE = 128
local status_channel = love.thread.getChannel("status")
local champion_neat_channel = love.thread.getChannel("champion")
status_channel:push('initializing...')
local function gather_all_files(path, file_list)
local items = love.filesystem.getDirectoryItems(path)
local current_dir = path:match("([^/]+)$") or path
for _, item in ipairs(items) do
local full_path = path .. "/" .. item
local info = love.filesystem.getInfo(full_path)
if info.type == "directory" then
status_channel:push('gathering files in ' .. full_path)
gather_all_files(full_path, file_list)
elseif info.type == "file" and item:match("%.png$") then
table.insert(file_list, {
path = full_path,
label = tonumber(current_dir) or 0
})
end
end
end
local function serialize_genome(genome)
local lines = {"return {"}
lines[#lines+1] = " fitness = " .. (genome.fitness or 0) .. ","
lines[#lines+1] = " nodes = {"
for i, node in ipairs(genome.nodes) do
lines[#lines+1] = string.format(" {id = %d, type = \"%s\", activation = {name = \"%s\"}},", node.id, node.type, node.activation.name)
end
lines[#lines+1] = " },"
lines[#lines+1] = " connections = {"
for i, conn in ipairs(genome.connections) do
lines[#lines+1] = string.format(" {in_node = %d, out_node = %d, weight = %f, innovation = %d, enabled = %s},",
conn.in_node, conn.out_node, conn.weight, conn.innovation, tostring(conn.enabled))
end
lines[#lines+1] = " },"
lines[#lines+1] = " settings = {}"
lines[#lines+1] = "}"
return table.concat(lines, "\n")
end
local function load_genome()
if love.filesystem.getInfo("best_genome.lua") then
local chunk = love.filesystem.load("best_genome.lua")
if chunk then
return chunk()
end
end
return nil
end
-- Initial genome setup
local current_best_genome = load_genome()
local all_files = {}
gather_all_files(DATASET_ROOT, all_files)
-- Create a basic starter genome if none exists
if not current_best_genome then
current_best_genome = neat.create_genome({
input_count = 7,
output_count = 1,
hidden_layers_count = 2,
nodes_per_layer = 8
})
current_best_genome = neat.purify_genome(current_best_genome)
end
local result_channel = love.thread.newChannel()
local worker = love.thread.newThread("worker_train.lua")
while true do
-- 1. Select random batch
local batch = {}
for i = 1, BATCH_SIZE do
table.insert(batch, all_files[love.math.random(#all_files)])
end
-- 2. Setup worker
-- 3. Start worker with args
worker:start(result_channel, current_best_genome, batch)
worker:wait()
-- 4. Wait for results
local new_best = result_channel:pop()
local new_status = result_channel:pop()
if new_best then
current_best_genome = new_best
champion_neat_channel:push(neat.purify_genome(new_best))
love.filesystem.write("best_genome.lua", serialize_genome(new_best))
print("Updated best genome from worker and saved to file.")
end
if new_status then
status_channel:push(new_status)
end
end