Skip to content

Commit df2ccb4

Browse files
committed
WIP, Rust: Improve performance
1 parent 31c4c7a commit df2ccb4

File tree

4 files changed

+115
-53
lines changed

4 files changed

+115
-53
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -949,29 +949,46 @@ private module Cached {
949949
}
950950

951951
private class ReceiverExpr extends Expr {
952-
ReceiverExpr() { any(MethodCallExpr mce).getReceiver() = this }
952+
MethodCallExpr mce;
953+
954+
ReceiverExpr() { mce.getReceiver() = this }
955+
956+
string getField() { result = mce.getIdentifier().getText() }
953957

954958
Type resolveTypeAt(TypePath path) { result = inferTypeDeref(this, path) }
955959
}
956960

961+
private module IsInstantiationOfInput implements IsInstantiationOfSig<ReceiverExpr> {
962+
predicate potentialInstantiationOf(TypeAbstraction impl, ReceiverExpr receiver, TypeMention sub) {
963+
sub.resolveType() = receiver.resolveTypeAt(TypePath::nil()) and
964+
sub = impl.(ImplTypeAbstraction).getSelfTy().(TypeReprMention) and
965+
exists(impl.(ImplItemNode).getASuccessor(receiver.getField()))
966+
}
967+
}
968+
969+
bindingset[item, name]
970+
pragma[inline_late]
971+
private Function getMethodSuccessor(ItemNode item, string name) {
972+
result = item.getASuccessor(name)
973+
}
974+
957975
bindingset[tp, name]
976+
pragma[inline_late]
958977
private Function getTypeParameterMethod(TypeParameter tp, string name) {
959-
result = tp.(TypeParamTypeParameter).getTypeParam().(ItemNode).getASuccessor(name)
978+
result = getMethodSuccessor(tp.(TypeParamTypeParameter).getTypeParam(), name)
960979
or
961-
result = tp.(SelfTypeParameter).getTrait().(ItemNode).getASuccessor(name)
980+
result = getMethodSuccessor(tp.(SelfTypeParameter).getTrait(), name)
962981
}
963982

964983
/**
965984
* Gets an `impl` block with an implementing type that matches the type of
966985
* `mce`'s receiver.
967986
*/
968-
private predicate methodCallMatchingImpl(ReceiverExpr receiver, string name, Function function) {
969-
exists(MethodCallExpr mce, Impl impl |
970-
mce.getReceiver() = receiver and
971-
mce.getIdentifier().getText() = name and
972-
TypeTreeUtils<ReceiverExpr, TypeMention>::isInstantiationOf(impl, receiver,
973-
impl.getSelfTy().(TypeReprMention)) and
974-
function = impl.(ImplItemNode).getASuccessor(name)
987+
private predicate methodCallMatchingImpl(ReceiverExpr receiver, Function function) {
988+
exists(Impl impl |
989+
IsInstantiationOf<ReceiverExpr, IsInstantiationOfInput>::isInstantiationOf(receiver, impl,
990+
impl.(ImplTypeAbstraction).getSelfTy().(TypeReprMention)) and
991+
function = getMethodSuccessor(impl, receiver.getField())
975992
)
976993
}
977994

@@ -980,16 +997,14 @@ private module Cached {
980997
*/
981998
cached
982999
Function resolveMethodCallExpr(MethodCallExpr mce) {
983-
exists(ReceiverExpr receiver, string name |
984-
mce.getReceiver() = receiver and
985-
mce.getIdentifier().getText() = name
986-
|
1000+
exists(ReceiverExpr receiver | mce.getReceiver() = receiver |
9871001
// The method comes from an `impl` block targeting the type of `receiver`.
988-
methodCallMatchingImpl(receiver, name, result)
1002+
methodCallMatchingImpl(receiver, result)
9891003
or
9901004
// The type of `receiver` is a type parameter and the method comes from a
9911005
// trait bound on the type parameter.
992-
result = getTypeParameterMethod(inferTypeDeref(receiver, TypePath::nil()), name)
1006+
result =
1007+
getTypeParameterMethod(inferTypeDeref(receiver, TypePath::nil()), receiver.getField())
9931008
)
9941009
}
9951010

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ private import TypeInference
77

88
/** An AST node that may mention a type. */
99
abstract class TypeMention extends AstNode {
10-
TypeMention() { exists(this.getLocation()) }
11-
1210
/** Gets the `i`th type argument mention, if any. */
1311
abstract TypeMention getTypeArgument(int i);
1412

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ testFailures
22
| main.rs:132:26:132:31 | x.m1() | Fixed missing result: method=MyThing<S1>::m1 |
33
| main.rs:133:26:133:31 | y.m1() | Fixed missing result: method=MyThing<S2>::m1 |
44
| main.rs:133:26:133:33 | ... .a | Fixed missing result: fieldof=MyThing |
5+
| main.rs:272:28:272:91 | //... | Fixed spurious result: method=T::convert_to |
6+
| main.rs:276:28:276:52 | //... | Missing result: method=T::convert_to |
57
| main.rs:285:26:285:38 | thing_s1.m1() | Fixed missing result: method=MyThing<S1>::m1 |
68
| main.rs:286:26:286:38 | thing_s2.m1() | Fixed missing result: method=MyThing<S2>::m1 |
79
| main.rs:286:26:286:40 | ... .a | Fixed missing result: fieldof=MyThing |
@@ -20,6 +22,8 @@ testFailures
2022
| main.rs:333:29:333:62 | //... | Fixed spurious result: type=x:S2 |
2123
| main.rs:335:29:335:62 | //... | Fixed spurious result: type=y:S1 |
2224
| main.rs:342:33:342:66 | //... | Fixed spurious result: type=x:S2 |
25+
| main.rs:345:37:345:71 | //... | Missing result: method=T::convert_to |
26+
| main.rs:345:37:345:71 | //... | Missing result: type=i:S1 |
2327
| main.rs:346:13:346:13 | j | Fixed missing result: type=j:S1 |
2428
| main.rs:924:26:924:37 | x5.flatten() | Fixed missing result: method=flatten |
2529
inferType
@@ -171,10 +175,8 @@ inferType
171175
| main.rs:175:16:175:19 | SelfParam | | main.rs:173:5:178:5 | Self [trait MyProduct] |
172176
| main.rs:177:16:177:19 | SelfParam | | main.rs:173:5:178:5 | Self [trait MyProduct] |
173177
| main.rs:180:43:180:43 | x | | main.rs:180:26:180:40 | T2 |
174-
| main.rs:180:56:182:5 | { ... } | | main.rs:155:5:156:14 | S1 |
175178
| main.rs:180:56:182:5 | { ... } | | main.rs:180:22:180:23 | T1 |
176179
| main.rs:181:9:181:9 | x | | main.rs:180:26:180:40 | T2 |
177-
| main.rs:181:9:181:14 | x.m1() | | main.rs:155:5:156:14 | S1 |
178180
| main.rs:181:9:181:14 | x.m1() | | main.rs:180:22:180:23 | T1 |
179181
| main.rs:186:15:186:18 | SelfParam | | main.rs:144:5:147:5 | MyThing |
180182
| main.rs:186:15:186:18 | SelfParam | A | main.rs:155:5:156:14 | S1 |
@@ -275,10 +277,8 @@ inferType
275277
| main.rs:267:13:267:16 | self | | main.rs:264:10:264:23 | T |
276278
| main.rs:267:13:267:21 | self.m1() | | main.rs:155:5:156:14 | S1 |
277279
| main.rs:271:41:271:45 | thing | | main.rs:271:23:271:38 | T |
278-
| main.rs:271:57:273:5 | { ... } | | main.rs:155:5:156:14 | S1 |
279280
| main.rs:271:57:273:5 | { ... } | | main.rs:271:19:271:20 | TS |
280281
| main.rs:272:9:272:13 | thing | | main.rs:271:23:271:38 | T |
281-
| main.rs:272:9:272:26 | thing.convert_to() | | main.rs:155:5:156:14 | S1 |
282282
| main.rs:272:9:272:26 | thing.convert_to() | | main.rs:271:19:271:20 | TS |
283283
| main.rs:275:31:275:35 | thing | | main.rs:275:14:275:28 | TP |
284284
| main.rs:275:48:277:5 | { ... } | | main.rs:155:5:156:14 | S1 |
@@ -389,15 +389,12 @@ inferType
389389
| main.rs:319:31:319:38 | thing_s1 | A | main.rs:155:5:156:14 | S1 |
390390
| main.rs:320:26:320:26 | x | | main.rs:155:5:156:14 | S1 |
391391
| main.rs:321:13:321:13 | y | | main.rs:144:5:147:5 | MyThing |
392-
| main.rs:321:13:321:13 | y | | main.rs:155:5:156:14 | S1 |
393392
| main.rs:321:13:321:13 | y | A | main.rs:157:5:158:14 | S2 |
394393
| main.rs:321:17:321:39 | call_trait_m1(...) | | main.rs:144:5:147:5 | MyThing |
395-
| main.rs:321:17:321:39 | call_trait_m1(...) | | main.rs:155:5:156:14 | S1 |
396394
| main.rs:321:17:321:39 | call_trait_m1(...) | A | main.rs:157:5:158:14 | S2 |
397395
| main.rs:321:31:321:38 | thing_s2 | | main.rs:144:5:147:5 | MyThing |
398396
| main.rs:321:31:321:38 | thing_s2 | A | main.rs:157:5:158:14 | S2 |
399397
| main.rs:322:26:322:26 | y | | main.rs:144:5:147:5 | MyThing |
400-
| main.rs:322:26:322:26 | y | | main.rs:155:5:156:14 | S1 |
401398
| main.rs:322:26:322:26 | y | A | main.rs:157:5:158:14 | S2 |
402399
| main.rs:322:26:322:28 | y.a | | main.rs:157:5:158:14 | S2 |
403400
| main.rs:325:13:325:13 | a | | main.rs:149:5:153:5 | MyPair |
@@ -468,10 +465,8 @@ inferType
468465
| main.rs:344:21:344:37 | MyThing {...} | | main.rs:144:5:147:5 | MyThing |
469466
| main.rs:344:21:344:37 | MyThing {...} | A | main.rs:155:5:156:14 | S1 |
470467
| main.rs:344:34:344:35 | S1 | | main.rs:155:5:156:14 | S1 |
471-
| main.rs:345:13:345:13 | i | | main.rs:155:5:156:14 | S1 |
472468
| main.rs:345:17:345:21 | thing | | main.rs:144:5:147:5 | MyThing |
473469
| main.rs:345:17:345:21 | thing | A | main.rs:155:5:156:14 | S1 |
474-
| main.rs:345:17:345:34 | thing.convert_to() | | main.rs:155:5:156:14 | S1 |
475470
| main.rs:346:13:346:13 | j | | main.rs:155:5:156:14 | S1 |
476471
| main.rs:346:17:346:33 | convert_to(...) | | main.rs:155:5:156:14 | S1 |
477472
| main.rs:346:28:346:32 | thing | | main.rs:144:5:147:5 | MyThing |

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
368368
predicate potentialInstantiationOf(TypeAbstraction abs, App app, TypeMention term);
369369
}
370370

371-
module GenIsInstantiationOf<TypeTreeSig App, IsInstantiationOfSig<App> Input> {
371+
module IsInstantiationOf<TypeTreeSig App, IsInstantiationOfSig<App> Input> {
372372
private import Input
373373

374374
/** Gets the `i`th path in `term` per some arbitrary order. */
@@ -520,6 +520,28 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
520520
)
521521
}
522522

523+
/**
524+
* Holds if the type `type` can satisfy the constraint `constraint`
525+
* through `abs`, `sub`, and `constraintMention`.
526+
*/
527+
predicate typeToTypeMention(
528+
Type type, Type constraint, TypeAbstraction abs, TypeMention sub,
529+
TypeMention constraintMention
530+
) {
531+
typeSatisfiesConstraintTrans(abs, sub, constraintMention, _, _) and
532+
type = resolveTypeMentionRoot(sub) and
533+
constraint = resolveTypeMentionRoot(constraintMention)
534+
}
535+
536+
int countConstraintImplementations(Type type, Type constraint) {
537+
result =
538+
strictcount(TypeAbstraction abs, TypeMention tm, TypeMention constraintMention |
539+
typeToTypeMention(type, constraint, abs, tm, constraintMention)
540+
|
541+
constraintMention
542+
)
543+
}
544+
523545
/**
524546
* Holds if `baseMention` is a (transitive) base type mention of `sub`,
525547
* and `t` is mentioned (implicitly) at `path` inside `baseMention`. For
@@ -847,23 +869,23 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
847869

848870
private module AccessConstraint {
849871
/**
850-
* Holds if inferring types at `a` might depend on the type at `path` of
851-
* `apos` having `base` as a transitive base type.
872+
* If the access `a` for `apos` and `path` has the root type `type` and
873+
* type inference requires it to satisfy the constraint `constraint`.
852874
*/
853-
private predicate relevantAccess(Access a, AccessPosition apos, TypePath path, Type base) {
875+
private predicate relevantAccess(
876+
Access a, AccessPosition apos, TypePath path, Type type, Type constraint
877+
) {
854878
exists(Declaration target, DeclarationPosition dpos |
855-
adjustedAccessType(a, apos, target, _, _) and
856-
accessDeclarationPositionMatch(apos, dpos)
857-
|
858-
path.isEmpty() and declarationBaseType(target, dpos, base, _, _)
859-
or
860-
typeParameterConstraintHasTypeParameter(target, dpos, path, _, base, _, _)
879+
target = a.getTarget() and
880+
type = a.getInferredType(apos, path) and
881+
accessDeclarationPositionMatch(apos, dpos) and
882+
typeParameterConstraintHasTypeParameter(target, dpos, path, _, constraint, _, _)
861883
)
862884
}
863885

864886
newtype TTRelevantAccess =
865887
TRelevantAccess(Access a, AccessPosition apos, TypePath path) {
866-
relevantAccess(a, apos, path, _)
888+
relevantAccess(a, apos, path, _, _)
867889
}
868890

869891
class RelevantAccess extends TTRelevantAccess {
@@ -878,37 +900,69 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
878900
}
879901

880902
string toString() {
881-
result = a.toString() + "," + apos.toString() + "," + pathToSub.toString()
903+
result = a.toString() + ", " + apos.toString() + ", " + pathToSub.toString()
882904
}
883905

884906
Location getLocation() { result = a.getLocation() }
885907
}
886908

887909
module IsInstantiationOfInput implements IsInstantiationOfSig<RelevantAccess> {
888910
predicate potentialInstantiationOf(TypeAbstraction abs, RelevantAccess at, TypeMention sub) {
889-
exists(TypeMention sup, Access a, AccessPosition apos, TypePath pathToSub |
911+
// We only need to check instantiations where there are multiple candidates.
912+
exists(
913+
TypeMention constraintMention, Access a, AccessPosition apos, TypePath pathToSub,
914+
Type type
915+
|
916+
type = resolveTypeMentionRoot(sub) and
890917
at = TRelevantAccess(a, apos, pathToSub) and
891-
relevantAccess(a, apos, pathToSub, resolveTypeMentionRoot(sup)) and
892-
typeSatisfiesConstraintTrans(abs, sub, sup, _, _)
918+
relevantAccess(a, apos, pathToSub, type, resolveTypeMentionRoot(constraintMention)) and
919+
typeSatisfiesConstraintTrans(abs, sub, constraintMention, _, _) and
920+
countConstraintImplementations(type, resolveTypeMentionRoot(constraintMention)) > 1
893921
)
894922
}
895923
}
896924

897-
module IsInstantiationOf = GenIsInstantiationOf<RelevantAccess, IsInstantiationOfInput>;
925+
/**
926+
* The type at `a`, `apos`, `pathToSub` satisfies `constraint` through
927+
* `abs`, `sub`, and `constraintMention`.
928+
*/
929+
predicate hasConstraintMention(
930+
Access a, AccessPosition apos, TypePath pathToSub, Type constraint, TypeAbstraction abs,
931+
TypeMention sub, TypeMention constraintMention
932+
) {
933+
exists(Type type | relevantAccess(a, apos, pathToSub, type, constraint) |
934+
not exists(countConstraintImplementations(type, constraint)) and
935+
typeSatisfiesConstraintTrans(abs, sub, constraintMention, _, _) and
936+
resolveTypeMentionRoot(sub) = abs.getATypeParameter() and
937+
constraint = resolveTypeMentionRoot(constraintMention)
938+
or
939+
countConstraintImplementations(type, constraint) > 0 and
940+
typeToTypeMention(type, constraint, abs, sub, constraintMention) and
941+
// When there are multiple ways the type could implement the
942+
// constraint we need to find the right implementation, which is the
943+
// one where the type instantiates the precondition.
944+
if countConstraintImplementations(type, constraint) > 1
945+
then
946+
IsInstantiationOf<RelevantAccess, IsInstantiationOfInput>::isInstantiationOf(TRelevantAccess(a,
947+
apos, pathToSub), abs, sub)
948+
else any()
949+
)
950+
}
898951

899952
/**
900953
* Holds if the constraint is satisfied.
901954
*/
902955
pragma[nomagic]
903956
predicate satisfiesConstraintTypeMention(
904-
Access a, AccessPosition apos, TypePath pathToSub, TypeMention sup, TypePath path, Type t
957+
Access a, AccessPosition apos, TypePath pathToSub, Type constraint, TypePath path, Type t
905958
) {
906-
relevantAccess(a, apos, pathToSub, resolveTypeMentionRoot(sup)) and
907-
exists(TypeAbstraction abs, TypeMention sub, TypePath prefix, Type t0, RelevantAccess at |
959+
exists(
960+
RelevantAccess at, TypeAbstraction abs, TypeMention sub, Type t0, TypePath prefix,
961+
TypeMention constraintMention
962+
|
908963
at = TRelevantAccess(a, apos, pathToSub) and
909-
// The found sub type is more general than the inferred access type
910-
typeSatisfiesConstraintTrans(abs, sub, sup, prefix, t0) and
911-
IsInstantiationOf::isInstantiationOf(at, abs, sub) and
964+
hasConstraintMention(a, apos, pathToSub, constraint, abs, sub, constraintMention) and
965+
typeSatisfiesConstraintTrans(abs, sub, constraintMention, prefix, t0) and
912966
(
913967
not t0 = abs.getATypeParameter() and t = t0 and path = prefix
914968
or
@@ -1049,13 +1103,13 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
10491103
not exists(getTypeArgument(a, target, tp, _)) and
10501104
target = a.getTarget() and
10511105
exists(
1052-
TypeMention base, AccessPosition apos, DeclarationPosition dpos, TypePath pathToTp,
1106+
Type constraint, AccessPosition apos, DeclarationPosition dpos, TypePath pathToTp,
10531107
TypePath pathToTp2
10541108
|
10551109
accessDeclarationPositionMatch(apos, dpos) and
1056-
typeParameterConstraintHasTypeParameter(target, dpos, pathToTp2, _,
1057-
resolveTypeMentionRoot(base), pathToTp, tp) and
1058-
AccessConstraint::satisfiesConstraintTypeMention(a, apos, pathToTp2, base,
1110+
typeParameterConstraintHasTypeParameter(target, dpos, pathToTp2, _, constraint, pathToTp,
1111+
tp) and
1112+
AccessConstraint::satisfiesConstraintTypeMention(a, apos, pathToTp2, constraint,
10591113
pathToTp.append(path), t)
10601114
)
10611115
}

0 commit comments

Comments
 (0)