Skip to content

Commit 49b69dd

Browse files
committed
new threshold and consitency edits
1 parent 2a8e6a4 commit 49b69dd

File tree

1 file changed

+130
-100
lines changed

1 file changed

+130
-100
lines changed

Lib/difflib.py

Lines changed: 130 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@
3838
Match = _namedtuple('Match', 'a b size')
3939

4040

41+
def _adjust_indices(seq, start, stop):
42+
assert start >= 0
43+
size = len(seq)
44+
if stop is None or stop > size:
45+
stop = size
46+
return start, stop
47+
48+
4149
class _LCSUBSimple:
4250
"""Simple dict method for finding longest common substring.
4351
@@ -46,51 +54,61 @@ class _LCSUBSimple:
4654
S: O(n2)
4755
4856
Members:
49-
pos2 for x in seq2, pos2[x] is a list of the indices (into seq2)
57+
b2j for x in b, b2j[x] is a list of the indices (into b)
5058
at which x appears; junk elements do not appear
5159
"""
5260

53-
def __init__(self, seq2, junk=()):
61+
def __init__(self, b, junk=()):
5462
if not isinstance(junk, frozenset):
5563
junk = frozenset(junk)
56-
self.seq2 = seq2
64+
self.b = b
5765
self.junk = junk
58-
self.pos2 = None
59-
60-
def _build(self):
61-
if self.pos2 is None:
62-
self.pos2 = pos2 = {} # positions of each element in seq2
63-
for i, elt in enumerate(self.seq2):
64-
indices = pos2.setdefault(elt, [])
66+
self._b2j = None
67+
68+
def isbuilt(self, blo, bhi):
69+
blo, bhi = _adjust_indices(self.b, blo, bhi)
70+
if blo >= bhi:
71+
return True
72+
return self._b2j is not None
73+
74+
def _get_b2j(self):
75+
b2j = self._b2j
76+
if b2j is None:
77+
b2j = {} # positions of each element in b
78+
for i, elt in enumerate(self.b):
79+
indices = b2j.setdefault(elt, [])
6580
indices.append(i)
6681
junk = self.junk
6782
if junk:
6883
for elt in junk:
69-
del pos2[elt]
70-
71-
def find(self, seq1, start1=0, stop1=None, start2=0, stop2=None):
72-
if stop1 is None:
73-
stop1 = len(seq1)
74-
if stop2 is None:
75-
stop2 = len(self.seq2)
76-
self._build()
77-
pos2 = self.pos2
84+
del b2j[elt]
85+
self._b2j = b2j
86+
87+
return b2j
88+
89+
def find(self, a, alo=0, ahi=None, blo=0, bhi=None):
90+
alo, ahi = _adjust_indices(a, alo, ahi)
91+
blo, bhi = _adjust_indices(self.b, blo, bhi)
92+
if alo >= ahi or blo >= bhi:
93+
return (alo, blo, 0)
94+
95+
b2j = self._get_b2j()
7896
j2len = {}
7997
nothing = []
80-
besti, bestj, bestsize = start1, start2, 0
98+
besti, bestj, bestsize = alo, blo, 0
8199
# find longest junk-free match
82100
# during an iteration of the loop, j2len[j] = length of longest
83-
# junk-free match ending with seq1[i-1] and seq2[j]
84-
for i in range(start1, stop1):
85-
# look at all instances of seq1[i] in seq2; note that because
86-
# pos2 has no junk keys, the loop is skipped if seq1[i] is junk
101+
# junk-free match ending with a[i-1] and b[j]
102+
for i in range(alo, ahi):
103+
# look at all instances of a[i] in b; note that because
104+
# b2j has no junk keys, the loop is skipped if a[i] is junk
87105
j2lenget = j2len.get
88106
newj2len = {}
89-
for j in pos2.get(seq1[i], nothing):
90-
# seq1[i] matches seq2[j]
91-
if j < start2:
107+
for j in b2j.get(a[i], nothing):
108+
# a[i] matches b[j]
109+
if j < blo:
92110
continue
93-
if j >= stop2:
111+
if j >= bhi:
94112
break
95113
k = newj2len[j] = j2lenget(j-1, 0) + 1
96114
if k > bestsize:
@@ -123,90 +141,96 @@ class _LCSUBAutomaton:
123141
end_pos - end position of first occurrence (used for result)
124142
"""
125143

126-
def __init__(self, seq2, junk=()):
144+
def __init__(self, b, junk=()):
127145
if not isinstance(junk, frozenset):
128146
junk = frozenset(junk)
129-
self.seq2 = seq2
147+
self.b = b
130148
self.junk = junk
131-
self.root = None
132-
self.cache = (None, None)
149+
self._root = None
150+
self._cache = (None, None)
151+
152+
def isbuilt(self, blo, bhi):
153+
blo, bhi = _adjust_indices(self.b, blo, bhi)
154+
if blo >= bhi:
155+
return True
156+
return self._root is not None and self._cache == (blo, bhi)
133157

134-
def _build(self, start2, stop2):
158+
def _get_root(self, blo, bhi):
135159
"""
136-
Automaton needs to rebuild for every (start2, stop2)
160+
Automaton needs to rebuild for every (blo, bhi)
137161
This is made to cache the last one and only rebuild on new values
138162
139163
Note that to construct Automaton that can be queried for any
140-
(start2, stop2), each node would need to store a store a set of
164+
(blo, bhi), each node would need to store a store a set of
141165
indices. And this is prone to O(n^2) memory explosion.
142166
Current approach maintains reasonable memory guarantees
143167
and is also much simpler in comparison.
144168
"""
145-
key = (start2, stop2)
146-
if self.root is not None and self.cache == key:
147-
return
169+
key = (blo, bhi)
170+
root = self._root
171+
if root is None or self._cache != key:
172+
root = [0, None, {}, -1]
173+
b = self.b
174+
junk = self.junk
175+
last_len = 0
176+
last = root
177+
for j in range(blo, bhi):
178+
c = b[j]
179+
if c in junk:
180+
last_len = 0
181+
last = root
182+
else:
183+
last_len += 1
184+
curr = [last_len, None, {}, j]
148185

149-
self.root = root = [0, None, {}, -1]
150-
seq2 = self.seq2
151-
junk = self.junk
152-
last_len = 0
153-
last = root
154-
for j in range(start2, stop2):
155-
c = seq2[j]
156-
if c in junk:
157-
last_len = 0
158-
last = root
159-
else:
160-
last_len += 1
161-
curr = [last_len, None, {}, j]
162-
163-
p = last
164-
p_next = p[_NEXT]
165-
while c not in p_next:
166-
p_next[c] = curr
167-
if p is root:
168-
curr[_LINK] = root
169-
break
170-
p = p[_LINK]
186+
p = last
171187
p_next = p[_NEXT]
172-
else:
173-
q = p_next[c]
174-
p_length_p1 = p[_LENGTH] + 1
175-
if p_length_p1 == q[_LENGTH]:
176-
curr[_LINK] = q
188+
while c not in p_next:
189+
p_next[c] = curr
190+
if p is root:
191+
curr[_LINK] = root
192+
break
193+
p = p[_LINK]
194+
p_next = p[_NEXT]
177195
else:
178-
# Copy `q[_POS]` to ensure leftmost match in seq2
179-
clone = [p_length_p1, q[_LINK], q[_NEXT].copy(), q[_POS]]
180-
while (p_next := p[_NEXT]).get(c) is q:
181-
p_next[c] = clone
182-
if p is root:
183-
break
184-
p = p[_LINK]
185-
186-
q[_LINK] = curr[_LINK] = clone
187-
188-
last = curr
189-
190-
self.cache = key
191-
192-
def find(self, seq1, start1=0, stop1=None, start2=0, stop2=None):
193-
size1 = len(seq1)
194-
size2 = len(self.seq2)
195-
if stop1 is None or stop1 > size1:
196-
stop1 = size1
197-
if stop2 is None or stop2 > size2:
198-
stop2 = size2
199-
self._build(start2, stop2)
200-
root = self.root
196+
q = p_next[c]
197+
p_length_p1 = p[_LENGTH] + 1
198+
if p_length_p1 == q[_LENGTH]:
199+
curr[_LINK] = q
200+
else:
201+
# Copy `q[_POS]` to ensure leftmost match in b
202+
clone = [p_length_p1, q[_LINK], q[_NEXT].copy(), q[_POS]]
203+
while (p_next := p[_NEXT]).get(c) is q:
204+
p_next[c] = clone
205+
if p is root:
206+
break
207+
p = p[_LINK]
208+
209+
q[_LINK] = curr[_LINK] = clone
210+
211+
last = curr
212+
213+
self._root = root
214+
self._cache = key
215+
216+
return root
217+
218+
def find(self, a, alo=0, ahi=None, blo=0, bhi=None):
219+
alo, ahi = _adjust_indices(a, alo, ahi)
220+
blo, bhi = _adjust_indices(self.b, blo, bhi)
221+
if alo >= ahi or blo >= bhi:
222+
return (alo, blo, 0)
223+
224+
root = self._get_root(blo, bhi)
201225
junk = self.junk
202226
v = root
203227
l = 0
204228
best_len = 0
205229
best_state = None
206230
best_pos = 0
207231

208-
for i in range(start1, stop1):
209-
c = seq1[i]
232+
for i in range(alo, ahi):
233+
c = a[i]
210234
if c in junk:
211235
v = root
212236
l = 0
@@ -225,7 +249,7 @@ def find(self, seq1, start1=0, stop1=None, start2=0, stop2=None):
225249
best_pos = i
226250

227251
if not best_len:
228-
return (start1, start2, 0)
252+
return (alo, blo, 0)
229253

230254
start_in_s1 = best_pos + 1 - best_len
231255
end_in_s2 = best_state[_POS]
@@ -492,18 +516,22 @@ def __chain_b(self):
492516
for elt in popular: # ditto; as fast for 1% deletion
493517
del bcounts[elt]
494518

495-
self._max_bcount = max(bcounts.values()) if bcounts else 0
519+
if not bcounts:
520+
self._bcount_thres = 0
521+
else:
522+
sum_bcount = sum(bcounts.values())
523+
avg_bcount = sum(c * c for c in bcounts.values()) / sum_bcount
524+
max_bcount = max(bcounts.values())
525+
self._bcount_thres = avg_bcount * 0.8 + max_bcount * 0.2
526+
496527
self._all_junk = all_junk = frozenset(junk | popular)
497528
self._lcsub_simple = _LCSUBSimple(b, all_junk)
498529
self._lcsub_automaton = _LCSUBAutomaton(b, all_junk)
499530

500531
@property
501532
def b2j(self):
502533
# NOTE: For backwards compatibility
503-
simple_calc = self._lcsub_simple
504-
if simple_calc.pos2 is None:
505-
simple_calc._build()
506-
return simple_calc.pos2
534+
return self._lcsub_simple._get_b2j()
507535

508536
def find_longest_match(self, alo=0, ahi=None, blo=0, bhi=None):
509537
"""Find longest matching block in a[alo:ahi] and b[blo:bhi].
@@ -596,12 +624,14 @@ def find_longest_match(self, alo=0, ahi=None, blo=0, bhi=None):
596624
simple_calc = self._lcsub_simple
597625
automaton = self._lcsub_automaton
598626

627+
simple_cost = self._bcount_thres * tmp_asize
628+
if not simple_calc.isbuilt(blo, bhi):
629+
simple_cost += bsize
630+
599631
automaton_cost = tmp_asize
600-
if automaton.cache != (blo, bhi):
632+
if not automaton.isbuilt(blo, bhi):
601633
automaton_cost += bsize * 6
602-
simple_cost = self._max_bcount * tmp_asize
603-
if simple_calc.pos2 is None:
604-
simple_cost += bsize
634+
605635
if simple_cost < automaton_cost:
606636
calc = simple_calc
607637
else:

0 commit comments

Comments
 (0)