@@ -60,25 +60,6 @@ void Runable::fini()
6060 _deferred.clear ();
6161}
6262
63- #if 0
64- class DepManager
65- {
66- private:
67- IdValueMap _ivm;
68- std::unordered_set<id_type> _args;
69- public:
70- ::mlir::Value getDependent(i::mlir::OpBuilder & builder, d_type guid)
71- {
72- if(auto d = _ivm.find(guid); d == _ivm.end()) {
73- _func.insertArg
74- _ivm[guid] = {val, {}}
75- } else {
76- return d->second.first;
77- }
78- }
79- };
80- #endif
81-
8263void process_promises ()
8364{
8465 bool done = false ;
@@ -87,7 +68,6 @@ void process_promises()
8768 do {
8869 ::mlir::OpBuilder builder (&jit._context );
8970 auto loc = builder.getUnknownLoc ();
90- jit::IdValueMap ivp;
9171
9272 // Create a MLIR module
9373 auto module = builder.create <::mlir::ModuleOp>(loc);
@@ -104,11 +84,13 @@ void process_promises()
10484 // we need to keep runables/deferred/futures alive until we set their values below
10585 std::vector<Runable::ptr_type> runables;
10686
87+ jit::DepManager dm (function);
88+
10789 while (true ) {
10890 Runable::ptr_type d;
10991 _deferred.pop (d);
11092 if (d) {
111- if (d->generate_mlir (builder, loc, ivp )) {
93+ if (d->generate_mlir (builder, loc, dm )) {
11294 d.reset ();
11395 break ;
11496 };
@@ -123,39 +105,8 @@ void process_promises()
123105
124106 if (runables.empty ()) continue ;
125107
126- // Now we have to define the return type as a ValueRange of all arrays which we have created
127- // (runnables have put them into ivp when generating mlir)
128- // We also compute the total size of the struct llvm created for this return type
129- // llvm will basically return a struct with all the arrays as members, each of type JIT::MemRefDescriptor
130-
131- // Need a container to put all return values, will be used to construct TypeRange
132- std::vector<::mlir::Type> ret_types;
133- ret_types.reserve (ivp.size ());
134- std::vector<::mlir::Value> ret_values;
135- ret_types.reserve (ivp.size ());
136- std::unordered_map<id_type, uint64_t > rank_map;
137- // here we store the total size of the llvm struct
138- uint64_t sz = 0 ;
139- for (auto & v : ivp) {
140- auto value = v.second .first ;
141- // append the type and array/value
142- ret_types.push_back (value.getType ());
143- ret_values.push_back (value);
144- auto ptt = value.getType ().dyn_cast <::imex::ptensor::PTensorType>();
145- assert (ptt);
146- auto rank = ptt.getRtensor ().getShape ().size ();
147- rank_map[v.first ] = rank;
148- // add sizeof(MemRefDescriptor<elementtype, rank>) to sz
149- sz += 3 + 2 * rank;
150- }
151- ::mlir::TypeRange ret_tr (ret_types);
152- ::mlir::ValueRange ret_vr (ret_values);
153-
154- // add return statement
155- auto ret_value = builder.create <::mlir::func::ReturnOp>(loc, ret_vr);
156- // Define and assign correct function type
157- auto funcTypeAttr = ::mlir::TypeAttr::get (builder.getFunctionType ({}, ret_tr));
158- function.setFunctionTypeAttr (funcTypeAttr);
108+ // create return statement and adjust function type
109+ uint64_t sz = dm.handleResult (builder);
159110 // also request generation of c-wrapper function
160111 function->setAttr (::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName (), ::mlir::UnitAttr::get (&jit._context ));
161112 // add the function to the module
@@ -165,22 +116,10 @@ void process_promises()
165116 // compile and run the module
166117 assert (sizeof (intptr_t ) == sizeof (void *));
167118 intptr_t * output = new intptr_t [sz];
168- std::cout << ivp.size () << " sz: " << sz << std::endl;
169119 if (jit.run (module , fname, output)) throw std::runtime_error (" failed running jit" );
170120
171- // push results to fulfill promises
172- size_t pos = 0 ;
173- for (auto & v : ivp) {
174- auto value = v.second .first ;
175- auto rank = rank_map[v.first ];
176- void * allocated = (void *)output[pos];
177- void * aligned = (void *)output[pos+1 ];
178- intptr_t offset = output[pos+2 ];
179- intptr_t * sizes = output + pos + 3 ;
180- intptr_t * stride = output + pos + 3 + rank;
181- pos += 3 + 2 * rank;
182- v.second .second (rank, allocated, aligned, offset, sizes, stride);
183- }
121+ // push results to deliver promises
122+ dm.deliver (output, sz);
184123 } while (!done);
185124}
186125
0 commit comments