Skip to content

Commit cdbca80

Browse files
committed
ch22: examples
1 parent 1702717 commit cdbca80

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#!/usr/bin/env python3
2+
3+
"""Download flags of countries (with error handling).
4+
5+
asyncio async/await version using run_in_executor for save_flag.
6+
7+
"""
8+
9+
import asyncio
10+
from collections import Counter
11+
12+
import aiohttp
13+
import tqdm # type: ignore
14+
15+
from flags2_common import main, HTTPStatus, Result, save_flag
16+
17+
# default set low to avoid errors from remote site, such as
18+
# 503 - Service Temporarily Unavailable
19+
DEFAULT_CONCUR_REQ = 5
20+
MAX_CONCUR_REQ = 1000
21+
22+
23+
class FetchError(Exception):
24+
def __init__(self, country_code: str):
25+
self.country_code = country_code
26+
27+
async def get_flag(session: aiohttp.ClientSession,
28+
base_url: str,
29+
cc: str) -> bytes:
30+
url = f'{base_url}/{cc}/{cc}.gif'
31+
async with session.get(url) as resp:
32+
if resp.status == 200:
33+
return await resp.read()
34+
else:
35+
resp.raise_for_status()
36+
return bytes()
37+
38+
# tag::FLAGS3_ASYNCIO_GET_COUNTRY[]
39+
async def get_country(session: aiohttp.ClientSession, # <1>
40+
base_url: str,
41+
cc: str) -> str:
42+
url = f'{base_url}/{cc}/metadata.json'
43+
async with session.get(url) as resp:
44+
if resp.status == 200:
45+
metadata = await resp.json() # <2>
46+
return metadata.get('country', 'no name') # <3>
47+
else:
48+
resp.raise_for_status()
49+
return ''
50+
# end::FLAGS3_ASYNCIO_GET_COUNTRY[]
51+
52+
# tag::FLAGS3_ASYNCIO_DOWNLOAD_ONE[]
53+
async def download_one(session: aiohttp.ClientSession,
54+
cc: str,
55+
base_url: str,
56+
semaphore: asyncio.Semaphore,
57+
verbose: bool) -> Result:
58+
try:
59+
async with semaphore:
60+
image = await get_flag(session, base_url, cc) # <1>
61+
async with semaphore:
62+
country = await get_country(session, base_url, cc) # <2>
63+
except aiohttp.ClientResponseError as exc:
64+
if exc.status == 404:
65+
status = HTTPStatus.not_found
66+
msg = 'not found'
67+
else:
68+
raise FetchError(cc) from exc
69+
else:
70+
filename = country.replace(' ', '_') # <3>
71+
filename = f'{filename}.gif'
72+
loop = asyncio.get_running_loop()
73+
loop.run_in_executor(None,
74+
save_flag, image, filename)
75+
status = HTTPStatus.ok
76+
msg = 'OK'
77+
if verbose and msg:
78+
print(cc, msg)
79+
return Result(status, cc)
80+
# end::FLAGS3_ASYNCIO_DOWNLOAD_ONE[]
81+
82+
async def supervisor(cc_list: list[str],
83+
base_url: str,
84+
verbose: bool,
85+
concur_req: int) -> Counter[HTTPStatus]:
86+
counter: Counter[HTTPStatus] = Counter()
87+
semaphore = asyncio.Semaphore(concur_req)
88+
async with aiohttp.ClientSession() as session:
89+
to_do = [download_one(session, cc, base_url,
90+
semaphore, verbose)
91+
for cc in sorted(cc_list)]
92+
93+
to_do_iter = asyncio.as_completed(to_do)
94+
if not verbose:
95+
to_do_iter = tqdm.tqdm(to_do_iter, total=len(cc_list))
96+
for coro in to_do_iter:
97+
try:
98+
res = await coro
99+
except FetchError as exc:
100+
country_code = exc.country_code
101+
try:
102+
error_msg = exc.__cause__.message # type: ignore
103+
except AttributeError:
104+
error_msg = 'Unknown cause'
105+
if verbose and error_msg:
106+
print(f'*** Error for {country_code}: {error_msg}')
107+
status = HTTPStatus.error
108+
else:
109+
status = res.status
110+
111+
counter[status] += 1
112+
113+
return counter
114+
115+
116+
def download_many(cc_list: list[str],
117+
base_url: str,
118+
verbose: bool,
119+
concur_req: int) -> Counter[HTTPStatus]:
120+
coro = supervisor(cc_list, base_url, verbose, concur_req)
121+
counts = asyncio.run(coro) # <14>
122+
123+
return counts
124+
125+
126+
if __name__ == '__main__':
127+
main(download_many, DEFAULT_CONCUR_REQ, MAX_CONCUR_REQ)

0 commit comments

Comments
 (0)