Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import com.contrastsecurity.sarif.Result;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import io.codemodder.Codemod;
import io.codemodder.CodemodExecutionPriority;
import io.codemodder.CodemodFileScanningResult;
Expand All @@ -14,11 +12,9 @@
import io.codemodder.SarifFindingKeyUtil;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.providers.sarif.semgrep.ProvidedSemgrepScan;
import io.codemodder.remediation.FixCandidateSearcher;
import io.codemodder.remediation.GenericRemediationMetadata;
import io.codemodder.remediation.Remediator;
import io.codemodder.remediation.SearcherStrategyRemediator;
import io.codemodder.remediation.javadeserialization.JavaDeserializationFixStrategy;
import io.codemodder.remediation.javadeserialization.JavaDeserializationRemediator;
import java.util.Optional;
import javax.inject.Inject;

Expand All @@ -41,32 +37,7 @@ public SemgrepJavaDeserializationCodemod(
ruleId = "java.lang.security.audit.object-deserialization.object-deserialization")
final RuleSarif sarif) {
super(GenericRemediationMetadata.DESERIALIZATION.reporter(), sarif);
this.remediator =
new SearcherStrategyRemediator.Builder<Result>()
.withSearcherStrategyPair(
// matches declarations
new FixCandidateSearcher.Builder<Result>()
.withMatcher(
n ->
Optional.empty()
.or(
() ->
Optional.of(n)
.map(
m ->
m instanceof VariableDeclarationExpr vde
? vde
: null)
.filter(JavaDeserializationFixStrategy::match))
.or(
() ->
Optional.of(n)
.map(m -> m instanceof MethodCallExpr mce ? mce : null)
.filter(JavaDeserializationFixStrategy::match))
.isPresent())
.build(),
new JavaDeserializationFixStrategy())
.build();
this.remediator = new JavaDeserializationRemediator<>();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.codemodder.remediation.javadeserialization;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.remediation.*;
Expand All @@ -21,8 +23,26 @@ public JavaDeserializationRemediator() {
this.searchStrategyRemediator =
new SearcherStrategyRemediator.Builder<T>()
.withSearcherStrategyPair(
// matches declarations
new FixCandidateSearcher.Builder<T>()
.withMatcher(JavaDeserializationFixStrategy::match)
.withMatcher(
n ->
Optional.empty()
.or(
() ->
Optional.of(n)
.map(
m ->
m instanceof VariableDeclarationExpr vde
? vde
: null)
.filter(JavaDeserializationFixStrategy::match))
.or(
() ->
Optional.of(n)
.map(m -> m instanceof MethodCallExpr mce ? mce : null)
.filter(JavaDeserializationFixStrategy::match))
.isPresent())
.build(),
new JavaDeserializationFixStrategy())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import java.util.Optional;

/**
* Fix strategy for XSS vulnerabilities where a variable/expr is sent to a Spring ResponseEntity.
* Fix strategy for XSS vulnerabilities where a variable/expr is sent to a Spring ResponseEntity
* constructor.
*/
final class ResponseEntityFixStrategy implements RemediationStrategy {
final class ResponseEntityConstructorFixStrategy implements RemediationStrategy {

@Override
public SuccessOrReason fix(final CompilationUnit cu, final Node node) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.codemodder.remediation.xss;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.resolution.types.ResolvedType;
import io.codemodder.remediation.RemediationStrategy;
import io.codemodder.remediation.SuccessOrReason;
import java.util.Optional;

/**
* Fix strategy for XSS vulnerabilities where a variable/expr is sent to a Spring ResponseEntity
* write method like ok().
*/
final class ResponseEntityWriteFixStrategy implements RemediationStrategy {

@Override
public SuccessOrReason fix(final CompilationUnit cu, final Node node) {
var maybeCall =
Optional.of(node).map(n -> n instanceof MethodCallExpr ? (MethodCallExpr) n : null);
if (maybeCall.isEmpty()) {
return SuccessOrReason.reason("Not a method call.");
}

MethodCallExpr call = maybeCall.get();
return EncoderWrapping.fix(call, 0);
}

static boolean match(final Node node) {
return Optional.of(node)
.map(n -> n instanceof MethodCallExpr ? (MethodCallExpr) n : null)
.filter(m -> "ok".equals(m.getNameAsString()))
.filter(m -> !m.getArguments().isEmpty())
.filter(
c -> {
Expression firstArg = c.getArguments().getFirst().get();
try {
ResolvedType resolvedType = firstArg.calculateResolvedType();
return "java.lang.String".equals(resolvedType.describe());
} catch (Exception e) {
// this is expected often, and indicates its a non-String type anyway
return false;
}
})
.isPresent();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ public XSSRemediator() {
new PrintingMethodFixStrategy())
.withSearcherStrategyPair(
new FixCandidateSearcher.Builder<T>()
.withMatcher(ResponseEntityFixStrategy::match)
.withMatcher(ResponseEntityConstructorFixStrategy::match)
.build(),
new ResponseEntityFixStrategy())
new ResponseEntityConstructorFixStrategy())
.withSearcherStrategyPair(
new FixCandidateSearcher.Builder<T>()
.withMatcher(ResponseEntityWriteFixStrategy::match)
.build(),
new ResponseEntityWriteFixStrategy())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

final class ResponseEntityFixStrategyTest {
final class ResponseEntityConstructorFixStrategyTest {

private ResponseEntityFixStrategy fixer;
private ResponseEntityConstructorFixStrategy fixer;
private DetectorRule rule;
private JavaParser parser;

@BeforeEach
void setup() throws IOException {
this.fixer = new ResponseEntityFixStrategy();
this.fixer = new ResponseEntityConstructorFixStrategy();
this.parser = JavaParserFactory.newFactory().create(List.of());
this.rule = new DetectorRule("xss", "XSS", null);
}
Expand Down Expand Up @@ -86,7 +86,7 @@ private CodemodFileScanningResult scanAndFix(final CompilationUnit cu, final int
new SearcherStrategyRemediator.Builder<XSSFinding>()
.withSearcherStrategyPair(
new FixCandidateSearcher.Builder<XSSFinding>()
.withMatcher(ResponseEntityFixStrategy::match)
.withMatcher(ResponseEntityConstructorFixStrategy::match)
.build(),
fixer)
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package io.codemodder.remediation.xss;

import static org.assertj.core.api.Assertions.assertThat;

import com.github.javaparser.JavaParser;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.javaparser.JavaParserFactory;
import io.codemodder.remediation.FixCandidateSearcher;
import io.codemodder.remediation.SearcherStrategyRemediator;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

final class ResponseEntityWriteFixStrategyTest {

private ResponseEntityWriteFixStrategy fixer;
private DetectorRule rule;
private JavaParser parser;

@BeforeEach
void setup() throws IOException {
this.fixer = new ResponseEntityWriteFixStrategy();
this.parser = JavaParserFactory.newFactory().create(List.of());
this.rule = new DetectorRule("xss", "XSS", null);
}

private static Stream<Arguments> fixableSamples() {
return Stream.of(
Arguments.of(
"""
class Samples {
String should_be_fixed(String s) {
return ResponseEntity.ok("Value: " + s);
}
}
""",
"""
import org.owasp.encoder.Encode;
class Samples {
String should_be_fixed(String s) {
return ResponseEntity.ok("Value: " + Encode.forHtml(s));
}
}
"""),
Arguments.of(
"""
class Samples {
String should_be_fixed(Object s) {
return ResponseEntity.ok("Value: " + s.toString());
}
}
""",
"""
import org.owasp.encoder.Encode;
class Samples {
String should_be_fixed(Object s) {
return ResponseEntity.ok("Value: " + Encode.forHtml(s.toString()));
}
}
"""));
}

@ParameterizedTest
@MethodSource("fixableSamples")
void it_fixes_obvious_response_write_methods(final String beforeCode, final String afterCode) {
CompilationUnit cu = parser.parse(beforeCode).getResult().orElseThrow();
LexicalPreservingPrinter.setup(cu);

var result = scanAndFix(cu, 3);
assertThat(result.changes()).isNotEmpty();
String actualCode = LexicalPreservingPrinter.print(cu);
assertThat(actualCode).isEqualToIgnoringWhitespace(afterCode);
}

private CodemodFileScanningResult scanAndFix(final CompilationUnit cu, final int line) {
XSSFinding finding = new XSSFinding("should_be_fixed", line, null);
var remediator =
new SearcherStrategyRemediator.Builder<XSSFinding>()
.withSearcherStrategyPair(
new FixCandidateSearcher.Builder<XSSFinding>()
.withMatcher(ResponseEntityWriteFixStrategy::match)
.build(),
fixer)
.build();
return remediator.remediateAll(
cu,
"path",
rule,
List.of(finding),
XSSFinding::key,
XSSFinding::line,
x -> Optional.empty(),
x -> Optional.ofNullable(x.column()));
}

@ParameterizedTest
@MethodSource("unfixableSamples")
void it_does_not_fix_unfixable_samples(final String beforeCode, final int line) {
CompilationUnit cu = parser.parse(beforeCode).getResult().orElseThrow();
LexicalPreservingPrinter.setup(cu);
var result = scanAndFix(cu, line);
assertThat(result.changes()).isEmpty();
}

private static Stream<Arguments> unfixableSamples() {
return Stream.of(
// this is not a ResponseEntity, shouldn't touch it
Arguments.of(
// this is not a ResponseEntity, shouldn't touch it
"""
class Samples {
String should_be_fixed(String s) {
return ResponseEntity.something_besides_ok("Value: " + s);
}
}
""",
3));
}
}
Loading