-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDariusController.cpp
More file actions
158 lines (133 loc) · 5.91 KB
/
DariusController.cpp
File metadata and controls
158 lines (133 loc) · 5.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include "MPI.h"
#include "ap_axi_sdata.h"
#define CONTROLLER_ONLY
#define MEM_INFO_SIZE 4
#define PARAMETER_MEM_INFO_SIZE 2
#define DARIUS_INFO_SIZE 35 // {ind_0 = num_commands, ind_1-32 = command, ind_33 = batch_size, ind_34 = num_ranks}
#define NUM_COMMANDS_OFFSET 0x60
#define COMMAND_OFFSET 0x70
#define CYCLE_COUNT_OFFSET 0xd0
#define START 0x1
#define DONE 0x6
#define DEPTH 2147483648
#define DARIUS_DEPTH 1024
#define INIT 0
#define DMA_IN 1
#define RUN_DARIUS 2
#define WAIT_FOR_DARIUS 3
#define DMA_OUT 4
void dariusController(
float mem[DEPTH], // global memory pointer
int darius_driver[DARIUS_DEPTH]
//const int rank // offset of inputs
) // kernel size
{
//Needed for MPI support
#pragma HLS resource core = AXI4Stream variable = stream_out
#pragma HLS resource core = AXI4Stream variable = stream_in
#pragma HLS DATA_PACK variable = stream_out
#pragma HLS DATA_PACK variable = stream_in
// Global memory interface
#pragma HLS INTERFACE ap_bus port = mem depth = 2147483648
#pragma HLS RESOURCE core = AXI4M variable = mem
#pragma HLS INTERFACE ap_bus port = darius_driver
#pragma HLS resource core = AXI4M variable = darius_driver
#pragma HLS INTERFACE ap_ctrl_none port = return
int rank = id_in;
//Variables that need to maintain value across states
static int parameter_mem_info[PARAMETER_MEM_INFO_SIZE]; //{offset in offchip memory to dma_in, size to dma_in}
static int data_mem_info[MEM_INFO_SIZE]; //{offset in offchip memory to dma_in, size to dma_in, offset in offchip memory to dma_out, size to dma_out}
static int darius_info[DARIUS_INFO_SIZE]; //{num_commands, command, batch_size, num_ranks}
static int cumulative_cycle_count[1];
static unsigned int batch_size = darius_info[DARIUS_INFO_SIZE - 2];
static unsigned int num_ranks = darius_info[DARIUS_INFO_SIZE - 1];
static int prev_rank;
static int next_rank;
//variables that are read in
float parameter_mem_info_float[PARAMETER_MEM_INFO_SIZE]; //{offset in offchip memory to dma_in, size to dma_in}
float data_mem_info_float[MEM_INFO_SIZE]; //{offset in offchip memory to dma_in, size to dma_in, offset in offchip memory to dma_out, size to dma_out}
float darius_info_float[DARIUS_INFO_SIZE]; //{num_commands, command, batch_size, num_ranks}
float cumulative_cycle_count_float[1];
float local_mem[10];
float size_float[1];
static ap_uint<3> state = INIT;
switch (state)
{
case INIT:
//control and parameters from rank 0
//
//information on parameters (offset to dma in and size)
while (!MPI_Recv(parameter_mem_info_float, PARAMETER_MEM_INFO_SIZE, MPI_FLOAT, 0, 0 /*not used*/, MPI_COMM_WORLD /*not used*/))
;
for (int i = 0; i < PARAMETER_MEM_INFO_SIZE; i++)
parameter_mem_info[i] = (int)parameter_mem_info_float[i];
size_float[0] = parameter_mem_info_float[1];
//size_float[0] = 2.0f;
//float send_float[1];
//if(size_float[0] == 8.0f)
// send_float[0] = 42.0f;
//else
// send_float[0] = size_float[0];
//
//while(!MPI_Send(send_float, 1, MPI_FLOAT, 0, 0 ,MPI_COMM_WORLD));
#ifndef CONTROLLER_ONLY
darius_driver[0] = 0; // num_commands
#endif
//dma in parameters
while (!MPI_Recv(mem + parameter_mem_info[0] / sizeof(int), parameter_mem_info[1] / sizeof(int), MPI_FLOAT, 0, 0 /*not used*/, MPI_COMM_WORLD /*not used*/))
;
while (!MPI_Recv(data_mem_info_float, MEM_INFO_SIZE, MPI_FLOAT, 0, 0 /*not used*/, MPI_COMM_WORLD /*not used*/))
;
for (int i = 0; i < MEM_INFO_SIZE; i++)
data_mem_info[i] = (int)data_mem_info_float[i];
while (!MPI_Recv(darius_info_float, DARIUS_INFO_SIZE, MPI_FLOAT, 0, 0 /*not used*/, MPI_COMM_WORLD /*not used*/))
;
for (int i = 0; i < DARIUS_INFO_SIZE; i++)
darius_info[i] = (int)darius_info_float[i];
batch_size = darius_info[DARIUS_INFO_SIZE - 2];
num_ranks = darius_info[DARIUS_INFO_SIZE - 1];
if (rank <= batch_size)
prev_rank = 0;
else
prev_rank = rank - batch_size;
if (rank > (num_ranks - batch_size))
next_rank = 0;
else
next_rank = rank + batch_size;
state = DMA_IN;
break;
case DMA_IN:
//previous cycle count and data from previous rank
while (!MPI_Recv(cumulative_cycle_count, 1, MPI_FLOAT, prev_rank, 0 /*not used*/, MPI_COMM_WORLD /*not used*/))
;
cumulative_cycle_count[0] = (int)cumulative_cycle_count_float[0];
while (!MPI_Recv(mem + data_mem_info[0] / sizeof(float), data_mem_info[1] / sizeof(float), MPI_FLOAT, prev_rank, 0 /*not used*/, MPI_COMM_WORLD /*not used*/))
;
state = RUN_DARIUS;
break;
case RUN_DARIUS:
//run darius
darius_driver[NUM_COMMANDS_OFFSET / sizeof(int)] = darius_info[0]; // num_commands
for (int i = 0; i < (DARIUS_INFO_SIZE - 1); i++)
darius_driver[COMMAND_OFFSET / sizeof(int) + i] = darius_info[i + 1]; // command
darius_driver[0] = START;
state = WAIT_FOR_DARIUS;
break;
case WAIT_FOR_DARIUS:
if (darius_driver[0] == DONE)
state = DMA_OUT;
else
state = WAIT_FOR_DARIUS;
break;
case DMA_OUT:
cumulative_cycle_count[0] += darius_driver[CYCLE_COUNT_OFFSET / sizeof(int)]; // command
cumulative_cycle_count_float[0] = (float)cumulative_cycle_count[0];
//send next cycle count and data to next rank
while (!MPI_Send(cumulative_cycle_count_float, 1, MPI_FLOAT, next_rank, 0, MPI_COMM_WORLD))
;
while (!MPI_Send(mem + data_mem_info[2] / sizeof(float), data_mem_info[3] / sizeof(int), MPI_FLOAT, next_rank, 0, MPI_COMM_WORLD))
;
state = DMA_IN;
break;
}
}