Skip to content

Commit af4f55e

Browse files
committed
Rust: Generalize certain type inference logic
1 parent ed3a33f commit af4f55e

File tree

4 files changed

+127
-75
lines changed

4 files changed

+127
-75
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 105 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,6 @@ private Type inferAnnotatedType(AstNode n, TypePath path) {
257257

258258
/** Module for inferring certain type information. */
259259
private module CertainTypeInference {
260-
/** Holds if the type mention does not contain any inferred types `_`. */
261-
predicate typeMentionIsComplete(TypeMention tm) {
262-
not exists(InferTypeRepr t | t.getParentNode*() = tm)
263-
}
264-
265260
/**
266261
* Holds if `ce` is a call where we can infer the type with certainty and if
267262
* `f` is the target of the call and `p` the path invoked by the call.
@@ -373,13 +368,46 @@ private module CertainTypeInference {
373368
Type inferCertainType(AstNode n, TypePath path) {
374369
exists(TypeMention tm |
375370
tm = getTypeAnnotation(n) and
376-
typeMentionIsComplete(tm) and
377371
result = tm.resolveTypeAt(path)
378372
)
379373
or
380374
result = inferCertainCallExprType(n, path)
381375
or
382376
result = inferCertainTypeEquality(n, path)
377+
or
378+
infersCertainTypeAt(n, path, result.getATypeParameter())
379+
}
380+
381+
/**
382+
* Holds if `n` has complete and certain type information at the type path
383+
* `prefix.tp`. This entails that the type at `prefix` must be the type
384+
* that declares `tp`.
385+
*/
386+
pragma[nomagic]
387+
private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) {
388+
exists(TypePath path |
389+
exists(inferCertainType(n, path)) and
390+
path.isSnoc(prefix, tp)
391+
)
392+
}
393+
394+
/**
395+
* Holds if `n` has complete and certain type information at _some_ type path.
396+
*/
397+
pragma[nomagic]
398+
predicate hasInferredCertainType(AstNode n) { exists(inferCertainType(n, _)) }
399+
400+
bindingset[n, path, t]
401+
pragma[inline_late]
402+
predicate certainTypeConflict(AstNode n, TypePath path, Type t) {
403+
inferCertainType(n, path) != t
404+
or
405+
exists(TypePath prefix, TypePath suffix, TypeParameter tp, Type certainType |
406+
path = prefix.appendInverse(suffix) and
407+
tp = suffix.getHead() and
408+
inferCertainType(n, prefix) = certainType and
409+
not certainType.getATypeParameter() = tp
410+
)
383411
}
384412
}
385413

@@ -531,9 +559,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
531559

532560
pragma[nomagic]
533561
private Type inferTypeEquality(AstNode n, TypePath path) {
534-
// Don't propagate type information into a node for which we already have
535-
// certain type information.
536-
not exists(CertainTypeInference::inferCertainType(n, _)) and
537562
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
538563
result = inferType(n2, prefix2.appendInverse(suffix)) and
539564
path = prefix1.append(suffix)
@@ -2282,62 +2307,72 @@ private module Cached {
22822307
Stages::TypeInferenceStage::ref() and
22832308
result = CertainTypeInference::inferCertainType(n, path)
22842309
or
2285-
result = inferAnnotatedType(n, path)
2286-
or
2287-
result = inferLogicalOperationType(n, path)
2288-
or
2289-
result = inferAssignmentOperationType(n, path)
2290-
or
2291-
result = inferTypeEquality(n, path)
2292-
or
2293-
result = inferImplicitSelfType(n, path)
2294-
or
2295-
result = inferStructExprType(n, path)
2296-
or
2297-
result = inferTupleRootType(n) and
2298-
path.isEmpty()
2299-
or
2300-
result = inferPathExprType(n, path)
2301-
or
2302-
result = inferCallExprBaseType(n, path)
2303-
or
2304-
result = inferFieldExprType(n, path)
2305-
or
2306-
result = inferTupleIndexExprType(n, path)
2307-
or
2308-
result = inferTupleContainerExprType(n, path)
2309-
or
2310-
result = inferRefNodeType(n) and
2311-
path.isEmpty()
2312-
or
2313-
result = inferTryExprType(n, path)
2314-
or
2315-
result = inferLiteralType(n, path)
2316-
or
2317-
result = inferAsyncBlockExprRootType(n) and
2318-
path.isEmpty()
2319-
or
2320-
result = inferAwaitExprType(n, path)
2321-
or
2322-
result = inferArrayExprType(n) and
2323-
path.isEmpty()
2324-
or
2325-
result = inferRangeExprType(n) and
2326-
path.isEmpty()
2327-
or
2328-
result = inferIndexExprType(n, path)
2329-
or
2330-
result = inferForLoopExprType(n, path)
2331-
or
2332-
result = inferDynamicCallExprType(n, path)
2333-
or
2334-
result = inferClosureExprType(n, path)
2335-
or
2336-
result = inferCastExprType(n, path)
2337-
or
2338-
result = inferStructPatType(n, path)
2339-
or
2340-
result = inferTupleStructPatType(n, path)
2310+
// Don't propagate type information into a node for which we already have
2311+
// certain type information.
2312+
(
2313+
if CertainTypeInference::hasInferredCertainType(n)
2314+
then not CertainTypeInference::certainTypeConflict(n, path, result)
2315+
else any()
2316+
) and
2317+
// not exists(CertainTypeInference::inferCertainType(n, path)) and
2318+
(
2319+
result = inferAnnotatedType(n, path)
2320+
or
2321+
result = inferLogicalOperationType(n, path)
2322+
or
2323+
result = inferAssignmentOperationType(n, path)
2324+
or
2325+
result = inferTypeEquality(n, path)
2326+
or
2327+
result = inferImplicitSelfType(n, path)
2328+
or
2329+
result = inferStructExprType(n, path)
2330+
or
2331+
result = inferTupleRootType(n) and
2332+
path.isEmpty()
2333+
or
2334+
result = inferPathExprType(n, path)
2335+
or
2336+
result = inferCallExprBaseType(n, path)
2337+
or
2338+
result = inferFieldExprType(n, path)
2339+
or
2340+
result = inferTupleIndexExprType(n, path)
2341+
or
2342+
result = inferTupleContainerExprType(n, path)
2343+
or
2344+
result = inferRefNodeType(n) and
2345+
path.isEmpty()
2346+
or
2347+
result = inferTryExprType(n, path)
2348+
or
2349+
result = inferLiteralType(n, path)
2350+
or
2351+
result = inferAsyncBlockExprRootType(n) and
2352+
path.isEmpty()
2353+
or
2354+
result = inferAwaitExprType(n, path)
2355+
or
2356+
result = inferArrayExprType(n) and
2357+
path.isEmpty()
2358+
or
2359+
result = inferRangeExprType(n) and
2360+
path.isEmpty()
2361+
or
2362+
result = inferIndexExprType(n, path)
2363+
or
2364+
result = inferForLoopExprType(n, path)
2365+
or
2366+
result = inferDynamicCallExprType(n, path)
2367+
or
2368+
result = inferClosureExprType(n, path)
2369+
or
2370+
result = inferCastExprType(n, path)
2371+
or
2372+
result = inferStructPatType(n, path)
2373+
or
2374+
result = inferTupleStructPatType(n, path)
2375+
)
23412376
}
23422377
}
23432378

@@ -2438,6 +2473,11 @@ private module Debug {
24382473
c = max(countTypePaths(_, _, _))
24392474
}
24402475

2476+
Type debugInferCertainType(AstNode n, TypePath path) {
2477+
n = getRelevantLocatable() and
2478+
result = CertainTypeInference::inferCertainType(n, path)
2479+
}
2480+
24412481
Type debugInferCertainNonUniqueType(AstNode n, TypePath path) {
24422482
n = getRelevantLocatable() and
24432483
Consistency::nonUniqueCertainType(n, path) and

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,7 @@ mod indexers {
20742074
// implicit dereference. We cannot currently handle a position that is
20752075
// both implicitly dereferenced and implicitly borrowed, so the extra
20762076
// type sneaks in.
2077-
let x = slice[0].foo(); // $ target=foo type=x:S target=index SPURIOUS: type=slice:[]
2077+
let x = slice[0].foo(); // $ target=foo type=x:S target=index
20782078
}
20792079

20802080
pub fn f() {

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,6 @@ inferType
401401
| dereference.rs:123:32:123:41 | key_to_key | K | file://:0:0:0:0 | & |
402402
| dereference.rs:123:32:123:41 | key_to_key | K.&T | dereference.rs:99:5:100:21 | Key |
403403
| dereference.rs:123:32:123:41 | key_to_key | S | {EXTERNAL LOCATION} | RandomState |
404-
| dereference.rs:123:32:123:41 | key_to_key | V | dereference.rs:99:5:100:21 | Key |
405404
| dereference.rs:123:32:123:41 | key_to_key | V | file://:0:0:0:0 | & |
406405
| dereference.rs:123:32:123:41 | key_to_key | V.&T | dereference.rs:99:5:100:21 | Key |
407406
| dereference.rs:123:32:123:50 | key_to_key.get(...) | | {EXTERNAL LOCATION} | Option |
@@ -425,13 +424,9 @@ inferType
425424
| dereference.rs:127:9:127:18 | key_to_key | | {EXTERNAL LOCATION} | HashMap |
426425
| dereference.rs:127:9:127:18 | key_to_key | K | file://:0:0:0:0 | & |
427426
| dereference.rs:127:9:127:18 | key_to_key | K.&T | dereference.rs:99:5:100:21 | Key |
428-
| dereference.rs:127:9:127:18 | key_to_key | K.&T | file://:0:0:0:0 | & |
429-
| dereference.rs:127:9:127:18 | key_to_key | K.&T.&T | dereference.rs:99:5:100:21 | Key |
430427
| dereference.rs:127:9:127:18 | key_to_key | S | {EXTERNAL LOCATION} | RandomState |
431428
| dereference.rs:127:9:127:18 | key_to_key | V | file://:0:0:0:0 | & |
432429
| dereference.rs:127:9:127:18 | key_to_key | V.&T | dereference.rs:99:5:100:21 | Key |
433-
| dereference.rs:127:9:127:18 | key_to_key | V.&T | file://:0:0:0:0 | & |
434-
| dereference.rs:127:9:127:18 | key_to_key | V.&T.&T | dereference.rs:99:5:100:21 | Key |
435430
| dereference.rs:127:9:127:35 | key_to_key.insert(...) | | {EXTERNAL LOCATION} | Option |
436431
| dereference.rs:127:9:127:35 | key_to_key.insert(...) | T | file://:0:0:0:0 | & |
437432
| dereference.rs:127:9:127:35 | key_to_key.insert(...) | T.&T | dereference.rs:99:5:100:21 | Key |
@@ -3941,7 +3936,6 @@ inferType
39413936
| main.rs:2070:22:2070:26 | slice | &T.[T] | main.rs:2037:5:2038:13 | S |
39423937
| main.rs:2077:13:2077:13 | x | | main.rs:2037:5:2038:13 | S |
39433938
| main.rs:2077:17:2077:21 | slice | | file://:0:0:0:0 | & |
3944-
| main.rs:2077:17:2077:21 | slice | | file://:0:0:0:0 | [] |
39453939
| main.rs:2077:17:2077:21 | slice | &T | file://:0:0:0:0 | [] |
39463940
| main.rs:2077:17:2077:21 | slice | &T.[T] | main.rs:2037:5:2038:13 | S |
39473941
| main.rs:2077:17:2077:24 | slice[0] | | main.rs:2037:5:2038:13 | S |
@@ -4031,7 +4025,6 @@ inferType
40314025
| main.rs:2144:16:2144:19 | self | T | main.rs:2139:10:2139:17 | T |
40324026
| main.rs:2144:16:2144:21 | self.0 | | main.rs:2139:10:2139:17 | T |
40334027
| main.rs:2144:31:2144:35 | other | | main.rs:2137:5:2137:19 | S |
4034-
| main.rs:2144:31:2144:35 | other | T | main.rs:2099:5:2104:5 | Self [trait MyAdd] |
40354028
| main.rs:2144:31:2144:35 | other | T | main.rs:2139:10:2139:17 | T |
40364029
| main.rs:2144:31:2144:37 | other.0 | | main.rs:2099:5:2104:5 | Self [trait MyAdd] |
40374030
| main.rs:2144:31:2144:37 | other.0 | | main.rs:2139:10:2139:17 | T |
@@ -4047,7 +4040,6 @@ inferType
40474040
| main.rs:2153:16:2153:19 | self | | main.rs:2137:5:2137:19 | S |
40484041
| main.rs:2153:16:2153:19 | self | T | main.rs:2148:10:2148:17 | T |
40494042
| main.rs:2153:16:2153:21 | self.0 | | main.rs:2148:10:2148:17 | T |
4050-
| main.rs:2153:31:2153:35 | other | | main.rs:2099:5:2104:5 | Self [trait MyAdd] |
40514043
| main.rs:2153:31:2153:35 | other | | main.rs:2148:10:2148:17 | T |
40524044
| main.rs:2164:19:2164:22 | SelfParam | | main.rs:2137:5:2137:19 | S |
40534045
| main.rs:2164:19:2164:22 | SelfParam | T | main.rs:2157:14:2157:14 | T |

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,30 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
331331
bindingset[this, prefix]
332332
TypePath stripPrefix(TypePath prefix) { this = prefix + result }
333333

334+
/** Gets the path obtained by removing `suffix` from this path. */
335+
bindingset[this, suffix]
336+
TypePath stripSuffix(TypePath suffix) { this = result + suffix }
337+
338+
/** Gets the path obtained by removing `prefix` from this path. */
339+
bindingset[this]
340+
predicate startsWith(TypePath prefix) { prefix = this.prefix(this.indexOf(".") + 1) }
341+
334342
/** Holds if this path starts with `tp`, followed by `suffix`. */
335343
bindingset[this]
336344
predicate isCons(TypeParameter tp, TypePath suffix) {
337-
suffix = this.stripPrefix(TypePath::singleton(tp))
345+
exists(string regexp | regexp = "([0-9]+)\\.(.*)" |
346+
tp = TypeParameter::decode(this.regexpCapture(regexp, 1)) and
347+
suffix = this.regexpCapture(regexp, 2)
348+
)
349+
}
350+
351+
/** Holds if this path starts with `prefix`, followed by `tp`. */
352+
bindingset[this]
353+
predicate isSnoc(TypePath prefix, TypeParameter tp) {
354+
exists(string regexp | regexp = "(|.+\\.)([0-9]+)\\." |
355+
prefix = this.regexpCapture(regexp, 1) and
356+
tp = TypeParameter::decode(this.regexpCapture(regexp, 2))
357+
)
338358
}
339359

340360
/** Gets the head of this path, if any. */

0 commit comments

Comments
 (0)