Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,8 @@ object Val {
}
}

private[sjsonnet] def rawBytes: Array[Byte] = null

/**
* If both this and other are ConcatViews sharing the same left array, return the shared prefix
* length. Otherwise return 0. Used by compare/equal to skip identical prefix elements entirely,
Expand Down Expand Up @@ -1344,7 +1346,7 @@ object Val {
@inline private def isMaterialized: Boolean = arr ne null

/** Raw byte backing data for zero-copy extraction (e.g. base64 encode). Always non-null. */
def rawBytes: Array[Byte] = byteData
override def rawBytes: Array[Byte] = byteData

override def value(i: Int): Val = {
if (isMaterialized || isConcatView) super.value(i)
Expand Down Expand Up @@ -1410,6 +1412,68 @@ object Val {
}
}

private final class LinearModByteArr(pos0: Position, size: Int, multiplier: Int, addend: Int)
extends Arr(pos0, null) {
_length = size
private var byteData: Array[Byte] = null

@inline private def checkIndex(i: Int): Unit =
if (i < 0 || i >= _length) throw new ArrayIndexOutOfBoundsException(i)

@inline private def byteAt(i: Int): Int =
((i.toLong * multiplier.toLong + addend.toLong) & 0xffL).toInt

override def value(i: Int): Val = {
if ((arr ne null) || isConcatView) super.value(i)
else {
checkIndex(i)
Val.cachedNum(pos, byteAt(i).toDouble)
}
}

override def eval(i: Int): Eval = {
if ((arr ne null) || isConcatView) super.eval(i)
else {
checkIndex(i)
Val.cachedNum(pos, byteAt(i).toDouble)
}
}

override def rawBytes: Array[Byte] = {
val current = byteData
if (current != null) current
else {
val bytes = new Array[Byte](_length)
var i = 0
while (i < bytes.length) {
bytes(i) = byteAt(i).toByte
i += 1
}
byteData = bytes
bytes
}
}

override def asLazyArray: Array[Eval] = {
if ((arr eq null) && !isConcatView) {
val p = pos
val result = new Array[Eval](_length)
var i = 0
while (i < result.length) {
result(i) = Val.cachedNum(p, byteAt(i).toDouble)
i += 1
}
arr = result
}
super.asLazyArray
}

override def reversed(newPos: Position): Arr = {
asLazyArray
super.reversed(newPos)
}
}

object Arr {
def apply(pos: Position, arr: Array[? <: Eval]): Arr = new Arr(pos, arr)

Expand Down Expand Up @@ -1521,6 +1585,9 @@ object Val {
/** Create a byte-backed array from raw bytes (e.g. base64DecodeBytes output). */
def fromBytes(pos: Position, bytes: Array[Byte]): Arr = new ByteArr(pos, bytes)

def linearModBytes(pos: Position, size: Int, multiplier: Int, addend: Int): Arr =
new LinearModByteArr(pos, size, multiplier, addend)

/**
* Create a lazy range array representing the integer sequence [from, from+1, ..., from+size-1].
* Elements are computed on demand via Val.cachedNum, avoiding upfront allocation of the full
Expand Down
88 changes: 87 additions & 1 deletion sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,81 @@ object ArrayModule extends AbstractFunctionModule {
sum
}

private final val NoBytePattern = Long.MinValue
private final val MaxExactDoubleInt = 9007199254740991L

@inline private def encodeBytePattern(multiplier: Int, addend: Int): Long =
(multiplier.toLong << 32) | (addend.toLong & 0xffffffffL)

@inline private def bytePatternMultiplier(pattern: Long): Int =
(pattern >>> 32).toInt

@inline private def bytePatternAddend(pattern: Long): Int =
pattern.toInt

private def literalNonNegativeInt(expr: Expr): Int = expr match {
case n: Val.Num =>
val d = n.rawDouble
if (d >= 0.0 && d.isWhole && d.isValidInt) d.toInt else -1
case _ => -1
}

private def scaleBytePattern(pattern: Long, factor: Int): Long = {
if (pattern == NoBytePattern || factor < 0) NoBytePattern
else {
val multiplier = bytePatternMultiplier(pattern).toLong * factor.toLong
val addend = bytePatternAddend(pattern).toLong * factor.toLong
if (multiplier <= Int.MaxValue && addend <= Int.MaxValue)
encodeBytePattern(multiplier.toInt, addend.toInt)
else NoBytePattern
}
}

private def addBytePattern(pattern: Long, addend: Int): Long = {
if (pattern == NoBytePattern || addend < 0) NoBytePattern
else {
val newAddend = bytePatternAddend(pattern).toLong + addend.toLong
if (newAddend <= Int.MaxValue)
encodeBytePattern(bytePatternMultiplier(pattern), newAddend.toInt)
else NoBytePattern
}
}

private def linearBytePattern(expr: Expr, paramIdx: Int): Long = expr match {
case id: Expr.ValidId if id.nameIdx == paramIdx =>
encodeBytePattern(1, 0)
case Expr.BinaryOp(_, lhs, Expr.BinaryOp.OP_*, rhs) =>
val factor = literalNonNegativeInt(rhs)
if (factor >= 0) scaleBytePattern(linearBytePattern(lhs, paramIdx), factor)
else {
val lhsFactor = literalNonNegativeInt(lhs)
if (lhsFactor >= 0) scaleBytePattern(linearBytePattern(rhs, paramIdx), lhsFactor)
else NoBytePattern
}
case Expr.BinaryOp(_, lhs, Expr.BinaryOp.OP_+, rhs) =>
val rhsAddend = literalNonNegativeInt(rhs)
if (rhsAddend >= 0) addBytePattern(linearBytePattern(lhs, paramIdx), rhsAddend)
else {
val lhsAddend = literalNonNegativeInt(lhs)
if (lhsAddend >= 0) addBytePattern(linearBytePattern(rhs, paramIdx), lhsAddend)
else NoBytePattern
}
case _ => NoBytePattern
}

private def linearMod256BytePattern(body: Expr, paramIdx: Int, size: Int): Long = body match {
case Expr.BinaryOp(_, lhs, Expr.BinaryOp.OP_%, rhs) if literalNonNegativeInt(rhs) == 256 =>
val pattern = linearBytePattern(lhs, paramIdx)
if (pattern == NoBytePattern) NoBytePattern
else {
val multiplier = bytePatternMultiplier(pattern).toLong
val addend = bytePatternAddend(pattern).toLong
val maxInput = if (size <= 0) addend else multiplier * (size.toLong - 1L) + addend
if (maxInput <= MaxExactDoubleInt) pattern else NoBytePattern
}
case _ => NoBytePattern
}

private def removeAtView(arr: Val.Arr, removeIdx: Int): Val.Arr = {
val len = arr.length
if (len == 1) Val.Arr(arr.pos, Val.Arr.EMPTY_EVAL_ARRAY)
Expand Down Expand Up @@ -908,7 +983,18 @@ object ArrayModule extends AbstractFunctionModule {
builtin("makeArray", "sz", "func") { (pos, ev, size: Val, func: Val.Func) =>
val sz = size.cast[Val.Num].asPositiveInt
val body = func.bodyExpr
if (func.params.names.length == 1 && body != null && body.isInstanceOf[Val.Literal]) {
val bytePattern =
if (func.params.names.length == 1 && body != null)
linearMod256BytePattern(body, func.defSiteValScope.length, sz)
else NoBytePattern
if (bytePattern != NoBytePattern) {
Val.Arr.linearModBytes(
pos,
sz,
bytePatternMultiplier(bytePattern),
bytePatternAddend(bytePattern)
)
} else if (func.params.names.length == 1 && body != null && body.isInstanceOf[Val.Literal]) {
// Function body is a constant (e.g. `function(_) 'x'`).
// Keep the eager shared-value array: it is smaller and faster than a lazy view here.
val a = new Array[Eval](sz)
Expand Down
38 changes: 21 additions & 17 deletions sjsonnet/src/sjsonnet/stdlib/EncodingModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,31 @@ object EncodingModule extends AbstractFunctionModule {
case ba: Val.ByteArr =>
Val.Str.asciiSafe(pos, PlatformBase64.encodeToString(ba.rawBytes))
case arr: Val.Arr =>
val len = arr.length
val byteArr = new Array[Byte](len)
var i = 0
while (i < len) {
arr.value(i) match {
case v: Val.Num =>
val vInt = v.asInt
if (vInt < 0 || vInt > 255) {
val rawBytes = arr.rawBytes
if (rawBytes != null) Val.Str.asciiSafe(pos, PlatformBase64.encodeToString(rawBytes))
else {
val len = arr.length
val byteArr = new Array[Byte](len)
var i = 0
while (i < len) {
arr.value(i) match {
case v: Val.Num =>
val vInt = v.asInt
if (vInt < 0 || vInt > 255) {
Error.fail(
f"Found an invalid codepoint value at position $i (must be 0 <= X <= 255), got $vInt"
)
}
byteArr(i) = vInt.toByte
case v =>
Error.fail(
f"Found an invalid codepoint value at position $i (must be 0 <= X <= 255), got $vInt"
f"Expected an array of numbers, got a ${v.prettyName} at position $i"
)
}
byteArr(i) = vInt.toByte
case v =>
Error.fail(
f"Expected an array of numbers, got a ${v.prettyName} at position $i"
)
}
i += 1
}
i += 1
Val.Str.asciiSafe(pos, PlatformBase64.encodeToString(byteArr))
}
Val.Str.asciiSafe(pos, PlatformBase64.encodeToString(byteArr))
case x => Error.fail("Cannot base64 encode " + x.prettyName)
}): Val
},
Expand Down
7 changes: 7 additions & 0 deletions sjsonnet/test/src/sjsonnet/Base64Tests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ object Base64Tests extends TestSuite {
assert(r == ujson.True)
}

test("makeArrayBytePattern") {
eval("""std.base64(std.makeArray(6, function(i) i % 256))""") ==>
ujson.Str("AAECAwQF")
eval("""std.base64(std.makeArray(6, function(i) (i * 7 + 13) % 256))""") ==>
ujson.Str("DRQbIikw")
}

// ================================================================
// Multiple encode/decode cycles (stability test)
// ================================================================
Expand Down
Loading