Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 85 additions & 28 deletions src/sortedcontainers/sortedlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from bisect import bisect_left, bisect_right, insort
from collections.abc import MutableSequence, Sequence
from functools import reduce
from itertools import chain, repeat, starmap
from itertools import accumulate, chain, repeat, starmap
from math import log
from operator import add, eq, ge, gt, iadd, le, lt, ne
from reprlib import recursive_repr
Expand Down Expand Up @@ -84,6 +84,8 @@ class SortedList(MutableSequence):

"""

__slots__ = ('_lists', '_maxes', '_index', '_offset', '_len', '_load', '_cumsum')

DEFAULT_LOAD_FACTOR = 1000

def __init__(self, iterable=None, key=None):
Expand Down Expand Up @@ -111,6 +113,7 @@ def __init__(self, iterable=None, key=None):
self._maxes = []
self._index = []
self._offset = 0
self._cumsum = ()

if iterable is not None:
self._update(iterable)
Expand Down Expand Up @@ -186,6 +189,7 @@ def clear(self):
del self._lists[:]
del self._maxes[:]
del self._index[:]
self._cumsum = ()
self._offset = 0

_clear = clear
Expand Down Expand Up @@ -218,7 +222,8 @@ def add(self, value):
else:
insort(_lists[pos], value)

self._expand(pos)
if self._index or len(_lists[pos]) > (self._load << 1):
self._expand(pos)
else:
_lists.append([value])
_maxes.append(value)
Expand All @@ -234,12 +239,12 @@ def _expand(self, pos):
``SortedList._loc``.

"""
_load = self._load
_lists = self._lists
_index = self._index

if len(_lists[pos]) > (_load << 1):
if len(_lists[pos]) > (self._load << 1):
_maxes = self._maxes
_load = self._load

_lists_pos = _lists[pos]
half = _lists_pos[_load:]
Expand All @@ -250,13 +255,15 @@ def _expand(self, pos):
_maxes.insert(pos + 1, half[-1])

del _index[:]
self._cumsum = ()
else:
if _index:
child = self._offset + pos
while child:
_index[child] += 1
child = (child - 1) >> 1
_index[0] += 1
self._cumsum = ()

def update(self, iterable):
"""Update sorted list by adding all values from `iterable`.
Expand Down Expand Up @@ -294,6 +301,7 @@ def update(self, iterable):
_maxes.extend(sublist[-1] for sublist in _lists)
self._len = len(values)
del self._index[:]
self._cumsum = ()

_update = update

Expand Down Expand Up @@ -412,7 +420,6 @@ def _delete(self, pos, idx):

"""
_lists = self._lists
_maxes = self._maxes
_index = self._index

_lists_pos = _lists[pos]
Expand All @@ -423,15 +430,18 @@ def _delete(self, pos, idx):
len_lists_pos = len(_lists_pos)

if len_lists_pos > (self._load >> 1):
_maxes[pos] = _lists_pos[-1]
if idx == len_lists_pos:
self._maxes[pos] = _lists_pos[-1]

if _index:
child = self._offset + pos
while child > 0:
_index[child] -= 1
child = (child - 1) >> 1
_index[0] -= 1
self._cumsum = ()
elif len(_lists) > 1:
_maxes = self._maxes
if not pos:
pos += 1

Expand All @@ -442,14 +452,18 @@ def _delete(self, pos, idx):
del _lists[pos]
del _maxes[pos]
del _index[:]
self._cumsum = ()

self._expand(prev)
elif len_lists_pos:
_maxes[pos] = _lists_pos[-1]
self._maxes[pos] = _lists_pos[-1]
self._cumsum = ()
else:
_maxes = self._maxes
del _lists[pos]
del _maxes[pos]
del _index[:]
self._cumsum = ()

def _loc(self, pos, idx):
"""Convert an index pair (lists index, sublist index) into a single
Expand Down Expand Up @@ -501,13 +515,19 @@ def _loc(self, pos, idx):
:return: index in sorted list

"""
_cumsum = self._cumsum

if _cumsum:
return _cumsum[pos] + idx

if not pos:
return idx

_index = self._index

if not _index:
self._build_index()
return self._cumsum[pos] + idx

total = 0

Expand Down Expand Up @@ -601,11 +621,24 @@ def _pos(self, idx):
if idx < len(self._lists[0]):
return 0, idx

_cumsum = self._cumsum

if _cumsum:
# Fast path: use cached prefix-sum + C-level bisect
pos = bisect_right(_cumsum, idx) - 1
idx -= _cumsum[pos]
return (pos, idx)

_index = self._index

if not _index:
self._build_index()
_cumsum = self._cumsum
pos = bisect_right(_cumsum, idx) - 1
idx -= _cumsum[pos]
return (pos, idx)

# Standard path: tree traversal (cumsum invalidated, index exists)
pos = 0
child = 1
len_index = len(_index)
Expand Down Expand Up @@ -664,6 +697,7 @@ def _build_index(self):
if len(row0) == 1:
self._index[:] = row0
self._offset = 0
self._cumsum = (0,) + tuple(row0)
return

head = iter(row0)
Expand All @@ -676,6 +710,7 @@ def _build_index(self):
if len(row1) == 1:
self._index[:] = row1 + row0
self._offset = 1
self._cumsum = tuple(accumulate(row0, initial=0))
return

size = 2 ** (int(log(len(row1) - 1, 2)) + 1)
Expand All @@ -691,6 +726,9 @@ def _build_index(self):
reduce(iadd, reversed(tree), self._index)
self._offset = size * 2 - 1

# Build prefix-sum for fast _pos lookups (leading zero for direct indexing)
self._cumsum = tuple(accumulate(row0, initial=0))

def __delitem__(self, index):
"""Remove value at `index` from sorted list.

Expand All @@ -712,7 +750,10 @@ def __delitem__(self, index):
:raises IndexError: if index out of range

"""
if isinstance(index, slice):
if index.__class__ is int:
pos, idx = self._pos(index)
self._delete(pos, idx)
elif isinstance(index, slice):
start, stop, step = index.indices(self._len)

if step == 1 and start < stop:
Expand Down Expand Up @@ -766,6 +807,10 @@ def __getitem__(self, index):
"""
_lists = self._lists

if index.__class__ is int:
pos, idx = self._pos(index)
return _lists[pos][idx]

if isinstance(index, slice):
start, stop, step = index.indices(self._len)

Expand Down Expand Up @@ -810,25 +855,10 @@ def __getitem__(self, index):

indices = range(start, stop, step)
return list(self._getitem(index) for index in indices)
else:
if self._len:
if index == 0:
return _lists[0][0]
elif index == -1:
return _lists[-1][-1]
else:
raise IndexError('list index out of range')

if 0 <= index < len(_lists[0]):
return _lists[0][index]

len_last = len(_lists[-1])

if -len_last < index < 0:
return _lists[-1][len_last + index]

pos, idx = self._pos(index)
return _lists[pos][idx]
# Support for int-like objects (numpy.int64, etc.)
pos, idx = self._pos(index)
return _lists[pos][idx]

_getitem = __getitem__

Expand Down Expand Up @@ -1572,6 +1602,13 @@ def _check(self):
else:
child_sum = self._index[child] + self._index[child + 1]
assert child_sum == self._index[pos]

if self._cumsum:
assert self._cumsum == tuple(
accumulate(
(len(sublist) for sublist in self._lists), initial=0
)
)
except:
traceback.print_exc(file=sys.stdout)
print('len', self._len)
Expand Down Expand Up @@ -1617,6 +1654,8 @@ class SortedKeyList(SortedList):

"""

__slots__ = ('_key', '_keys')

def __init__(self, iterable=None, key=identity):
"""Initialize sorted-key list instance.

Expand Down Expand Up @@ -1649,6 +1688,7 @@ def __init__(self, iterable=None, key=identity):
self._maxes = []
self._index = []
self._offset = 0
self._cumsum = ()

if iterable is not None:
self._update(iterable)
Expand All @@ -1672,6 +1712,7 @@ def clear(self):
del self._keys[:]
del self._maxes[:]
del self._index[:]
self._cumsum = ()

_clear = clear

Expand Down Expand Up @@ -1710,7 +1751,8 @@ def add(self, value):
_lists[pos].insert(idx, value)
_keys[pos].insert(idx, key)

self._expand(pos)
if self._index or len(_keys[pos]) > (self._load << 1):
self._expand(pos)
else:
_lists.append([value])
_keys.append([key])
Expand Down Expand Up @@ -1748,13 +1790,15 @@ def _expand(self, pos):
_maxes.insert(pos + 1, half_keys[-1])

del _index[:]
self._cumsum = ()
else:
if _index:
child = self._offset + pos
while child:
_index[child] += 1
child = (child - 1) >> 1
_index[0] += 1
self._cumsum = ()

def update(self, iterable):
"""Update sorted-key list by adding all values from `iterable`.
Expand Down Expand Up @@ -1795,6 +1839,7 @@ def update(self, iterable):
_maxes.extend(sublist[-1] for sublist in _keys)
self._len = len(values)
del self._index[:]
self._cumsum = ()

_update = update

Expand Down Expand Up @@ -1974,14 +2019,16 @@ def _delete(self, pos, idx):
len_keys_pos = len(keys_pos)

if len_keys_pos > (self._load >> 1):
_maxes[pos] = keys_pos[-1]
if idx == len_keys_pos:
_maxes[pos] = keys_pos[-1]

if _index:
child = self._offset + pos
while child > 0:
_index[child] -= 1
child = (child - 1) >> 1
_index[0] -= 1
self._cumsum = ()
elif len(_keys) > 1:
if not pos:
pos += 1
Expand All @@ -1995,15 +2042,18 @@ def _delete(self, pos, idx):
del _keys[pos]
del _maxes[pos]
del _index[:]
self._cumsum = ()

self._expand(prev)
elif len_keys_pos:
_maxes[pos] = keys_pos[-1]
self._cumsum = ()
else:
del _lists[pos]
del _keys[pos]
del _maxes[pos]
del _index[:]
self._cumsum = ()

def irange(self, minimum=None, maximum=None, inclusive=(True, True), reverse=False):
"""Create an iterator of values between `minimum` and `maximum`.
Expand Down Expand Up @@ -2510,6 +2560,13 @@ def _check(self):
else:
child_sum = self._index[child] + self._index[child + 1]
assert child_sum == self._index[pos]

if self._cumsum:
assert self._cumsum == tuple(
accumulate(
(len(sublist) for sublist in self._lists), initial=0
)
)
except:
traceback.print_exc(file=sys.stdout)
print('len', self._len)
Expand Down