Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 184 additions & 64 deletions codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,23 @@ import scala.annotation.nowarn
import scala.concurrent.duration.DurationLong
import scala.util.{Failure, Success, Try}

/* CpgPass
*
* Base class of a program which receives a CPG as input for the purpose of modifying it.
* */
/** A single-threaded CPG pass. This is the simplest pass to implement: override [[run]] and add desired graph
* modifications to the provided [[DiffGraphBuilder]].
*
* Internally implemented as a [[ForkJoinParallelCpgPass]] with a single part and parallelism disabled.
*
* @param cpg
* the code property graph to modify
* @param outName
* optional name for output
*/
abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelCpgPass[AnyRef](cpg, outName) {

/** The main method to implement. Add all desired graph changes (nodes, edges, properties) to the provided builder.
*
* @param builder
* the [[DiffGraphBuilder]] that accumulates graph modifications
*/
def run(builder: DiffGraphBuilder): Unit

final override def generateParts(): Array[? <: AnyRef] = Array[AnyRef](null)
Expand All @@ -26,42 +37,126 @@ abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelC
override def isParallel: Boolean = false
}

/** @deprecated Use [[CpgPass]] instead. */
@deprecated abstract class SimpleCpgPass(cpg: Cpg, outName: String = "") extends CpgPass(cpg, outName)

/* ForkJoinParallelCpgPass is a possible replacement for CpgPass and ParallelCpgPass.
*
* Instead of returning an Iterator, generateParts() returns an Array. This means that the entire collection
* of parts must live on the heap at the same time; on the other hand, there are no possible issues with iterator invalidation,
* e.g. when running over all METHOD nodes and deleting some of them.
*
* Instead of streaming writes as ParallelCpgPass do, all `runOnPart` invocations read the initial state
* of the graph. Then all changes (accumulated in the DiffGraphBuilders) are merged into a single change, and applied in one go.
*
* In other words, the parallelism follows the fork/join parallel map-reduce (java: collect, scala: aggregate) model.
* The effect is identical as if one were to sequentially run `runOnParts` on all output elements of `generateParts()`
* in sequential order, with the same builder.
*
* This simplifies semantics and makes it easy to reason about possible races.
*
* Note that ForkJoinParallelCpgPass never writes intermediate results, so one must consider peak memory consumption when porting from ParallelCpgPass.
*
* Initialization and cleanup of external resources or large datastructures can be done in the `init()` and `finish()`
* methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct
* passes eagerly, and releases them only when the entire chain has run.
* */
abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: String = "") extends CpgPassBase {
/** A parallel CPG pass using the fork/join model.
*
* Instead of returning an Iterator, [[generateParts]] returns an Array. This means that the entire collection of parts
* must live on the heap at the same time; on the other hand, there are no possible issues with iterator invalidation,
* e.g. when running over all METHOD nodes and deleting some of them.
*
* Instead of streaming writes as ParallelCpgPass do, all [[runOnPart]] invocations read the initial state of the
* graph. Then all changes (accumulated in the DiffGraphBuilders) are merged into a single change, and applied in one
* go.
*
* In other words, the parallelism follows the fork/join parallel map-reduce (java: collect, scala: aggregate) model.
* The effect is identical as if one were to sequentially run [[runOnPart]] on all output elements of [[generateParts]]
* in sequential order, with the same builder.
*
* This simplifies semantics and makes it easy to reason about possible races.
*
* Note that ForkJoinParallelCpgPass never writes intermediate results, so one must consider peak memory consumption
* when porting from ParallelCpgPass.
*
* Initialization and cleanup of external resources or large datastructures can be done in the [[init]] and [[finish]]
* methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct passes
* eagerly, and releases them only when the entire chain has run.
*
* This is a simplified form of [[ForkJoinParallelCpgPassWithAccumulator]] that does not use an accumulator.
*
* @tparam T
* the type of each part produced by [[generateParts]]
* @param cpg
* the code property graph to modify
* @param outname
* optional output name
*/
abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outname: String = "")
extends ForkJoinParallelCpgPassWithAccumulator[T, Null](cpg, outname) {

/** Process a single part and record graph modifications in the provided builder.
*
* @param builder
* the [[DiffGraphBuilder]] that accumulates graph modifications
* @param part
* the part to process, as produced by [[generateParts]]
*/
def runOnPart(builder: DiffGraphBuilder, part: T): Unit

override def createAccumulator(): Null = null
override def runOnPart(builder: DiffGraphBuilder, part: T, acc: Null): Unit = runOnPart(builder, part)
override def onAccumulatorComplete(builder: DiffGraphBuilder, accumulator: Null): Unit = {}
override def mergeAccumulator(left: Null, accumulator: Null): Unit = {}
}

/** A parallel CPG pass with an accumulator for aggregating side results.
*
* This is the most general form of the fork/join pass framework. It extends [[ForkJoinParallelCpgPass]] with an
* accumulator of type [[Accumulator]] that each parallel worker maintains locally. After all parts are processed,
* worker accumulators are merged via [[mergeAccumulator]], and the final merged accumulator is passed to
* [[onAccumulatorComplete]] where additional graph changes can be recorded.
*
* @tparam T
* the type of each part produced by [[generateParts]]
* @tparam Accumulator
* the type of the accumulator used during parallel execution
* @param cpg
* the code property graph to modify
* @param outName
* optional output name
*/
abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, Accumulator <: AnyRef](
cpg: Cpg,
@nowarn outName: String = ""
) extends CpgPassBase {
type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder
// generate Array of parts that can be processed in parallel

/** Generate an array of parts to be processed in parallel by [[runOnPart]]. */
def generateParts(): Array[? <: AnyRef]
// setup large data structures, acquire external resources

/** Called once before [[generateParts]]. Use to set up large data structures or acquire external resources. */
def init(): Unit = {}
// release large data structures and external resources

/** Called once after all parts have been processed (in a `finally` block). Use to release resources acquired in
* [[init]].
*/
def finish(): Unit = {}
// main function: add desired changes to builder
def runOnPart(builder: DiffGraphBuilder, part: T): Unit
// Override this to disable parallelism of passes. Useful for debugging.

/** Process a single part, recording graph changes in `builder` and side results in `accumulator`.
*
* @param builder
* the [[DiffGraphBuilder]] that accumulates graph modifications
* @param part
* the part to process
* @param accumulator
* the thread-local accumulator for this worker
*/
def runOnPart(builder: DiffGraphBuilder, part: T, accumulator: Accumulator): Unit

/** Override and return `false` to disable parallel execution. Useful for debugging. */
def isParallel: Boolean = true

/** Create a fresh accumulator instance. Called once per parallel worker thread. */
def createAccumulator(): Accumulator

/** Merge the `accumulator` (right) into `left`. Called during the combine phase of fork/join. */
def mergeAccumulator(left: Accumulator, accumulator: Accumulator): Unit

/** Called once after all parts are processed and accumulators are merged. Use to record additional graph changes
* based on the fully merged accumulator.
*
* @param builder
* the [[DiffGraphBuilder]] for any additional modifications
* @param accumulator
* the final merged accumulator
*/
def onAccumulatorComplete(builder: DiffGraphBuilder, accumulator: Accumulator): Unit

/** Creates a new [[DiffGraphBuilder]], runs the pass (init, generateParts, runOnPart, finish), applies all
* accumulated changes to the graph, and logs timing information. Exceptions during execution are logged and
* re-thrown.
*/
override def createAndApply(): Unit = {
baseLogger.info(s"Start of pass: $name")
val nanosStart = System.nanoTime()
Expand Down Expand Up @@ -89,41 +184,50 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
}
}

/** Runs the full pass lifecycle (init, generateParts, parallel runOnPart, accumulator merge, finish) and absorbs all
* changes into `externalBuilder` without applying them to the graph. The caller is responsible for applying the
* builder.
*
* @param externalBuilder
* the builder to absorb all generated changes into
* @return
* the number of parts that were processed
*/
override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = {
try {
init()

val parts = generateParts()
val nParts = parts.size
nParts match {
case 0 =>
case 1 =>
runOnPart(externalBuilder, parts(0).asInstanceOf[T])
case _ =>
val stream =
if (!isParallel)
java.util.Arrays
.stream(parts)
.sequential()
else
java.util.Arrays
.stream(parts)
.parallel()
val diff = stream.collect(
new Supplier[DiffGraphBuilder] {
override def get(): DiffGraphBuilder =
Cpg.newDiffGraphBuilder
},
new BiConsumer[DiffGraphBuilder, AnyRef] {
override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit =
runOnPart(builder, part.asInstanceOf[T])
},
new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] {
override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit =
leftBuilder.absorb(rightBuilder)
}
)
externalBuilder.absorb(diff)
}
val stream =
if (!isParallel) java.util.Arrays.stream(parts).sequential()
else java.util.Arrays.stream(parts).parallel()

val (diff, acc) = stream.collect(
new Supplier[(DiffGraphBuilder, Accumulator)] {
override def get(): (DiffGraphBuilder, Accumulator) =
(Cpg.newDiffGraphBuilder, createAccumulator())
},
new BiConsumer[(DiffGraphBuilder, Accumulator), AnyRef] {
override def accept(consumedArg: (DiffGraphBuilder, Accumulator), part: AnyRef): Unit = {
val (diff, acc) = consumedArg
runOnPart(diff, part.asInstanceOf[T], acc)
}
},
new BiConsumer[(DiffGraphBuilder, Accumulator), (DiffGraphBuilder, Accumulator)] {
override def accept(
leftConsumedArg: (DiffGraphBuilder, Accumulator),
rightConsumedArg: (DiffGraphBuilder, Accumulator)
): Unit = {
val (leftDiff, leftAcc) = leftConsumedArg
val (rightDiff, rightAcc) = leftConsumedArg
leftDiff.absorb(rightDiff)
mergeAccumulator(leftAcc, rightAcc)
}
}
)
onAccumulatorComplete(diff, acc)
externalBuilder.absorb(diff)
nParts
} finally {
finish()
Expand All @@ -137,6 +241,9 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S

}

/** Base trait for all CPG passes. Defines the lifecycle methods that every pass must implement: [[createAndApply]] for
* standalone execution, and [[runWithBuilder]] for composing passes that share a single [[DiffGraphBuilder]].
*/
trait CpgPassBase {

protected def baseLogger: Logger = LoggerFactory.getLogger(getClass)
Expand All @@ -156,8 +263,12 @@ trait CpgPassBase {
*/
def runWithBuilder(builder: DiffGraphBuilder): Int

/** Wraps runWithBuilder with logging, and swallows raised exceptions. Use with caution -- API is unstable. A return
* value of -1 indicates failure, otherwise the return value of runWithBuilder is passed through.
/** Wraps [[runWithBuilder]] with logging and exception handling. Use with caution — API is unstable.
*
* @param builder
* the [[DiffGraphBuilder]] to absorb changes into
* @return
* the number of parts processed, or `-1` if the pass threw an exception
*/
def runWithBuilderLogged(builder: DiffGraphBuilder): Int = {
baseLogger.info(s"Start of pass: $name")
Expand Down Expand Up @@ -189,6 +300,15 @@ trait CpgPassBase {
@deprecated
protected def store(overlay: GeneratedMessageV3, name: String, serializedCpg: SerializedCpg): Unit = {}

/** Executes `fun` while logging the pass start and completion time (including duration via MDC).
*
* @tparam A
* the return type of the wrapped computation
* @param fun
* the computation to execute
* @return
* the result of `fun`
*/
protected def withStartEndTimesLogged[A](fun: => A): A = {
baseLogger.info(s"Running pass: $name")
val startTime = System.currentTimeMillis
Expand Down
Loading
Loading