Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sjsonnet/src/sjsonnet/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,15 @@ class Evaluator(
case ExprTags.Lookup => visitLookup(e.asInstanceOf[Lookup])
case ExprTags.Function =>
val f = e.asInstanceOf[Function]
visitMethod(f.body, f.params, f.pos)
val func = visitMethod(f.body, f.params, f.pos)
// Forward StaticOptimizer-computed identity tags so apply1 can elide identity calls.
// Skip the writes when the optimizer didn't classify this Function (the common case),
// keeping zero overhead for non-identity lambdas.
if (f.staticIdentityShape != 0) {
func.staticIdentityShape = f.staticIdentityShape
func.staticIdentityCapturedIdx = f.staticIdentityCapturedIdx
}
func
case ExprTags.LocalExpr => visitLocalExpr(e.asInstanceOf[LocalExpr])
case ExprTags.Apply => visitApply(e.asInstanceOf[Apply])
case ExprTags.Apply3 => visitApply3(e.asInstanceOf[Apply3])
Expand Down
9 changes: 9 additions & 0 deletions sjsonnet/src/sjsonnet/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,15 @@ object Expr {
}
final case class Function(var pos: Position, params: Params, body: Expr) extends Expr {
final override private[sjsonnet] def tag = ExprTags.Function

// Set by StaticOptimizer; remains 0 unless the body is statically recognized.
// 1 = direct identity: body == ValidId(uniqueParam)
// 2 = self-composition: body == g(g(...g(x))) for the unique parameter `x` and a single
// captured `g` (any depth >= 1). Effective identity is decided at runtime by checking
// whether `g` is itself effectively identity.
var staticIdentityShape: Byte = 0
// When `staticIdentityShape == 2`, the absolute scope index of the captured `g`.
var staticIdentityCapturedIdx: Int = -1
}
final case class IfElse(var pos: Position, cond: Expr, `then`: Expr, `else`: Expr) extends Expr {
final override private[sjsonnet] def tag = ExprTags.IfElse
Expand Down
47 changes: 46 additions & 1 deletion sjsonnet/src/sjsonnet/StaticOptimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,14 @@ class StaticOptimizer(
case And(pos, _: Val.False, _) => Val.False(pos)
case Or(pos, _: Val.True, _) => Val.True(pos)
case Or(pos, _: Val.False, rhs: Val.Bool) => rhs.pos = pos; rhs
case e => e

// Identity-equivalent function recognition. Cheap pattern match here keeps the runtime
// fast path (`Val.Func.isEffectivelyIdentity`) at single-field cost.
case f: Function =>
identifyStaticIdentity(f)
f

case e => e
}
}

Expand Down Expand Up @@ -211,6 +218,44 @@ class StaticOptimizer(
}
} catch { case _: Exception => Double.NaN }

/**
* Statically classify a unary function as identity-equivalent so the runtime fast path can elide
* the call entirely. We recognize:
*
* - Direct identity: `function(x) x`.
* - Self-composition: `function(x) g(g(...g(x)))` for some captured `g` (depth >= 1). The
* runtime check then needs only to test whether `g` is itself effectively identity.
*
* The captured `g` index is the absolute scope index assigned by ScopedExprTransform; at the
* point we re-enter `transform` after `super.transform`, `scope.size` equals the index of the
* function's unique parameter (we've already exited `nestedNames` for the body).
*/
private def identifyStaticIdentity(f: Function): Unit = {
val params = f.params
if (params.names.length != 1) return
if (params.defaultExprs(0) != null) return
val paramIdx = scope.size
f.body match {
case ValidId(_, _, idx) if idx == paramIdx =>
f.staticIdentityShape = 1
case Apply1(_, ValidId(_, _, gIdx), inner, tail, strict)
if !tail && strict && gIdx != paramIdx &&
isSelfCompositionChain(inner, gIdx, paramIdx) =>
f.staticIdentityShape = 2
f.staticIdentityCapturedIdx = gIdx
case _ =>
}
}

private def isSelfCompositionChain(body: Expr, expectedFuncIdx: Int, paramIdx: Int): Boolean =
body match {
case ValidId(_, _, idx) if idx == paramIdx => true
case Apply1(_, ValidId(_, _, fIdx), inner, tail, strict)
if !tail && strict && fIdx == expectedFuncIdx =>
isSelfCompositionChain(inner, expectedFuncIdx, paramIdx)
case _ => false
}

private object ValidSuper {
def unapply(s: Super): Option[(Position, Int)] =
scope.get("self") match {
Expand Down
118 changes: 118 additions & 0 deletions sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2366,6 +2366,28 @@ object Val {
*/
def bodyExpr: Expr = null

/**
* Static identity shape from [[StaticOptimizer]]. The optimizer pattern-matches `Function`
* literals and tags them so the runtime can decide identity equivalence in O(1) plus a chain
* walk on the captured value.
*
* - 0 = unknown / not statically classified (default)
* - 1 = direct identity body: `function(x) x`
* - 2 = self-composition body: `function(x) g(g(...g(x)))` for a single captured `g`
*
* Stored as a plain `var` (set by `Evaluator.visitExpr`) rather than an overridable `def` so
* that all `Val.Func` instances created by the evaluator share a single anonymous subclass
* shape — this keeps `apply1`/`apply` call sites monomorphic for the JIT (avoids the bimorphic
* inlining cliff that occurred when an extra subclass was added per-Function-expr).
*/
var staticIdentityShape: Byte = 0

/**
* When [[staticIdentityShape]] == 2, the absolute scope index of the captured `g` in
* [[defSiteValScope]]. Otherwise -1.
*/
var staticIdentityCapturedIdx: Int = -1

// Convenience wrapper: evaluates the function body and resolves any TailCall sentinel.
// Use this instead of raw `evalRhs` at call sites that bypass `apply*` and consume
// the result directly (e.g. stdlib scope-reuse fast paths).
Expand All @@ -2391,6 +2413,93 @@ object Val {
}
}

/*
* Cached state for [[isEffectivelyIdentity]]. Mutually exclusive byte states (not flags).
* Encoded as a plain `var` because the predicate is idempotent, the cached values are
* stable Bytes (a single CPU word write), and the JIT can keep this on the hot path as a
* field load + constant comparison without any extra fences. Recursive cycles are guarded
* by the in-progress state.
*/
private var _effectiveIdentityState: Byte = Func.EffIdUnknown

/**
* True when this function is semantically equivalent to identity (`x => x`). Used by [[apply1]]
* to elide the call entirely. Recognizes:
* 1. Direct identity (delegates to [[isIdentityFunction]]).
* 2. Self-composition over an effectively-identity captured function, classified by
* [[StaticOptimizer]] via [[staticIdentityShape]]. The decision requires forcing the
* captured value, so the result is cached after first evaluation.
*
* Recursive captures (`local rec = f2(rec)`) are broken via an in-progress marker that yields
* `false`, letting normal application/max-stack logic surface the user-visible error.
*/
final private[sjsonnet] def isEffectivelyIdentity: Boolean = {
val s = _effectiveIdentityState
if (s == Func.EffIdYes) true
else if (s == Func.EffIdNo || s == Func.EffIdInProgress) false
else computeEffectiveIdentity()
}

private def computeEffectiveIdentity(): Boolean = {
if (isIdentityFunction) {
_effectiveIdentityState = Func.EffIdYes
return true
}
if (staticIdentityShape != 2) {
_effectiveIdentityState = Func.EffIdNo
return false
}
// Walk the static-composition chain iteratively. Each level marks itself InProgress,
// we record visited nodes so we can either propagate the final decision back up the
// chain or roll the marks back to Unknown if forcing a captured value throws.
// One small ArrayList allocation per cache miss; never executed on the hot path
// because subsequent `isEffectivelyIdentity` reads hit the cached Byte state.
val visited = new java.util.ArrayList[Func](4)
_effectiveIdentityState = Func.EffIdInProgress
visited.add(this)
var current: Func = this
var decided: Byte = 0
try {
while (decided == 0) {
val captured =
current.defSiteValScope.bindings(current.staticIdentityCapturedIdx).value
captured match {
case f: Func =>
f._effectiveIdentityState match {
case Func.EffIdYes => decided = Func.EffIdYes
case Func.EffIdNo => decided = Func.EffIdNo
case Func.EffIdInProgress => decided = Func.EffIdNo
case _ /* Unknown */ =>
if (f.isIdentityFunction) decided = Func.EffIdYes
else if (f.staticIdentityShape != 2) decided = Func.EffIdNo
else {
f._effectiveIdentityState = Func.EffIdInProgress
visited.add(f)
current = f
}
}
case _ => decided = Func.EffIdNo
}
}
var i = 0
val n = visited.size
while (i < n) {
visited.get(i)._effectiveIdentityState = decided
i += 1
}
decided == Func.EffIdYes
} catch {
case NonFatal(e) =>
var i = 0
val n = visited.size
while (i < n) {
visited.get(i)._effectiveIdentityState = Func.EffIdUnknown
i += 1
}
throw e
}
}

/** Override to provide a function name for error messages. Only called on error paths. */
def functionName: String = null

Expand Down Expand Up @@ -2523,6 +2632,7 @@ object Val {
ev: EvalScope,
tailstrictMode: TailstrictMode): Val = {
if (params.names.length != 1) apply(Array(argVal), null, outerPos)
else if (isEffectivelyIdentity) argVal.value
else {
val funDefFileScope: FileScope = pos match {
case null => outerPos.fileScope
Expand Down Expand Up @@ -2577,6 +2687,14 @@ object Val {
}
}

object Func {
// Mutually-exclusive byte states for the cached effective-identity probe; not flags.
private[Val] final val EffIdUnknown: Byte = 0
private[Val] final val EffIdYes: Byte = 1
private[Val] final val EffIdNo: Byte = 2
private[Val] final val EffIdInProgress: Byte = 3
}

/**
* Superclass for standard library functions.
*
Expand Down
74 changes: 74 additions & 0 deletions sjsonnet/test/src/sjsonnet/EvaluatorTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,80 @@ object EvaluatorTests extends TestSuite {
eval("{foo: (function() true)()}") ==> ujson.Obj {
"foo" -> ujson.True
}
eval(
"""
|local f2(f) = function(x) f(f(x));
|local g = f2(error "should stay lazy");
|std.type(g)
|""".stripMargin
) ==> ujson.Str("function")
assert(
evalErr(
"""
|local f2(f) = function(x) f(f(x));
|local g = f2(error "call should force base");
|g(1)
|""".stripMargin
).contains("call should force base")
)
assert(
evalErr(
"""
|local f2(f) = function(x) f(f(x));
|f2(1)(1)
|""".stripMargin
).contains("Expected function, found number")
)
assert(
evalErr(
"""
|local f2(f) = function(x) f(f(x));
|f2(error "tailstrict should force") tailstrict
|""".stripMargin
).contains("tailstrict should force")
)
eval(
"""
|local f2(f) = function(x) f(f(x));
|f2(function(x) x + 1)(1)
|""".stripMargin
) ==> ujson.Num(3)
eval(
"""
|local f2(f) = function(x) f(f(x));
|local plus1(x) = x + 1;
|local chain = std.makeArray(5, function(i) if i == 0 then plus1 else f2(chain[i - 1]));
|chain[4](1)
|""".stripMargin,
maxStack = 100000
) ==> ujson.Num(17)
eval(
"""
|local f2(f) = function(x) f(f(x));
|local id(x) = x;
|local slowId = std.makeArray(20, function(i) if i == 0 then id else f2(slowId[i - 1]));
|slowId[15](42)
|""".stripMargin,
maxStack = 100000
) ==> ujson.Num(42)
assert(
evalErr(
"""
|local f2(f) = function(x) f(f(x));
|local rec = f2(rec);
|rec(1)
|""".stripMargin
).contains("Max stack frames exceeded")
)
assert(
evalErr(
"""
|local f2(f) = function(x) f(f(x));
|local o = { a: f2(self.b), b: f2(self.a) };
|o.a(1)
|""".stripMargin
).contains("Max stack frames exceeded")
)
}
test("members") {
eval("{local x = 1, x: x}['x']") ==> ujson.Num(1)
Expand Down
Loading