Skip to content

Commit 6a3d28c

Browse files
authored
fix: fallback to old version of login (#537)
* fix: fallback to old version of login * fix: bug and add test
1 parent fbf1434 commit 6a3d28c

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

roborock/web_api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def _get_iot_login_info(self) -> IotLoginInfo:
9898
raise RoborockException(f"{response.get('msg')} - response code: {response_code}")
9999
country_code = response["data"]["countrycode"]
100100
country = response["data"]["country"]
101-
if country_code is not None and country is not None:
101+
if country_code is not None or country is not None:
102102
self._iot_login_info = IotLoginInfo(
103103
base_url=response["data"]["url"],
104104
country=country,
@@ -234,6 +234,9 @@ async def request_code(self) -> None:
234234

235235
async def request_code_v4(self) -> None:
236236
"""Request a code using the v4 endpoint."""
237+
if await self.country_code is None or await self.country is None:
238+
_LOGGER.info("No country code or country found, trying old version of request code.")
239+
return await self.request_code()
237240
try:
238241
self._login_limiter.try_acquire("login")
239242
except BucketFullException as ex:
@@ -304,6 +307,9 @@ async def code_login_v4(
304307
country = await self.country
305308
if country_code is None:
306309
country_code = await self.country_code
310+
if country_code is None or country is None:
311+
_LOGGER.info("No country code or country found, trying old version of code login.")
312+
return await self.code_login(code)
307313
header_clientid = self._get_header_client_id()
308314
x_mercy_ks = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(16))
309315
x_mercy_k = await self._sign_key_v3(x_mercy_ks)

tests/test_web_api.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,36 @@ async def test_url_cycling(mock_rest) -> None:
170170
)
171171
# Make sure we just have the three we tested for above.
172172
assert len(mock_rest.requests) == 3
173+
174+
175+
async def test_missing_country_login(mock_rest) -> None:
176+
"""Test that we cycle through the URLs correctly."""
177+
mock_rest.clear()
178+
# Make country None, but country code set.
179+
mock_rest.post(
180+
re.compile("https://usiot.roborock.com/api/v1/getUrlByEmail.*"),
181+
status=200,
182+
payload={
183+
"code": 200,
184+
"data": {"url": "https://usiot.roborock.com", "country": None, "countrycode": 1},
185+
"msg": "Success",
186+
},
187+
)
188+
# v4 is not mocked, so it would fail it were called.
189+
mock_rest.post(
190+
re.compile(r"https://.*iot\.roborock\.com/api/v1/loginWithCode.*"),
191+
status=200,
192+
payload={"code": 200, "data": USER_DATA, "msg": "success"},
193+
)
194+
mock_rest.post(
195+
re.compile(r"https://.*iot\.roborock\.com/api/v1/sendEmailCode.*"),
196+
status=200,
197+
payload={"code": 200, "data": None, "msg": "success"},
198+
)
199+
200+
client = RoborockApiClient("test@example.com")
201+
await client.request_code_v4()
202+
ud = await client.code_login_v4(4123)
203+
assert ud is not None
204+
# Ensure we have no surprise REST calls.
205+
assert len(mock_rest.requests) == 3

0 commit comments

Comments
 (0)