Skip to content

Commit 6d972e0

Browse files
authored
gh-130415: Narrow types to constants in branches involving specialized comparisons with a constant (GH-144150)
1 parent 2984024 commit 6d972e0

File tree

6 files changed

+360
-8
lines changed

6 files changed

+360
-8
lines changed

Include/internal/pycore_optimizer_types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ typedef struct {
7676
typedef enum {
7777
JIT_PRED_IS,
7878
JIT_PRED_IS_NOT,
79+
JIT_PRED_EQ,
80+
JIT_PRED_NE,
7981
} JitOptPredicateKind;
8082

8183
typedef struct {

Lib/test/test_capi/test_opt.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,138 @@ def testfunc(n):
890890
self.assertLessEqual(len(guard_nos_unicode_count), 1)
891891
self.assertIn("_COMPARE_OP_STR", uops)
892892

893+
def test_compare_int_eq_narrows_to_constant(self):
894+
def f(n):
895+
def return_1():
896+
return 1
897+
898+
hits = 0
899+
v = return_1()
900+
for _ in range(n):
901+
if v == 1:
902+
if v == 1:
903+
hits += 1
904+
return hits
905+
906+
res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
907+
self.assertEqual(res, TIER2_THRESHOLD)
908+
self.assertIsNotNone(ex)
909+
uops = get_opnames(ex)
910+
911+
# Constant narrowing allows constant folding for second comparison
912+
self.assertLessEqual(count_ops(ex, "_COMPARE_OP_INT"), 1)
913+
914+
def test_compare_int_ne_narrows_to_constant(self):
915+
def f(n):
916+
def return_1():
917+
return 1
918+
919+
hits = 0
920+
v = return_1()
921+
for _ in range(n):
922+
if v != 1:
923+
hits += 1000
924+
else:
925+
if v == 1:
926+
hits += v + 1
927+
return hits
928+
929+
res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
930+
self.assertEqual(res, TIER2_THRESHOLD * 2)
931+
self.assertIsNotNone(ex)
932+
uops = get_opnames(ex)
933+
934+
# Constant narrowing allows constant folding for second comparison
935+
self.assertLessEqual(count_ops(ex, "_COMPARE_OP_INT"), 1)
936+
937+
def test_compare_float_eq_narrows_to_constant(self):
938+
def f(n):
939+
def return_tenth():
940+
return 0.1
941+
942+
hits = 0
943+
v = return_tenth()
944+
for _ in range(n):
945+
if v == 0.1:
946+
if v == 0.1:
947+
hits += 1
948+
return hits
949+
950+
res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
951+
self.assertEqual(res, TIER2_THRESHOLD)
952+
self.assertIsNotNone(ex)
953+
uops = get_opnames(ex)
954+
955+
# Constant narrowing allows constant folding for second comparison
956+
self.assertLessEqual(count_ops(ex, "_COMPARE_OP_FLOAT"), 1)
957+
958+
def test_compare_float_ne_narrows_to_constant(self):
959+
def f(n):
960+
def return_tenth():
961+
return 0.1
962+
963+
hits = 0
964+
v = return_tenth()
965+
for _ in range(n):
966+
if v != 0.1:
967+
hits += 1000
968+
else:
969+
if v == 0.1:
970+
hits += 1
971+
return hits
972+
973+
res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
974+
self.assertEqual(res, TIER2_THRESHOLD)
975+
self.assertIsNotNone(ex)
976+
uops = get_opnames(ex)
977+
978+
# Constant narrowing allows constant folding for second comparison
979+
self.assertLessEqual(count_ops(ex, "_COMPARE_OP_FLOAT"), 1)
980+
981+
def test_compare_str_eq_narrows_to_constant(self):
982+
def f(n):
983+
def return_hello():
984+
return "hello"
985+
986+
hits = 0
987+
v = return_hello()
988+
for _ in range(n):
989+
if v == "hello":
990+
if v == "hello":
991+
hits += 1
992+
return hits
993+
994+
res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
995+
self.assertEqual(res, TIER2_THRESHOLD)
996+
self.assertIsNotNone(ex)
997+
uops = get_opnames(ex)
998+
999+
# Constant narrowing allows constant folding for second comparison
1000+
self.assertLessEqual(count_ops(ex, "_COMPARE_OP_STR"), 1)
1001+
1002+
def test_compare_str_ne_narrows_to_constant(self):
1003+
def f(n):
1004+
def return_hello():
1005+
return "hello"
1006+
1007+
hits = 0
1008+
v = return_hello()
1009+
for _ in range(n):
1010+
if v != "hello":
1011+
hits += 1000
1012+
else:
1013+
if v == "hello":
1014+
hits += 1
1015+
return hits
1016+
1017+
res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
1018+
self.assertEqual(res, TIER2_THRESHOLD)
1019+
self.assertIsNotNone(ex)
1020+
uops = get_opnames(ex)
1021+
1022+
# Constant narrowing allows constant folding for second comparison
1023+
self.assertLessEqual(count_ops(ex, "_COMPARE_OP_STR"), 1)
1024+
8931025
@unittest.skip("gh-139109 WIP")
8941026
def test_combine_stack_space_checks_sequential(self):
8951027
def dummy12(x):

Python/optimizer_analysis.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ add_op(JitOptContext *ctx, _PyUOpInstruction *this_instr,
250250
#define sym_new_predicate _Py_uop_sym_new_predicate
251251
#define sym_apply_predicate_narrowing _Py_uop_sym_apply_predicate_narrowing
252252

253+
/* Comparison oparg masks */
254+
#define COMPARE_LT_MASK 2
255+
#define COMPARE_GT_MASK 4
256+
#define COMPARE_EQ_MASK 8
257+
253258
#define JUMP_TO_LABEL(label) goto label;
254259

255260
static int

Python/optimizer_bytecodes.c

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,21 +521,51 @@ dummy_func(void) {
521521
}
522522

523523
op(_COMPARE_OP_INT, (left, right -- res, l, r)) {
524-
res = sym_new_type(ctx, &PyBool_Type);
524+
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
525+
526+
if (cmp_mask == COMPARE_EQ_MASK) {
527+
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
528+
}
529+
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
530+
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
531+
}
532+
else {
533+
res = sym_new_type(ctx, &PyBool_Type);
534+
}
525535
l = left;
526536
r = right;
527537
REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);
528538
}
529539

530540
op(_COMPARE_OP_FLOAT, (left, right -- res, l, r)) {
531-
res = sym_new_type(ctx, &PyBool_Type);
541+
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
542+
543+
if (cmp_mask == COMPARE_EQ_MASK) {
544+
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
545+
}
546+
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
547+
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
548+
}
549+
else {
550+
res = sym_new_type(ctx, &PyBool_Type);
551+
}
532552
l = left;
533553
r = right;
534554
REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);
535555
}
536556

537557
op(_COMPARE_OP_STR, (left, right -- res, l, r)) {
538-
res = sym_new_type(ctx, &PyBool_Type);
558+
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
559+
560+
if (cmp_mask == COMPARE_EQ_MASK) {
561+
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
562+
}
563+
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
564+
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
565+
}
566+
else {
567+
res = sym_new_type(ctx, &PyBool_Type);
568+
}
539569
l = left;
540570
r = right;
541571
REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);

Python/optimizer_cases.c.h

Lines changed: 30 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)