Skip to content

Commit 54c7f07

Browse files
committed
Add integer support to math instructions
1 parent 5cf8bfd commit 54c7f07

File tree

5 files changed

+162
-34
lines changed

5 files changed

+162
-34
lines changed

src/engine/internal/llvm/instructions/math.cpp

Lines changed: 128 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,15 @@ LLVMInstruction *Math::buildAdd(LLVMInstruction *ins)
113113
assert(ins->args.size() == 2);
114114
const auto &arg1 = ins->args[0];
115115
const auto &arg2 = ins->args[1];
116-
llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first));
117-
llvm::Value *num2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first));
118-
ins->functionReturnReg->value = m_builder.CreateFAdd(num1, num2);
116+
117+
llvm::Value *double1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Double));
118+
llvm::Value *double2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Double));
119+
ins->functionReturnReg->value = m_builder.CreateFAdd(double1, double2);
120+
121+
llvm::Value *int1 = m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Int);
122+
llvm::Value *int2 = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int);
123+
ins->functionReturnReg->isInt = m_builder.CreateAnd(arg1.second->isInt, arg2.second->isInt);
124+
ins->functionReturnReg->intValue = m_builder.CreateAdd(int1, int2);
119125

120126
return ins->next;
121127
}
@@ -125,9 +131,15 @@ LLVMInstruction *Math::buildSub(LLVMInstruction *ins)
125131
assert(ins->args.size() == 2);
126132
const auto &arg1 = ins->args[0];
127133
const auto &arg2 = ins->args[1];
128-
llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first));
129-
llvm::Value *num2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first));
130-
ins->functionReturnReg->value = m_builder.CreateFSub(num1, num2);
134+
135+
llvm::Value *double1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Double));
136+
llvm::Value *double2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Double));
137+
ins->functionReturnReg->value = m_builder.CreateFSub(double1, double2);
138+
139+
llvm::Value *int1 = m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Int);
140+
llvm::Value *int2 = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int);
141+
ins->functionReturnReg->isInt = m_builder.CreateAnd(arg1.second->isInt, arg2.second->isInt);
142+
ins->functionReturnReg->intValue = m_builder.CreateSub(int1, int2);
131143

132144
return ins->next;
133145
}
@@ -137,9 +149,15 @@ LLVMInstruction *Math::buildMul(LLVMInstruction *ins)
137149
assert(ins->args.size() == 2);
138150
const auto &arg1 = ins->args[0];
139151
const auto &arg2 = ins->args[1];
140-
llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first));
141-
llvm::Value *num2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first));
142-
ins->functionReturnReg->value = m_builder.CreateFMul(num1, num2);
152+
153+
llvm::Value *double1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Double));
154+
llvm::Value *double2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Double));
155+
ins->functionReturnReg->value = m_builder.CreateFMul(double1, double2);
156+
157+
llvm::Value *int1 = m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Int);
158+
llvm::Value *int2 = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int);
159+
ins->functionReturnReg->isInt = m_builder.CreateAnd(arg1.second->isInt, arg2.second->isInt);
160+
ins->functionReturnReg->intValue = m_builder.CreateMul(int1, int2);
143161

144162
return ins->next;
145163
}
@@ -197,7 +215,10 @@ LLVMInstruction *Math::buildRandomInt(LLVMInstruction *ins)
197215
const auto &arg2 = ins->args[1];
198216
llvm::Value *from = m_builder.CreateFPToSI(m_utils.castValue(arg1.second, arg1.first), m_builder.getInt64Ty());
199217
llvm::Value *to = m_builder.CreateFPToSI(m_utils.castValue(arg2.second, arg2.first), m_builder.getInt64Ty());
200-
ins->functionReturnReg->value = m_builder.CreateCall(m_utils.functions().resolve_llvm_random_long(), { m_utils.executionContextPtr(), from, to });
218+
llvm::Value *intValue = m_builder.CreateCall(m_utils.functions().resolve_llvm_random_int64(), { m_utils.executionContextPtr(), from, to });
219+
ins->functionReturnReg->value = m_builder.CreateSIToFP(intValue, m_builder.getDoubleTy());
220+
ins->functionReturnReg->intValue = intValue;
221+
ins->functionReturnReg->isInt = m_builder.getInt1(true);
201222

202223
return ins->next;
203224
}
@@ -207,13 +228,49 @@ LLVMInstruction *Math::buildMod(LLVMInstruction *ins)
207228
assert(ins->args.size() == 2);
208229
const auto &arg1 = ins->args[0];
209230
const auto &arg2 = ins->args[1];
210-
// rem(a, b) / b < 0.0 ? rem(a, b) + b : rem(a, b)
211-
llvm::Constant *zero = llvm::ConstantFP::get(m_utils.llvmCtx(), llvm::APFloat(0.0));
212-
llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first));
213-
llvm::Value *num2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first));
214-
llvm::Value *value = m_builder.CreateFRem(num1, num2); // rem(a, b)
215-
llvm::Value *cond = m_builder.CreateFCmpOLT(m_builder.CreateFDiv(value, num2), zero); // rem(a, b) / b < 0.0 // rem(a, b)
216-
ins->functionReturnReg->value = m_builder.CreateSelect(cond, m_builder.CreateFAdd(value, num2), value);
231+
232+
// double: rem(a, b) / b < 0.0 ? rem(a, b) + b : rem(a, b)
233+
llvm::Constant *doubleZero = llvm::ConstantFP::get(m_utils.llvmCtx(), llvm::APFloat(0.0));
234+
llvm::Value *double1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first));
235+
llvm::Value *double2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first));
236+
llvm::Value *doubleRem = m_builder.CreateFRem(double1, double2); // rem(a, b)
237+
llvm::Value *doubleCond = m_builder.CreateFCmpOLT(m_builder.CreateFDiv(doubleRem, double2), doubleZero); // rem(a, b) / b < 0.0
238+
ins->functionReturnReg->value = m_builder.CreateSelect(doubleCond, m_builder.CreateFAdd(doubleRem, double2), doubleRem);
239+
240+
// int: b == 0 ? 0 (double fallback) : ((rem(a, b) < 0) != (b < 0) ? rem(a, b) + b : rem(a, b))
241+
llvm::Constant *intZero = llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true);
242+
llvm::Value *int1 = m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Int);
243+
llvm::Value *int2 = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int);
244+
llvm::Value *nanResult = m_builder.CreateICmpEQ(int2, intZero);
245+
246+
llvm::BasicBlock *nanBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function());
247+
llvm::BasicBlock *intBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function());
248+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function());
249+
m_builder.CreateCondBr(nanResult, nanBlock, intBlock);
250+
251+
m_builder.SetInsertPoint(nanBlock);
252+
llvm::Value *noInt = m_builder.getInt1(false);
253+
m_builder.CreateBr(nextBlock);
254+
255+
m_builder.SetInsertPoint(intBlock);
256+
llvm::Value *isInt = m_builder.CreateAnd(arg1.second->isInt, arg2.second->isInt);
257+
llvm::Value *intRem = m_builder.CreateSRem(int1, int2); // rem(a, b)
258+
llvm::Value *intCond = m_builder.CreateICmpSLT(m_builder.CreateSDiv(intRem, int2), intZero); // rem(a, b) / b < 0
259+
llvm::Value *intResult = m_builder.CreateSelect(intCond, m_builder.CreateAdd(intRem, int2), intRem);
260+
m_builder.CreateBr(nextBlock);
261+
262+
m_builder.SetInsertPoint(nextBlock);
263+
264+
llvm::PHINode *resultPhi = m_builder.CreatePHI(m_builder.getInt64Ty(), 2);
265+
resultPhi->addIncoming(intZero, nanBlock);
266+
resultPhi->addIncoming(intResult, intBlock);
267+
268+
llvm::PHINode *isIntPhi = m_builder.CreatePHI(m_builder.getInt1Ty(), 2);
269+
isIntPhi->addIncoming(noInt, nanBlock);
270+
isIntPhi->addIncoming(isInt, intBlock);
271+
272+
ins->functionReturnReg->intValue = resultPhi;
273+
ins->functionReturnReg->isInt = isIntPhi;
217274

218275
return ins->next;
219276
}
@@ -225,28 +282,45 @@ LLVMInstruction *Math::buildRound(LLVMInstruction *ins)
225282

226283
assert(ins->args.size() == 1);
227284
const auto &arg = ins->args[0];
228-
// x >= 0.0 ? round(x) : (x >= -0.5 ? -0.0 : floor(x + 0.5))
285+
286+
// double: x >= 0.0 ? round(x) : (x >= -0.5 ? -0.0 : floor(x + 0.5))
229287
llvm::Constant *zero = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0));
230288
llvm::Constant *negativeZero = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(-0.0));
231289
llvm::Function *roundFunc = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::round, m_builder.getDoubleTy());
232290
llvm::Function *floorFunc = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::floor, m_builder.getDoubleTy());
233-
llvm::Value *num = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first));
234-
llvm::Value *notNegative = m_builder.CreateFCmpOGE(num, zero); // num >= 0.0
235-
llvm::Value *roundNum = m_builder.CreateCall(roundFunc, num); // round(num)
236-
llvm::Value *negativeCond = m_builder.CreateFCmpOGE(num, llvm::ConstantFP::get(llvmCtx, llvm::APFloat(-0.5))); // num >= -0.5
237-
llvm::Value *negativeRound = m_builder.CreateCall(floorFunc, m_builder.CreateFAdd(num, llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.5)))); // floor(x + 0.5)
291+
llvm::Value *doubleValue = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first));
292+
llvm::Value *notNegative = m_builder.CreateFCmpOGE(doubleValue, zero); // num >= 0.0
293+
llvm::Value *roundNum = m_builder.CreateCall(roundFunc, doubleValue); // round(num)
294+
llvm::Value *negativeCond = m_builder.CreateFCmpOGE(doubleValue, llvm::ConstantFP::get(llvmCtx, llvm::APFloat(-0.5))); // num >= -0.5
295+
llvm::Value *negativeRound = m_builder.CreateCall(floorFunc, m_builder.CreateFAdd(doubleValue, llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.5)))); // floor(x + 0.5)
238296
ins->functionReturnReg->value = m_builder.CreateSelect(notNegative, roundNum, m_builder.CreateSelect(negativeCond, negativeZero, negativeRound));
239297

298+
// int: doubleX == inf || doubleX == -inf ? doubleX : intX
299+
llvm::Constant *posInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), false);
300+
llvm::Constant *negInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), true);
301+
llvm::Value *isInt = arg.second->isInt;
302+
llvm::Value *intValue = arg.second->intValue;
303+
llvm::Value *isNotInf = m_builder.CreateAnd(m_builder.CreateFCmpONE(doubleValue, posInf), m_builder.CreateFCmpONE(doubleValue, negInf));
304+
llvm::Value *cast = m_builder.CreateFPToSI(ins->functionReturnReg->value, m_builder.getInt64Ty());
305+
ins->functionReturnReg->isInt = isNotInf;
306+
ins->functionReturnReg->intValue = m_builder.CreateSelect(isInt, intValue, cast);
307+
240308
return ins->next;
241309
}
242310

243311
LLVMInstruction *Math::buildAbs(LLVMInstruction *ins)
244312
{
245313
assert(ins->args.size() == 1);
246314
const auto &arg = ins->args[0];
247-
llvm::Function *absFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::fabs, m_builder.getDoubleTy());
248-
llvm::Value *num = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first));
249-
ins->functionReturnReg->value = m_builder.CreateCall(absFunc, num);
315+
316+
llvm::Function *fabsFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::fabs, m_builder.getDoubleTy());
317+
llvm::Value *doubleValue = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first));
318+
ins->functionReturnReg->value = m_builder.CreateCall(fabsFunc, doubleValue);
319+
320+
llvm::Function *absFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::abs, m_builder.getInt64Ty());
321+
llvm::Value *intValue = arg.second->intValue;
322+
ins->functionReturnReg->isInt = arg.second->isInt;
323+
ins->functionReturnReg->intValue = m_builder.CreateCall(absFunc, { intValue, m_builder.getInt1(false) });
250324

251325
return ins->next;
252326
}
@@ -255,9 +329,21 @@ LLVMInstruction *Math::buildFloor(LLVMInstruction *ins)
255329
{
256330
assert(ins->args.size() == 1);
257331
const auto &arg = ins->args[0];
332+
333+
// double: floor(doubleX)
258334
llvm::Function *floorFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::floor, m_builder.getDoubleTy());
259-
llvm::Value *num = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first));
260-
ins->functionReturnReg->value = m_builder.CreateCall(floorFunc, num);
335+
llvm::Value *doubleValue = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first));
336+
ins->functionReturnReg->value = m_builder.CreateCall(floorFunc, doubleValue);
337+
338+
// int: doubleX == inf || doubleX == -inf ? doubleX : intX
339+
llvm::Constant *posInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), false);
340+
llvm::Constant *negInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), true);
341+
llvm::Value *isInt = arg.second->isInt;
342+
llvm::Value *intValue = arg.second->intValue;
343+
llvm::Value *isNotInf = m_builder.CreateAnd(m_builder.CreateFCmpONE(doubleValue, posInf), m_builder.CreateFCmpONE(doubleValue, negInf));
344+
llvm::Value *cast = m_builder.CreateFPToSI(ins->functionReturnReg->value, m_builder.getInt64Ty());
345+
ins->functionReturnReg->isInt = isNotInf;
346+
ins->functionReturnReg->intValue = m_builder.CreateSelect(isInt, intValue, cast);
261347

262348
return ins->next;
263349
}
@@ -266,9 +352,21 @@ LLVMInstruction *Math::buildCeil(LLVMInstruction *ins)
266352
{
267353
assert(ins->args.size() == 1);
268354
const auto &arg = ins->args[0];
355+
356+
// double: ceil(doubleX)
269357
llvm::Function *ceilFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::ceil, m_builder.getDoubleTy());
270-
llvm::Value *num = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first));
271-
ins->functionReturnReg->value = m_builder.CreateCall(ceilFunc, num);
358+
llvm::Value *doubleValue = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first));
359+
ins->functionReturnReg->value = m_builder.CreateCall(ceilFunc, doubleValue);
360+
361+
// int: doubleX == inf || doubleX == -inf ? doubleX : intX
362+
llvm::Constant *posInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), false);
363+
llvm::Constant *negInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), true);
364+
llvm::Value *isInt = arg.second->isInt;
365+
llvm::Value *intValue = arg.second->intValue;
366+
llvm::Value *isNotInf = m_builder.CreateAnd(m_builder.CreateFCmpONE(doubleValue, posInf), m_builder.CreateFCmpONE(doubleValue, negInf));
367+
llvm::Value *cast = m_builder.CreateFPToSI(ins->functionReturnReg->value, m_builder.getInt64Ty());
368+
ins->functionReturnReg->isInt = isNotInf;
369+
ins->functionReturnReg->intValue = m_builder.CreateSelect(isInt, intValue, cast);
272370

273371
return ins->next;
274372
}

src/engine/internal/llvm/llvmfunctions.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ extern "C"
2121
return value_doubleIsInt(from) && value_doubleIsInt(to) ? ctx->rng()->randint(from, to) : ctx->rng()->randintDouble(from, to);
2222
}
2323

24-
double llvm_random_long(ExecutionContext *ctx, long from, long to)
24+
int64_t llvm_random_int64(ExecutionContext *ctx, int64_t from, int64_t to)
2525
{
2626
return ctx->rng()->randint(from, to);
2727
}
@@ -240,10 +240,10 @@ llvm::FunctionCallee LLVMFunctions::resolve_llvm_random_double()
240240
return resolveFunction("llvm_random_double", llvm::FunctionType::get(m_builder->getDoubleTy(), { pointerType, m_builder->getDoubleTy(), m_builder->getDoubleTy() }, false));
241241
}
242242

243-
llvm::FunctionCallee LLVMFunctions::resolve_llvm_random_long()
243+
llvm::FunctionCallee LLVMFunctions::resolve_llvm_random_int64()
244244
{
245245
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_ctx->llvmCtx()), 0);
246-
return resolveFunction("llvm_random_long", llvm::FunctionType::get(m_builder->getDoubleTy(), { pointerType, m_builder->getInt64Ty(), m_builder->getInt64Ty() }, false));
246+
return resolveFunction("llvm_random_int64", llvm::FunctionType::get(m_builder->getInt64Ty(), { pointerType, m_builder->getInt64Ty(), m_builder->getInt64Ty() }, false));
247247
}
248248

249249
llvm::FunctionCallee LLVMFunctions::resolve_llvm_random_bool()

src/engine/internal/llvm/llvmfunctions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class LLVMFunctions
4444
llvm::FunctionCallee resolve_list_to_string();
4545
llvm::FunctionCallee resolve_llvm_random();
4646
llvm::FunctionCallee resolve_llvm_random_double();
47-
llvm::FunctionCallee resolve_llvm_random_long();
47+
llvm::FunctionCallee resolve_llvm_random_int64();
4848
llvm::FunctionCallee resolve_llvm_random_bool();
4949
llvm::FunctionCallee resolve_string_pool_new();
5050
llvm::FunctionCallee resolve_string_pool_free();

test/llvm/operators/math/mod_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,26 @@ TEST_F(LLVMModTest, NegativeFiveModZero_Const)
159159
ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, true, -5, 0);
160160
}
161161

162+
TEST_F(LLVMModTest, PositiveDecimalModZero)
163+
{
164+
ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, false, 5.8, 0);
165+
}
166+
167+
TEST_F(LLVMModTest, PositiveDecimalModZero_Const)
168+
{
169+
ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, true, 5.8, 0);
170+
}
171+
172+
TEST_F(LLVMModTest, NegativeDecimalModZero)
173+
{
174+
ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, false, -5.8, 0);
175+
}
176+
177+
TEST_F(LLVMModTest, NegativeDecimalModZero_Const)
178+
{
179+
ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, true, -5.8, 0);
180+
}
181+
162182
TEST_F(LLVMModTest, NegativeDecimalModInfinity)
163183
{
164184
ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, false, -2.5, "Infinity");

test/llvm/operators/math/round_test.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ TEST_F(LLVMRoundTest, FourPointZero_Const)
1919
ASSERT_EQ(m_utils.getOpResult(LLVMTestUtils::OpType::Round, true, 4.0).toDouble(), 4.0);
2020
}
2121

22+
TEST_F(LLVMRoundTest, NegativeFourPointZero)
23+
{
24+
ASSERT_EQ(m_utils.getOpResult(LLVMTestUtils::OpType::Round, false, -4.0).toDouble(), -4.0);
25+
}
26+
27+
TEST_F(LLVMRoundTest, NegativeFourPointZero_Const)
28+
{
29+
ASSERT_EQ(m_utils.getOpResult(LLVMTestUtils::OpType::Round, true, -4.0).toDouble(), -4.0);
30+
}
31+
2232
TEST_F(LLVMRoundTest, ThreePointTwo)
2333
{
2434
ASSERT_EQ(m_utils.getOpResult(LLVMTestUtils::OpType::Round, false, 3.2).toDouble(), 3.0);

0 commit comments

Comments
 (0)