|
5 | 5 | asyncio async/await version |
6 | 6 |
|
7 | 7 | """ |
8 | | -# BEGIN FLAGS2_ASYNCIO_TOP |
| 8 | +# tag::FLAGS2_ASYNCIO_TOP[] |
9 | 9 | import asyncio |
10 | 10 | from collections import Counter |
11 | 11 |
|
12 | 12 | import aiohttp |
13 | | -from aiohttp import web |
14 | | -from aiohttp.http_exceptions import HttpProcessingError |
15 | 13 | import tqdm # type: ignore |
16 | 14 |
|
17 | 15 | from flags2_common import main, HTTPStatus, Result, save_flag |
|
23 | 21 |
|
24 | 22 |
|
25 | 23 | class FetchError(Exception): # <1> |
26 | | - def __init__(self, country_code): |
| 24 | + def __init__(self, country_code: str): |
27 | 25 | self.country_code = country_code |
28 | 26 |
|
29 | 27 |
|
30 | | -async def get_flag(session, base_url, cc): # <2> |
31 | | - cc = cc.lower() |
32 | | - url = f'{base_url}/{cc}/{cc}.gif' |
| 28 | +async def get_flag(session: aiohttp.ClientSession, # <2> |
| 29 | + base_url: str, |
| 30 | + cc: str) -> bytes: |
| 31 | + url = f'{base_url}/{cc}/{cc}.gif'.lower() |
33 | 32 | async with session.get(url) as resp: |
34 | 33 | if resp.status == 200: |
35 | 34 | return await resp.read() |
36 | | - elif resp.status == 404: |
37 | | - raise web.HTTPNotFound() |
38 | 35 | else: |
39 | | - raise HttpProcessingError( |
40 | | - code=resp.status, message=resp.reason, |
41 | | - headers=resp.headers) |
42 | | - |
43 | | - |
44 | | -async def download_one(session, cc, base_url, semaphore, verbose): # <3> |
| 36 | + resp.raise_for_status() # <3> |
| 37 | + return bytes() |
| 38 | + |
| 39 | +async def download_one(session: aiohttp.ClientSession, # <4> |
| 40 | + cc: str, |
| 41 | + base_url: str, |
| 42 | + semaphore: asyncio.Semaphore, |
| 43 | + verbose: bool) -> Result: |
45 | 44 | try: |
46 | | - async with semaphore: # <4> |
47 | | - image = await get_flag(session, base_url, cc) # <5> |
48 | | - except web.HTTPNotFound: # <6> |
49 | | - status = HTTPStatus.not_found |
50 | | - msg = 'not found' |
51 | | - except Exception as exc: |
52 | | - raise FetchError(cc) from exc # <7> |
| 45 | + async with semaphore: # <5> |
| 46 | + image = await get_flag(session, base_url, cc) |
| 47 | + except aiohttp.ClientResponseError as exc: |
| 48 | + if exc.status == 404: # <6> |
| 49 | + status = HTTPStatus.not_found |
| 50 | + msg = 'not found' |
| 51 | + else: |
| 52 | + raise FetchError(cc) from exc # <7> |
53 | 53 | else: |
54 | | - save_flag(image, cc.lower() + '.gif') # <8> |
| 54 | + save_flag(image, f'{cc}.gif') |
55 | 55 | status = HTTPStatus.ok |
56 | 56 | msg = 'OK' |
57 | | - |
58 | 57 | if verbose and msg: |
59 | 58 | print(cc, msg) |
60 | | - |
61 | 59 | return Result(status, cc) |
62 | | -# END FLAGS2_ASYNCIO_TOP |
| 60 | +# end::FLAGS2_ASYNCIO_TOP[] |
63 | 61 |
|
64 | | -# BEGIN FLAGS2_ASYNCIO_DOWNLOAD_MANY |
65 | | -async def downloader_coro(cc_list: list[str], |
66 | | - base_url: str, |
67 | | - verbose: bool, |
68 | | - concur_req: int) -> Counter[HTTPStatus]: # <1> |
| 62 | +# tag::FLAGS2_ASYNCIO_START[] |
| 63 | +async def supervisor(cc_list: list[str], |
| 64 | + base_url: str, |
| 65 | + verbose: bool, |
| 66 | + concur_req: int) -> Counter[HTTPStatus]: # <1> |
69 | 67 | counter: Counter[HTTPStatus] = Counter() |
70 | 68 | semaphore = asyncio.Semaphore(concur_req) # <2> |
71 | | - async with aiohttp.ClientSession() as session: # <8> |
| 69 | + async with aiohttp.ClientSession() as session: |
72 | 70 | to_do = [download_one(session, cc, base_url, semaphore, verbose) |
73 | 71 | for cc in sorted(cc_list)] # <3> |
74 | | - |
75 | 72 | to_do_iter = asyncio.as_completed(to_do) # <4> |
76 | 73 | if not verbose: |
77 | 74 | to_do_iter = tqdm.tqdm(to_do_iter, total=len(cc_list)) # <5> |
78 | | - for future in to_do_iter: # <6> |
| 75 | + for coro in to_do_iter: # <6> |
79 | 76 | try: |
80 | | - res = await future # <7> |
| 77 | + res = await coro # <7> |
81 | 78 | except FetchError as exc: # <8> |
82 | 79 | country_code = exc.country_code # <9> |
83 | 80 | try: |
84 | | - if exc.__cause__ is None: |
85 | | - error_msg = 'Unknown cause' |
86 | | - else: |
87 | | - error_msg = exc.__cause__.args[0] # <10> |
88 | | - except IndexError: |
89 | | - error_msg = exc.__cause__.__class__.__name__ # <11> |
| 81 | + error_msg = exc.__cause__.message # type: ignore # <10> |
| 82 | + except AttributeError: |
| 83 | + error_msg = 'Unknown cause' # <11> |
90 | 84 | if verbose and error_msg: |
91 | 85 | print(f'*** Error for {country_code}: {error_msg}') |
92 | 86 | status = HTTPStatus.error |
93 | 87 | else: |
94 | 88 | status = res.status |
95 | | - |
96 | 89 | counter[status] += 1 # <12> |
97 | | - |
98 | 90 | return counter # <13> |
99 | 91 |
|
100 | | - |
101 | 92 | def download_many(cc_list: list[str], |
102 | 93 | base_url: str, |
103 | 94 | verbose: bool, |
104 | 95 | concur_req: int) -> Counter[HTTPStatus]: |
105 | | - coro = downloader_coro(cc_list, base_url, verbose, concur_req) |
| 96 | + coro = supervisor(cc_list, base_url, verbose, concur_req) |
106 | 97 | counts = asyncio.run(coro) # <14> |
107 | 98 |
|
108 | 99 | return counts |
109 | 100 |
|
110 | | - |
111 | 101 | if __name__ == '__main__': |
112 | 102 | main(download_many, DEFAULT_CONCUR_REQ, MAX_CONCUR_REQ) |
113 | | -# END FLAGS2_ASYNCIO_DOWNLOAD_MANY |
| 103 | +# end::FLAGS2_ASYNCIO_START[] |
0 commit comments