Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/68678.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Handle corrupted grains cache msgpack data by refreshing the cache.

36 changes: 27 additions & 9 deletions salt/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,20 @@ def _format_cached_grains(cached_grains):
return cached_grains


def _invalidate_grains_cache(cfn, reason=None):
"""
Remove an invalid grains cache file to allow refresh.
"""
if reason:
log.warning("Invalid grains cache (%s). Removing %s and refreshing.", reason, cfn)
else:
log.warning("Invalid grains cache. Removing %s and refreshing.", cfn)
try:
os.remove(cfn)
except OSError as exc:
log.debug("Failed to remove grains cache file %s: %s", cfn, exc)


def _load_cached_grains(opts, cfn):
"""
Returns the grains cached in cfn, or None if the cache is too old or is
Expand Down Expand Up @@ -1070,17 +1084,21 @@ def _load_cached_grains(opts, cfn):
log.debug("Retrieving grains from cache")
try:
with salt.utils.files.fopen(cfn, "rb") as fp_:
cached_grains = salt.utils.data.decode(
salt.payload.load(fp_), preserve_tuples=True
)
if not cached_grains:
log.debug("Cached grains are empty, cache might be corrupted. Refreshing.")
return None

return _format_cached_grains(cached_grains)
except OSError:
cached_grains = salt.payload.load(fp_, raise_on_error=False)
except OSError as exc:
log.debug("Failed to read grains cache file %s: %s", cfn, exc)
return None
if not cached_grains:
_invalidate_grains_cache(cfn, "empty or unreadable")
return None
try:
cached_grains = salt.utils.data.decode(cached_grains, preserve_tuples=True)
except Exception as exc: # pylint: disable=broad-except
_invalidate_grains_cache(cfn, f"decode error: {exc}")
return None

return _format_cached_grains(cached_grains)


def grains(opts, force_refresh=False, proxy=None, context=None, loaded_base_name=None):
"""
Expand Down
32 changes: 20 additions & 12 deletions salt/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def format_payload(enc, **kwargs):
return package(payload)


def loads(msg, encoding=None, raw=False):
def loads(msg, encoding=None, raw=False, log_error=True):
"""
Run the correct loads serialization format

Expand All @@ -70,6 +70,7 @@ def loads(msg, encoding=None, raw=False):
been lost in this case) to what the encoding is
set as. In this case, it will fail if any of
the contents cannot be converted.
:param log_error: Log deserialization failures when True.
"""
try:

Expand Down Expand Up @@ -97,14 +98,15 @@ def ext_type_decoder(code, data):
if encoding is None and not raw:
ret = salt.transport.frame.decode_embedded_strs(ret)
except Exception as exc: # pylint: disable=broad-except
log.critical(
"Could not deserialize msgpack message. This often happens "
"when trying to read a file not in binary mode. "
"To see message payload, enable debug logging and retry. "
"Exception: %s",
exc,
)
log.debug("Msgpack deserialization failure on message: %s", msg)
if log_error:
log.critical(
"Could not deserialize msgpack message. This often happens "
"when trying to read a file not in binary mode. "
"To see message payload, enable debug logging and retry. "
"Exception: %s",
salt.utils.msgpack.format_exception(exc),
)
log.debug("Msgpack deserialization failure on message: %s", msg)
exc_msg = "Could not deserialize msgpack message. See log for more info."
raise SaltDeserializationError(exc_msg) from exc
finally:
Expand Down Expand Up @@ -200,14 +202,20 @@ def verylong_encoder(obj, context):
)


def load(fn_):
def load(fn_, raise_on_error=True):
"""
Run the correct serialization to load a file
"""
data = fn_.read()
fn_.close()
if data:
return loads(data, encoding="utf-8")
if not data:
return None
try:
return loads(data, encoding="utf-8", log_error=raise_on_error)
except SaltDeserializationError:
if raise_on_error:
raise
return None


def dump(msg, fn_):
Expand Down
23 changes: 23 additions & 0 deletions salt/utils/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@ class _exceptions:

exceptions = _exceptions()


def is_extra_data_exception(exc):
"""
Return True if the exception is a msgpack ExtraData error.
"""
extra_data = getattr(exceptions, "ExtraData", None)
return extra_data is not None and isinstance(exc, extra_data)


def format_exception(exc):
"""
Return a helpful string for msgpack exceptions.
"""
if is_extra_data_exception(exc):
extra = getattr(exc, "extra", None)
try:
extra_len = len(extra) if extra is not None else None
except TypeError:
extra_len = None
if extra_len is not None:
return f"{exc} (extra {extra_len} bytes)"
return str(exc)

# One-to-one mappings
Packer = msgpack.Packer
ExtType = msgpack.ExtType
Expand Down
13 changes: 13 additions & 0 deletions tests/pytests/unit/loader/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import salt.exceptions
import salt.loader
import salt.loader.lazy
import salt.utils.msgpack
from tests.support.mock import MagicMock, patch


Expand Down Expand Up @@ -57,6 +58,18 @@ def test_custom_grain_with_annotations(minion_opts, grains_dir):
assert grains.get("example") == "42"


def test_load_cached_grains_invalid_msgpack(tmp_path):
cfn = tmp_path / "grains.cache.p"
payload = salt.utils.msgpack.dumps({"os": "Darwin"})
cfn.write_bytes(payload + payload)
opts = salt.config.DEFAULT_MINION_OPTS.copy()
opts["grains_cache_expiration"] = 300
opts["refresh_grains_cache"] = False
ret = salt.loader._load_cached_grains(opts, str(cfn))
assert ret is None
assert not cfn.exists()


def test_raw_mod_functions():
"Ensure functions loaded by raw_mod are LoaderFunc instances"
opts = {
Expand Down
16 changes: 16 additions & 0 deletions tests/pytests/unit/test_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
"""

import copy
import io
import datetime
import logging
from collections import OrderedDict

import pytest
import zmq

import salt.exceptions
Expand Down Expand Up @@ -236,6 +238,20 @@ def test_format_payload():
assert payload == expected


def test_load_returns_none_on_deserialization_error():
bad = salt.utils.msgpack.dumps({"a": 1}) + salt.utils.msgpack.dumps({"b": 2})
fp_ = io.BytesIO(bad)
ret = salt.payload.load(fp_, raise_on_error=False)
assert ret is None


def test_load_raises_on_deserialization_error():
bad = salt.utils.msgpack.dumps({"a": 1}) + salt.utils.msgpack.dumps({"b": 2})
fp_ = io.BytesIO(bad)
with pytest.raises(salt.exceptions.SaltDeserializationError):
salt.payload.load(fp_)


def test_SREQ_init():
req = salt.payload.SREQ(
"tcp://salt:3434", id_=b"id", serial="msgpack", linger=1, opts=None
Expand Down