Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,17 @@ For compatibility the default is to convert field names to `camelCase`. You can
MyMessage().to_dict(casing=betterproto.Casing.SNAKE)
```

#### Proto3 canonical JSON

By default, enum values are serialized using their stripped Python names (e.g. `"HEARTS"` for a proto enum value `SUIT_HEARTS`). To use the original `.proto` enum names as required by the [proto3 JSON spec](https://protobuf.dev/programming-guides/json/), pass `proto3_json=True`:

```python
Card(suit=Suit.HEARTS).to_dict(proto3_json=True)
# {"suit": "SUIT_HEARTS"} instead of {"suit": "HEARTS"}
```

When deserializing, `from_dict()` accepts both formats (stripped names, full proto names, and integer values).

### Determining if a message was sent

Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example.
Expand Down
48 changes: 38 additions & 10 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,10 @@ def FromString(cls: Type[T], data: bytes) -> T:
return cls().parse(data)

def to_dict(
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
self,
casing: Casing = Casing.CAMEL,
include_default_values: bool = False,
proto3_json: bool = False,
) -> Dict[str, Any]:
"""
Returns a JSON serializable dict representation of this object.
Expand All @@ -1419,6 +1422,10 @@ def to_dict(
If ``True`` will include the default values of fields. Default is ``False``.
E.g. an ``int32`` field will be included with a value of ``0`` if this is
set to ``True``, otherwise this would be ignored.
proto3_json: :class:`bool`
If ``True`` will use proto3 canonical JSON format for enum values,
serializing them with their original .proto names (e.g. "MARK_TYPE_BOLD")
instead of the stripped Python names (e.g. "BOLD"). Default is ``False``.

Returns
--------
Expand Down Expand Up @@ -1466,7 +1473,8 @@ def to_dict(
value = [_Duration.delta_to_json(i) for i in value]
else:
value = [
i.to_dict(casing, include_default_values) for i in value
i.to_dict(casing, include_default_values, proto3_json)
for i in value
]
if value or include_default_values:
output[cased_name] = value
Expand All @@ -1480,12 +1488,16 @@ def to_dict(
field_name=field_name, meta=meta
)
):
output[cased_name] = value.to_dict(casing, include_default_values)
output[cased_name] = value.to_dict(
casing, include_default_values, proto3_json
)
elif meta.proto_type == TYPE_MAP:
output_map = {**value}
for k in value:
if hasattr(value[k], "to_dict"):
output_map[k] = value[k].to_dict(casing, include_default_values)
output_map[k] = value[k].to_dict(
casing, include_default_values, proto3_json
)

if value or include_default_values:
output[cased_name] = output_map
Expand Down Expand Up @@ -1514,24 +1526,32 @@ def to_dict(
else:
output[cased_name] = b64encode(value).decode("utf8")
elif meta.proto_type == TYPE_ENUM:

def _enum_name(member):
if proto3_json:
return getattr(member, "proto_name", None) or member.name
return member.name

if field_is_repeated:
enum_class = field_types[field_name].__args__[0]
if isinstance(value, typing.Iterable) and not isinstance(
value, str
):
output[cased_name] = [enum_class(el).name for el in value]
output[cased_name] = [
_enum_name(enum_class(el)) for el in value
]
else:
# transparently upgrade single value to repeated
output[cased_name] = [enum_class(value).name]
output[cased_name] = [_enum_name(enum_class(value))]
elif value is None:
if include_default_values:
output[cased_name] = value
elif meta.optional:
enum_class = field_types[field_name].__args__[0]
output[cased_name] = enum_class(value).name
output[cased_name] = _enum_name(enum_class(value))
else:
enum_class = field_types[field_name] # noqa
output[cased_name] = enum_class(value).name
output[cased_name] = _enum_name(enum_class(value))
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if field_is_repeated:
output[cased_name] = [_dump_float(n) for n in value]
Expand Down Expand Up @@ -1591,10 +1611,18 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
)
elif meta.proto_type == TYPE_ENUM:
enum_cls = cls._betterproto.cls_by_field[field_name]

def _parse_enum(e, ec=enum_cls):
if isinstance(e, int):
return ec.try_value(e)
return ec.from_string(e)

if isinstance(value, list):
value = [enum_cls.from_string(e) for e in value]
value = [_parse_enum(e) for e in value]
elif isinstance(value, int):
value = _parse_enum(value)
elif isinstance(value, str):
value = enum_cls.from_string(value)
value = _parse_enum(value)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
value = (
[_parse_float(n) for n in value]
Expand Down
41 changes: 33 additions & 8 deletions src/betterproto/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def _is_descriptor(obj: object) -> bool:
class EnumType(EnumMeta if TYPE_CHECKING else type):
_value_map_: Mapping[int, Enum]
_member_map_: Mapping[str, Enum]
_proto_names_: Dict[str, str] # Maps Python name -> original proto name

def __new__(
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
) -> Self:
value_map = {}
member_map = {}
proto_names = namespace.pop("_proto_names_", {})

new_mcs = type(
f"{name}Type",
Expand All @@ -50,7 +52,11 @@ def __new__(
+ [EnumType, type]
)
), # reorder the bases so EnumType and type are last to avoid conflicts
{"_value_map_": value_map, "_member_map_": member_map},
{
"_value_map_": value_map,
"_member_map_": member_map,
"_proto_names_": proto_names,
},
)

members = {
Expand All @@ -71,7 +77,8 @@ def __new__(
for name, value in members.items():
member = value_map.get(value)
if member is None:
member = cls.__new__(cls, name=name, value=value) # type: ignore
proto_name = proto_names.get(name, name)
member = cls.__new__(cls, name=name, value=value, proto_name=proto_name) # type: ignore
value_map[value] = member
member_map[name] = member
type.__setattr__(new_mcs, name, member)
Expand Down Expand Up @@ -123,17 +130,27 @@ class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):

name: Optional[str]
value: int
proto_name: Optional[str]

if not TYPE_CHECKING:

def __new__(cls, *, name: Optional[str], value: int) -> Self:
def __new__(
cls, *, name: Optional[str], value: int, proto_name: Optional[str] = None
) -> Self:
self = super().__new__(cls, value)
super().__setattr__(self, "name", name)
super().__setattr__(self, "value", value)
# proto_name is the original name from the .proto file (e.g. "MARK_TYPE_BOLD")
# used for proto3 canonical JSON serialization
super().__setattr__(self, "proto_name", proto_name or name)
return self

def __getnewargs_ex__(self) -> Tuple[Tuple[()], Dict[str, Any]]:
return (), {"name": self.name, "value": self.value}
return (), {
"name": self.name,
"value": self.value,
"proto_name": self.proto_name,
}

def __str__(self) -> str:
return self.name or "None"
Expand Down Expand Up @@ -181,6 +198,9 @@ def try_value(cls, value: int = 0) -> Self:
def from_string(cls, name: str) -> Self:
"""Return the value which corresponds to the string name.

Accepts both the Python member name (e.g. "BOLD") and the original
proto name (e.g. "MARK_TYPE_BOLD") per the proto3 JSON spec.

Parameters
-----------
name: :class:`str`
Expand All @@ -191,7 +211,12 @@ def from_string(cls, name: str) -> Self:
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name]
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
# Try stripped Python name first
member = cls._member_map_.get(name)
if member is not None:
return member
# Try original proto name
for m in cls._member_map_.values():
if getattr(m, "proto_name", None) == name:
return m
raise ValueError(f"Unknown value {name} for enum {cls.__name__}")
2 changes: 2 additions & 0 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ class EnumEntry:
name: str
value: int
comment: str
proto_name: str # Original name from .proto file

def __post_init__(self) -> None:
# Get entries/allowed values for this Enum
Expand All @@ -662,6 +663,7 @@ def __post_init__(self) -> None:
comment=get_comment(
proto_file=self.source_file, path=self.path + [2, entry_number]
),
proto_name=entry_proto_value.name,
)
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
]
Expand Down
3 changes: 3 additions & 0 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ class {{ enum.py_name }}(betterproto.Enum):
{{ enum.comment }}

{% endif %}
# Mapping from Python member names to original proto names (for canonical JSON)
_proto_names_ = { {% for entry in enum.entries %}"{{ entry.name }}": "{{ entry.proto_name }}", {% endfor %}}

{% for entry in enum.entries %}
{{ entry.name }} = {{ entry.value }}
{% if entry.comment %}
Expand Down
101 changes: 101 additions & 0 deletions tests/test_proto3_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Test proto3 canonical JSON enum serialization.

Verifies that to_dict(proto3_json=True) uses the original .proto enum
names per the proto3 JSON spec, while to_dict() (default) preserves
the existing stripped Python names for backwards compatibility.
"""

from typing import List

import pytest

import betterproto


class Suit(betterproto.Enum):
_proto_names_ = {
"UNSPECIFIED": "SUIT_UNSPECIFIED",
"HEARTS": "SUIT_HEARTS",
"DIAMONDS": "SUIT_DIAMONDS",
"CLUBS": "SUIT_CLUBS",
"SPADES": "SUIT_SPADES",
}

UNSPECIFIED = 0
HEARTS = 1
DIAMONDS = 2
CLUBS = 3
SPADES = 4


from dataclasses import dataclass


@dataclass(eq=False, repr=False)
class Card(betterproto.Message):
suit: "Suit" = betterproto.enum_field(1)
value: int = betterproto.int32_field(2)


@dataclass(eq=False, repr=False)
class Hand(betterproto.Message):
cards: List["Card"] = betterproto.message_field(1)


class TestProto3JsonSerialization:
"""to_dict(proto3_json=True) uses full proto names."""

def test_single_enum(self):
card = Card(suit=Suit.HEARTS, value=10)
d = card.to_dict(proto3_json=True)
assert d["suit"] == "SUIT_HEARTS"

def test_default_uses_stripped_name(self):
card = Card(suit=Suit.HEARTS, value=10)
d = card.to_dict()
assert d["suit"] == "HEARTS"

def test_nested_propagates(self):
hand = Hand(cards=[Card(suit=Suit.SPADES, value=1)])
d = hand.to_dict(proto3_json=True)
assert d["cards"][0]["suit"] == "SUIT_SPADES"


class TestProto3JsonDeserialization:
"""from_dict accepts both proto names and stripped names."""

def test_accept_proto_name(self):
card = Card().from_dict({"suit": "SUIT_HEARTS", "value": 10})
assert card.suit == Suit.HEARTS

def test_accept_stripped_name(self):
card = Card().from_dict({"suit": "HEARTS", "value": 10})
assert card.suit == Suit.HEARTS

def test_accept_integer(self):
card = Card().from_dict({"suit": 1, "value": 10})
assert card.suit == Suit.HEARTS

def test_round_trip_proto3(self):
original = Card(suit=Suit.DIAMONDS, value=7)
d = original.to_dict(proto3_json=True)
restored = Card().from_dict(d)
assert restored.suit == Suit.DIAMONDS
assert restored.value == 7

def test_round_trip_default(self):
original = Card(suit=Suit.CLUBS, value=3)
d = original.to_dict()
restored = Card().from_dict(d)
assert restored.suit == Suit.CLUBS


class TestEnumWithoutProtoNames:
"""Enums without _proto_names_ (backwards compat)."""

def test_proto3_json_falls_back_to_name(self):
"""Without _proto_names_, proto3_json=True uses the Python name."""
from tests.test_enum import Colour

# Colour doesn't have _proto_names_, so proto_name == name
assert Colour.RED.proto_name == "RED"