Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 189 additions & 6 deletions src/compiler/evm_frontend/evm_mir_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
#include "compiler/evm_frontend/evm_mir_compiler.h"
#include "action/evm_bytecode_visitor.h"
#include "compiler/evm_frontend/evm_imported.h"
#include "compiler/mir/constants.h"
#include "compiler/mir/module.h"
#include "evm/gas_storage_cost.h"
#include "runtime/evm_instance.h"
#include "utils/hash_utils.h"
#include "utils/logging.h"
#include "llvm/Support/Casting.h"
#include <cstring>
#include <optional>

#ifdef ZEN_ENABLE_EVM_GAS_REGISTER
#include "compiler/llvm-prebuild/Target/X86/X86Subtarget.h"
Expand Down Expand Up @@ -2341,6 +2344,18 @@ EVMMirBuilder::handleClz(const Operand &ValueOp) {
RuntimeFunctions.GetClz, ValueOp);
}

namespace {
// Extract constant shift amount from MInstruction if it is a constant.
std::optional<uint64_t> getConstShiftAmount(MInstruction *Inst) {
if (auto *CI = llvm::dyn_cast<ConstantInstruction>(Inst)) {
if (auto *IntConst = llvm::dyn_cast<MConstantInt>(&CI->getConstant())) {
return IntConst->getValue().getZExtValue();
}
}
return std::nullopt;
}
} // namespace

EVMMirBuilder::U256Inst
EVMMirBuilder::handleLeftShift(const U256Inst &Value, MInstruction *ShiftAmount,
MInstruction *IsLargeShift) {
Expand All @@ -2349,6 +2364,59 @@ EVMMirBuilder::handleLeftShift(const U256Inst &Value, MInstruction *ShiftAmount,
U256Inst Result = {};

MInstruction *Zero = createIntConstInstruction(MirI64Type, 0);

// Fast path: constant shift amount — direct limb logic, no Select/cmp loops.
if (auto ShiftOpt = getConstShiftAmount(ShiftAmount)) {
uint64_t Shift = *ShiftOpt;
if (Shift >= 256) {
for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I)
Result[I] = Zero;
return Result;
}
uint64_t CompShift = Shift / 64;
uint64_t ShiftMod = Shift % 64;

// Hoist loop-invariant constant instructions out of the limb loop.
MInstruction *ShiftModConst = nullptr;
MInstruction *RemainingBitsConst = nullptr;
if (ShiftMod != 0) {
ShiftModConst = createIntConstInstruction(MirI64Type, ShiftMod);
}
if (ShiftMod != 0 && (64 - ShiftMod) > 0) {
RemainingBitsConst = createIntConstInstruction(MirI64Type, 64 - ShiftMod);
}

for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) {
MInstruction *R = Zero;
if (I >= CompShift) {
size_t SrcIdx = I - CompShift;
if (ShiftMod == 0) {
// Pure limb shift (multiple of 64): no intra-limb shift/carry needed.
R = Value[SrcIdx];
} else {
MInstruction *SrcVal = Value[SrcIdx];
MInstruction *Shifted = createInstruction<BinaryInstruction>(
false, OP_shl, MirI64Type, SrcVal, ShiftModConst);
if (SrcIdx > 0 && RemainingBitsConst) {
MInstruction *Carry = createInstruction<BinaryInstruction>(
false, OP_ushr, MirI64Type, Value[SrcIdx - 1],
RemainingBitsConst);
R = createInstruction<BinaryInstruction>(false, OP_or, MirI64Type,
Shifted, Carry);
} else {
R = Shifted;
}
}
}
// Guard with IsLargeShift: if the full 256-bit shift has high limbs set,
// the result must be zero per EVM spec.
R = createInstruction<SelectInstruction>(false, MirI64Type, IsLargeShift,
Zero, R);
Result[I] = protectUnsafeValue(R, MirI64Type);
}
return Result;
}
Comment thread
starwarfan marked this conversation as resolved.

MInstruction *One = createIntConstInstruction(MirI64Type, 1);
MInstruction *Const64 = createIntConstInstruction(MirI64Type, 64);

Expand Down Expand Up @@ -2484,6 +2552,65 @@ EVMMirBuilder::handleLogicalRightShift(const U256Inst &Value,
U256Inst Result = {};

MInstruction *Zero = createIntConstInstruction(MirI64Type, 0);

// Fast path: constant shift amount — direct limb logic, no Select/cmp loops.
if (auto ShiftOpt = getConstShiftAmount(ShiftAmount)) {
uint64_t Shift = *ShiftOpt;
if (Shift >= 256) {
for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I)
Result[I] = Zero;
return Result;
}
uint64_t CompShift = Shift / 64;
uint64_t ShiftMod = Shift % 64;

// If the shift is a multiple of 64, we only need to move whole limbs.
if (ShiftMod == 0) {
for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) {
MInstruction *R = Zero;
if (I + CompShift < EVM_ELEMENTS_COUNT) {
size_t SrcIdx = I + CompShift;
R = Value[SrcIdx];
}
// Guard with IsLargeShift for correctness with 256-bit shift values.
R = createInstruction<SelectInstruction>(false, MirI64Type,
IsLargeShift, Zero, R);
Result[I] = protectUnsafeValue(R, MirI64Type);
}
return Result;
}

// Hoist loop-invariant shift constants out of the limb loop.
MInstruction *ShiftModConst =
createIntConstInstruction(MirI64Type, ShiftMod);
uint64_t CarryShift = 64 - ShiftMod;
MInstruction *CarryShiftConst =
createIntConstInstruction(MirI64Type, CarryShift);

for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) {
Comment thread
starwarfan marked this conversation as resolved.
MInstruction *R = Zero;
if (I + CompShift < EVM_ELEMENTS_COUNT) {
size_t SrcIdx = I + CompShift;
MInstruction *SrcVal = Value[SrcIdx];
MInstruction *Shifted = createInstruction<BinaryInstruction>(
false, OP_ushr, MirI64Type, SrcVal, ShiftModConst);
if (SrcIdx + 1 < EVM_ELEMENTS_COUNT) {
MInstruction *Carry = createInstruction<BinaryInstruction>(
false, OP_shl, MirI64Type, Value[SrcIdx + 1], CarryShiftConst);
R = createInstruction<BinaryInstruction>(false, OP_or, MirI64Type,
Shifted, Carry);
} else {
R = Shifted;
}
}
// Guard with IsLargeShift for correctness with 256-bit shift values.
R = createInstruction<SelectInstruction>(false, MirI64Type, IsLargeShift,
Zero, R);
Result[I] = protectUnsafeValue(R, MirI64Type);
}
return Result;
}

MInstruction *One = createIntConstInstruction(MirI64Type, 1);
MInstruction *Const64 = createIntConstInstruction(MirI64Type, 64);

Expand Down Expand Up @@ -2611,25 +2738,81 @@ EVMMirBuilder::handleArithmeticRightShift(const U256Inst &Value,
EVMFrontendContext::getMIRTypeFromEVMType(EVMType::UINT64);
U256Inst Result = {};

// Arithmetic right shift: sign-extend when shift >= 256
MInstruction *Zero = createIntConstInstruction(MirI64Type, 0);
MInstruction *AllOnes = createIntConstInstruction(MirI64Type, ~0ULL);

// Check sign bit (bit 63 of highest component)
// Check sign bit (bit 63 of highest component) for large-shift result
MInstruction *HighComponent = Value[EVM_ELEMENTS_COUNT - 1];
MInstruction *Const63 = createIntConstInstruction(MirI64Type, 63);
MInstruction *SignBit = createInstruction<BinaryInstruction>(
false, OP_ushr, MirI64Type, HighComponent, Const63);

// Sign bit is 1 if negative
MInstruction *One = createIntConstInstruction(MirI64Type, 1);
MInstruction *IsNegative = createInstruction<CmpInstruction>(
false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, SignBit, One);

// Large shift result: all 1s if negative, all 0s if positive
MInstruction *LargeShiftResult = createInstruction<SelectInstruction>(
false, MirI64Type, IsNegative, AllOnes, Zero);

// Fast path: constant shift amount — direct limb logic, no Select/cmp loops.
if (auto ShiftOpt = getConstShiftAmount(ShiftAmount)) {
uint64_t Shift = *ShiftOpt;
if (Shift >= 256) {
for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I)
Result[I] = LargeShiftResult;
return Result;
}
uint64_t CompShift = Shift / 64;
uint64_t ShiftMod = Shift % 64;

// If the shift is a multiple of 64, we only need to move whole limbs.
if (ShiftMod == 0) {
for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) {
MInstruction *R = LargeShiftResult;
if (I + CompShift < EVM_ELEMENTS_COUNT) {
size_t SrcIdx = I + CompShift;
R = Value[SrcIdx];
}
// Guard with IsLargeShift for correctness with 256-bit shift values.
R = createInstruction<SelectInstruction>(
false, MirI64Type, IsLargeShift, LargeShiftResult, R);
Result[I] = protectUnsafeValue(R, MirI64Type);
}
return Result;
}

// Hoist loop-invariant shift constants out of the limb loop.
MInstruction *ShiftModConst =
createIntConstInstruction(MirI64Type, ShiftMod);
uint64_t CarryShift = 64 - ShiftMod;
MInstruction *CarryShiftConst =
createIntConstInstruction(MirI64Type, CarryShift);

for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) {
Comment thread
starwarfan marked this conversation as resolved.
MInstruction *R = LargeShiftResult;
if (I + CompShift < EVM_ELEMENTS_COUNT) {
size_t SrcIdx = I + CompShift;
MInstruction *SrcVal = Value[SrcIdx];
// Use arithmetic shift for the high component (contains sign bit)
bool UseArithShift = (SrcIdx == EVM_ELEMENTS_COUNT - 1);
MInstruction *Shifted = createInstruction<BinaryInstruction>(
false, UseArithShift ? OP_sshr : OP_ushr, MirI64Type, SrcVal,
ShiftModConst);
if (SrcIdx + 1 < EVM_ELEMENTS_COUNT) {
MInstruction *Carry = createInstruction<BinaryInstruction>(
false, OP_shl, MirI64Type, Value[SrcIdx + 1], CarryShiftConst);
R = createInstruction<BinaryInstruction>(false, OP_or, MirI64Type,
Shifted, Carry);
Comment thread
starwarfan marked this conversation as resolved.
} else {
R = Shifted;
}
}
// Guard with IsLargeShift for correctness with 256-bit shift values.
R = createInstruction<SelectInstruction>(false, MirI64Type, IsLargeShift,
LargeShiftResult, R);
Result[I] = protectUnsafeValue(R, MirI64Type);
}
return Result;
}

// intra-component shifts = shift % 64
// shift_comp = shift / 64 (which component index shift from)
MInstruction *Const64 = createIntConstInstruction(MirI64Type, 64);
Expand Down
Loading