-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
180 lines (152 loc) · 6.12 KB
/
main.cpp
File metadata and controls
180 lines (152 loc) · 6.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include <Foundation/Foundation.hpp>
#include <Metal/Metal.hpp>
#include <iostream>
#include <vector>
int main() {
// Initialize the Metal device
MTL::Device* device = MTL::CreateSystemDefaultDevice();
if (!device) {
std::cerr << "Failed to find a compatible Metal device." << std::endl;
return -1;
}
// Create a command queue
MTL::CommandQueue* commandQueue = device->newCommandQueue();
if (!commandQueue) {
std::cerr << "Failed to create a command queue." << std::endl;
device->release();
return -1;
}
// Load the default library (assumes operations.metal is part of the project)
// extra steps here are necessary since we are not in X-Code. We have to use
// xcrun -sdk to create a .metallib file. see tutorial for instructions
NS::Error* error = nullptr;
NS::String* filePath = NS::String::string("operations.metallib", NS::UTF8StringEncoding);
auto lib = device->newLibrary(filePath, &error);
if(!lib){
std::cerr << "Failed to Library\n";
std::exit(-1);
}
// Retrieve the compute function from the library
NS::String* functionName = NS::String::string("add_vector", NS::UTF8StringEncoding);
MTL::Function* computeFunction = lib->newFunction(functionName);
if (!computeFunction) {
std::cerr << "Failed to find the compute function 'add_vector'." << std::endl;
lib->release();
commandQueue->release();
device->release();
return -1;
}
// Create a compute pipeline state
MTL::ComputePipelineState* computePipelineState = device->newComputePipelineState(computeFunction, &error);
if (!computePipelineState) {
std::cerr << "Failed to create compute pipeline state: "
<< (error ? error->localizedDescription()->utf8String() : "Unknown error") << std::endl;
computeFunction->release();
lib->release();
commandQueue->release();
device->release();
return -1;
}
// Define the length of the vectors and buffer size
const uint32_t arrayLength = 1024;
const size_t bufferSize = arrayLength * sizeof(float);
// Initialize input data
std::vector<float> a(arrayLength);
std::vector<float> b(arrayLength);
for (uint32_t i = 0; i < arrayLength; ++i) {
a[i] = static_cast<float>(i);
b[i] = static_cast<float>(i * 2);
}
// Create buffers for the input and output data
MTL::Buffer* aBuffer = device->newBuffer(bufferSize, MTL::ResourceStorageModeManaged);
MTL::Buffer* bBuffer = device->newBuffer(bufferSize, MTL::ResourceStorageModeManaged);
MTL::Buffer* cBuffer = device->newBuffer(bufferSize, MTL::ResourceStorageModeManaged);
// Copy data into the Metal buffers
memcpy(aBuffer->contents(), a.data(), bufferSize);
memcpy(bBuffer->contents(), b.data(), bufferSize);
// Notify Metal that the buffers have been modified
aBuffer->didModifyRange(NS::Range::Make(0, aBuffer->length()));
bBuffer->didModifyRange(NS::Range::Make(0, bBuffer->length()));
// Create a command buffer to encode commands
MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer();
if (!commandBuffer) {
std::cerr << "Failed to create a command buffer." << std::endl;
// Release resources
aBuffer->release();
bBuffer->release();
cBuffer->release();
computePipelineState->release();
computeFunction->release();
lib->release();
commandQueue->release();
device->release();
return -1;
}
// Create a compute command encoder
MTL::ComputeCommandEncoder* computeEncoder = commandBuffer->computeCommandEncoder();
if (!computeEncoder) {
std::cerr << "Failed to create a compute command encoder." << std::endl;
// Release resources
commandBuffer->release();
aBuffer->release();
bBuffer->release();
cBuffer->release();
computePipelineState->release();
computeFunction->release();
lib->release();
commandQueue->release();
device->release();
return -1;
}
// Set the compute pipeline state and buffers
computeEncoder->setComputePipelineState(computePipelineState);
computeEncoder->setBuffer(aBuffer, 0, 0);
computeEncoder->setBuffer(bBuffer, 0, 1);
computeEncoder->setBuffer(cBuffer, 0, 2);
// Determine the grid and threadgroup sizes
MTL::Size gridSize = MTL::Size(arrayLength, 1, 1);
// Ensure the threadgroup size does not exceed the maximum threads per threadgroup
NS::UInteger threadgroup_Size = computePipelineState->maxTotalThreadsPerThreadgroup();
if (threadgroup_Size> arrayLength) {
threadgroup_Size = arrayLength;
}
MTL::Size threadgroupSize = MTL::Size(threadgroup_Size, 1, 1); // Adjust based on the device's capabilities
// Dispatch the compute kernel
computeEncoder->dispatchThreads(gridSize, threadgroupSize);
// End encoding
computeEncoder->endEncoding();
// Commit the command buffer and wait for it to complete
commandBuffer->commit();
commandBuffer->waitUntilCompleted();
// Read the output data from the GPU
float* cData = static_cast<float*>(cBuffer->contents());
// Verify the results
bool isCorrect = true;
for (uint32_t i = 0; i < arrayLength; ++i) {
float expected = a[i] + b[i];
if (cData[i] != expected) {
std::cerr << "Mismatch at index " << i << ": expected " << expected << ", got " << cData[i] << std::endl;
isCorrect = false;
break;
}
}
if (isCorrect) {
std::cout << "Computation successful! All results are correct." << std::endl;
}
// Release all allocated resources
computeEncoder->release();
commandBuffer->release();
aBuffer->release();
bBuffer->release();
cBuffer->release();
computePipelineState->release();
computeFunction->release();
lib->release();
functionName->release();
commandQueue->release();
device->release();
return 0;
}