Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions Doc/library/itertools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,8 @@ and :term:`generators <generator>` which incur interpreter overhead.
from contextlib import suppress
from functools import reduce
from math import comb, isqrt, prod, sumprod
from operator import getitem, is_not, itemgetter, mul, neg
from operator import getitem, is_not, itemgetter, mul, neg, truediv


# ==== Basic one liners ====

Expand All @@ -858,9 +859,10 @@ and :term:`generators <generator>` which incur interpreter overhead.
# prepend(1, [2, 3, 4]) → 1 2 3 4
return chain([value], iterable)

def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))
def running_mean(iterable):
"Yield the average of all values seen so far."
# running_mean([8.5, 9.5, 7.5, 6.5]) -> 8.5 9.0 8.5 8.0
return map(truediv, accumulate(iterable), count(1))

def repeatfunc(function, times=None, *args):
"Repeat calls to a function with specified arguments."
Expand Down Expand Up @@ -913,6 +915,7 @@ and :term:`generators <generator>` which incur interpreter overhead.
# all_equal('4٤௪౪໔', key=int) → True
return len(take(2, groupby(iterable, key))) <= 1


# ==== Data pipelines ====

def unique_justseen(iterable, key=None):
Expand Down Expand Up @@ -1021,6 +1024,7 @@ and :term:`generators <generator>` which incur interpreter overhead.
while True:
yield function()


# ==== Mathematical operations ====

def multinomial(*counts):
Expand All @@ -1040,6 +1044,7 @@ and :term:`generators <generator>` which incur interpreter overhead.
# sum_of_squares([10, 20, 30]) → 1400
return sumprod(*tee(iterable))


# ==== Matrix operations ====

def reshape(matrix, columns):
Expand All @@ -1058,6 +1063,7 @@ and :term:`generators <generator>` which incur interpreter overhead.
n = len(m2[0])
return batched(starmap(sumprod, product(m1, transpose(m2))), n)


# ==== Polynomial arithmetic ====

def convolve(signal, kernel):
Expand Down Expand Up @@ -1114,6 +1120,7 @@ and :term:`generators <generator>` which incur interpreter overhead.
powers = reversed(range(1, n))
return list(map(mul, coefficients, powers))


# ==== Number theory ====

def sieve(n):
Expand Down Expand Up @@ -1230,8 +1237,8 @@ and :term:`generators <generator>` which incur interpreter overhead.
[(0, 'a'), (1, 'b'), (2, 'c')]


>>> list(islice(tabulate(lambda x: 2*x), 4))
[0, 2, 4, 6]
>>> list(running_mean([8.5, 9.5, 7.5, 6.5]))
[8.5, 9.0, 8.5, 8.0]


>>> for _ in loops(5):
Expand Down Expand Up @@ -1798,6 +1805,10 @@ and :term:`generators <generator>` which incur interpreter overhead.

# Old recipes and their tests which are guaranteed to continue to work.

def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))

def old_sumprod_recipe(vec1, vec2):
"Compute a sum of products."
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))
Expand Down Expand Up @@ -1877,6 +1888,10 @@ and :term:`generators <generator>` which incur interpreter overhead.
.. doctest::
:hide:

>>> list(islice(tabulate(lambda x: 2*x), 4))
[0, 2, 4, 6]


>>> dotproduct([1,2,3], [4,5,6])
32

Expand Down
107 changes: 36 additions & 71 deletions Lib/test/test_free_threading/test_itertools.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,59 @@
import unittest
from threading import Thread, Barrier
from itertools import batched, chain, cycle
from itertools import batched, chain, combinations_with_replacement, cycle, permutations
from test.support import threading_helper


threading_helper.requires_working_threading(module=True)

class ItertoolsThreading(unittest.TestCase):

@threading_helper.reap_threads
def test_batched(self):
number_of_threads = 10
number_of_iterations = 20
barrier = Barrier(number_of_threads)
def work(it):
barrier.wait()
while True:
try:
next(it)
except StopIteration:
break

data = tuple(range(1000))
for it in range(number_of_iterations):
batch_iterator = batched(data, 2)
worker_threads = []
for ii in range(number_of_threads):
worker_threads.append(
Thread(target=work, args=[batch_iterator]))
def work_iterator(it):
while True:
try:
next(it)
except StopIteration:
break

with threading_helper.start_threads(worker_threads):
pass

barrier.reset()
class ItertoolsThreading(unittest.TestCase):

@threading_helper.reap_threads
def test_cycle(self):
number_of_threads = 6
def test_batched(self):
number_of_iterations = 10
number_of_cycles = 400
for _ in range(number_of_iterations):
it = batched(tuple(range(1000)), 2)
threading_helper.run_concurrently(work_iterator, nthreads=10, args=[it])

barrier = Barrier(number_of_threads)
@threading_helper.reap_threads
def test_cycle(self):
def work(it):
barrier.wait()
for _ in range(number_of_cycles):
try:
next(it)
except StopIteration:
pass
for _ in range(400):
next(it)

data = (1, 2, 3, 4)
for it in range(number_of_iterations):
cycle_iterator = cycle(data)
worker_threads = []
for ii in range(number_of_threads):
worker_threads.append(
Thread(target=work, args=[cycle_iterator]))

with threading_helper.start_threads(worker_threads):
pass

barrier.reset()
number_of_iterations = 6
for _ in range(number_of_iterations):
it = cycle((1, 2, 3, 4))
threading_helper.run_concurrently(work, nthreads=6, args=[it])

@threading_helper.reap_threads
def test_chain(self):
number_of_threads = 6
number_of_iterations = 20

barrier = Barrier(number_of_threads)
def work(it):
barrier.wait()
while True:
try:
next(it)
except StopIteration:
break

data = [(1, )] * 200
for it in range(number_of_iterations):
chain_iterator = chain(*data)
worker_threads = []
for ii in range(number_of_threads):
worker_threads.append(
Thread(target=work, args=[chain_iterator]))

with threading_helper.start_threads(worker_threads):
pass
number_of_iterations = 10
for _ in range(number_of_iterations):
it = chain(*[(1,)] * 200)
threading_helper.run_concurrently(work_iterator, nthreads=6, args=[it])

barrier.reset()
@threading_helper.reap_threads
def test_combinations_with_replacement(self):
number_of_iterations = 6
for _ in range(number_of_iterations):
it = combinations_with_replacement(tuple(range(2)), 2)
threading_helper.run_concurrently(work_iterator, nthreads=6, args=[it])

@threading_helper.reap_threads
def test_permutations(self):
number_of_iterations = 6
for _ in range(number_of_iterations):
it = permutations(tuple(range(4)), 2)
threading_helper.run_concurrently(work_iterator, nthreads=6, args=[it])


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make concurrent iteration over :class:`itertools.combinations_with_replacement` and :class:`itertools.permutations` safe under free-threading.
24 changes: 22 additions & 2 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2587,7 +2587,7 @@ cwr_traverse(PyObject *op, visitproc visit, void *arg)
}

static PyObject *
cwr_next(PyObject *op)
cwr_next_lock_held(PyObject *op)
{
cwrobject *co = cwrobject_CAST(op);
PyObject *elem;
Expand Down Expand Up @@ -2666,6 +2666,16 @@ cwr_next(PyObject *op)
return NULL;
}

static PyObject *
cwr_next(PyObject *op)
{
PyObject *result;
Py_BEGIN_CRITICAL_SECTION(op);
result = cwr_next_lock_held(op);
Py_END_CRITICAL_SECTION()
return result;
}

static PyMethodDef cwr_methods[] = {
{"__sizeof__", cwr_sizeof, METH_NOARGS, sizeof_doc},
{NULL, NULL} /* sentinel */
Expand Down Expand Up @@ -2846,7 +2856,7 @@ permutations_traverse(PyObject *op, visitproc visit, void *arg)
}

static PyObject *
permutations_next(PyObject *op)
permutations_next_lock_held(PyObject *op)
{
permutationsobject *po = permutationsobject_CAST(op);
PyObject *elem;
Expand Down Expand Up @@ -2936,6 +2946,16 @@ permutations_next(PyObject *op)
return NULL;
}

static PyObject *
permutations_next(PyObject *op)
{
PyObject *result;
Py_BEGIN_CRITICAL_SECTION(op);
result = permutations_next_lock_held(op);
Py_END_CRITICAL_SECTION()
return result;
}

static PyMethodDef permuations_methods[] = {
{"__sizeof__", permutations_sizeof, METH_NOARGS, sizeof_doc},
{NULL, NULL} /* sentinel */
Expand Down
Loading