Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.slf4j.{Logger, LoggerFactory, MDC}

import java.util.function.{BiConsumer, Supplier}
import scala.annotation.nowarn
import scala.compiletime.uninitialized
import scala.concurrent.duration.DurationLong
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -137,6 +138,147 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S

}

/** A [[ForkJoinParallelCpgPass]] that additionally maintains a thread-local accumulator of type [[R]] which is merged
* across all threads after processing completes. This enables map-reduce style aggregation alongside the usual
* DiffGraph-based graph modifications.
*
* Each thread gets its own accumulator instance (via [[newAccumulator]]). After all parts are processed, the
* accumulators are merged using [[mergeAccumulators]] and the result is passed to [[onAccumulatorComplete]].
*
* This variant uses the `stream.collect` / `BiConsumer` API (just like [[ForkJoinParallelCpgPass]]) with a combined
* container that holds both a [[DiffGraphBuilder]] and an accumulator per fork, so no `ThreadLocal` or
* `ConcurrentLinkedQueue` is needed.
*
* @tparam T
* the part type (same as in [[ForkJoinParallelCpgPass]])
* @tparam R
* the accumulator type
*/
abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, @nowarn outName: String = "")
extends CpgPassBase {
type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder

/** Generate Array of parts that can be processed in parallel. */
def generateParts(): Array[? <: AnyRef]

/** Setup large data structures, acquire external resources. */
def init(): Unit = {}

/** Override this to disable parallelism of passes. Useful for debugging. */
def isParallel: Boolean = true

/** Create a fresh, empty accumulator. Called once per fork (thread). */
protected def newAccumulator(): R

/** Merge two accumulators. Must be associative. The result may reuse either argument. */
protected def mergeAccumulators(left: R, right: R): R

/** Process a single part, writing graph changes to `builder` and aggregated data to `acc`. */
protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: R, part: T): Unit

/** Called after all parts are processed with the fully merged accumulator. Override `finish()` if you need to release
* resources; `onAccumulatorComplete` is invoked from within the default `finish()` implementation.
*/
protected def onAccumulatorComplete(acc: R): Unit = {}

/** Container pairing a per-fork DiffGraphBuilder with a per-fork accumulator. */
private class BuilderWithAccumulator(val builder: DiffGraphBuilder, var acc: R)

@volatile private var _accResult: R = uninitialized
@volatile private var _hasResult: Boolean = false

/** Release large data structures and external resources. The default implementation calls [[onAccumulatorComplete]]
* with the merged accumulator (or a fresh one if processing failed). Subclasses that override this method must call
* `super.finish()` to ensure the accumulator callback fires.
*/
def finish(): Unit = {
val acc = if (_hasResult) _accResult else newAccumulator()
onAccumulatorComplete(acc)
_hasResult = false
}

override def createAndApply(): Unit = {
baseLogger.info(s"Start of pass: $name")
val nanosStart = System.nanoTime()
var nParts = 0
var nanosBuilt = -1L
var nDiff = -1
var nDiffT = -1
try {
val diffGraph = Cpg.newDiffGraphBuilder
nParts = runWithBuilder(diffGraph)
nanosBuilt = System.nanoTime()
nDiff = diffGraph.size

nDiffT = flatgraph.DiffGraphApplier.applyDiff(cpg.graph, diffGraph)
} catch {
case exc: Exception =>
baseLogger.error(s"Pass ${name} failed", exc)
throw exc
} finally {
val nanosStop = System.nanoTime()
val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1)
baseLogger.info(
f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms ($fracRun%.0f%% on mutations). $nDiff%d + ${nDiffT - nDiff}%d changes committed from $nParts%d parts."
)
}
}

override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = {
_hasResult = false
try {
init()
val parts = generateParts()
val nParts = parts.size
_accResult = nParts match {
case 0 =>
newAccumulator()
case 1 =>
val acc = newAccumulator()
runOnPartWithAccumulator(externalBuilder, acc, parts(0).asInstanceOf[T])
acc
case _ =>
val stream =
if (!isParallel)
java.util.Arrays
.stream(parts)
.sequential()
else
java.util.Arrays
.stream(parts)
.parallel()
val result = stream.collect(
new Supplier[BuilderWithAccumulator] {
override def get(): BuilderWithAccumulator =
new BuilderWithAccumulator(Cpg.newDiffGraphBuilder, newAccumulator())
},
new BiConsumer[BuilderWithAccumulator, AnyRef] {
override def accept(bwa: BuilderWithAccumulator, part: AnyRef): Unit =
runOnPartWithAccumulator(bwa.builder, bwa.acc, part.asInstanceOf[T])
},
new BiConsumer[BuilderWithAccumulator, BuilderWithAccumulator] {
override def accept(left: BuilderWithAccumulator, right: BuilderWithAccumulator): Unit = {
left.builder.absorb(right.builder)
left.acc = mergeAccumulators(left.acc, right.acc)
}
}
)
externalBuilder.absorb(result.builder)
result.acc
}
_hasResult = true
nParts
} finally {
finish()
}
}

@deprecated("Please use createAndApply")
override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = {
createAndApply()
}
}

trait CpgPassBase {

protected def baseLogger: Logger = LoggerFactory.getLogger(getClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,114 @@ class CpgPassNewTests extends AnyWordSpec with Matchers {
}
}

"ForkJoinParallelCpgPassWithAccumulator" should {
"merge accumulators and invoke completion callback once" in {
val cpg = Cpg.empty
val completed = ArrayBuffer.empty[Int]

val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] =
new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-pass") {
override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int]
override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] =
left ++= right
override protected def runOnPartWithAccumulator(
builder: DiffGraphBuilder,
acc: ArrayBuffer[Int],
part: String
): Unit = acc += part.length
override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum
override def generateParts(): Array[String] = Array("a", "bb", "ccc")
override def isParallel: Boolean = false
}

pass.createAndApply()

completed.toSeq shouldBe Seq(6)
}

"use a fresh accumulator when there are no parts" in {
val cpg = Cpg.empty
val completed = ArrayBuffer.empty[Int]

val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] =
new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-empty") {
override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer(42)
override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] =
left ++= right
override protected def runOnPartWithAccumulator(
builder: DiffGraphBuilder,
acc: ArrayBuffer[Int],
part: String
): Unit = ()
override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum
override def generateParts(): Array[String] = Array.empty
}

pass.createAndApply()

completed.toSeq shouldBe Seq(42)
}

"clear accumulator state between runs" in {
val cpg = Cpg.empty
val completed = ArrayBuffer.empty[Int]

val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] =
new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-rerun") {
override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int]
override protected def mergeAccumulators(
left: ArrayBuffer[Int],
right: ArrayBuffer[Int]
): ArrayBuffer[Int] = {
left ++= right
}
override protected def runOnPartWithAccumulator(
builder: DiffGraphBuilder,
acc: ArrayBuffer[Int],
part: String
): Unit = acc += part.toInt
override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum
override def generateParts(): Array[String] = Array("1", "2", "3")
override def isParallel: Boolean = false
}

pass.createAndApply()
pass.createAndApply()

completed.toSeq shouldBe Seq(6, 6)
}

"invoke completion callback once when a part fails" in {
val cpg = Cpg.empty
val events = ArrayBuffer.empty[String]

val pass: ForkJoinParallelCpgPassWithAccumulator[String, Int] =
new ForkJoinParallelCpgPassWithAccumulator[String, Int](cpg, "acc-fail") {
override protected def newAccumulator(): Int = 0
override protected def mergeAccumulators(left: Int, right: Int): Int = left + right
override protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: Int, part: String): Unit = {
events += "run"
throw new RuntimeException("boom")
}
override protected def onAccumulatorComplete(acc: Int): Unit = events += s"complete:$acc"
override def generateParts(): Array[String] = Array("p1")
override def isParallel: Boolean = false
override def init(): Unit = {
events += "init"
super.init()
}
override def finish(): Unit = {
events += "finish"
super.finish()
}
}

intercept[RuntimeException] {
pass.createAndApply()
}

events.toSeq shouldBe Seq("init", "run", "finish", "complete:0")
}
}

}
Loading