Skip to content
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

+ Updated project structure to use `pyproject.toml`
+ Add `build` to `dev` dependencies
+ Support `argparse.ArgumentDefaultsHelpFormatter` by properly initializing
argparse Action default.
+ Relax the guarantee that default_factory is called "exactly once per parse"
but still guarantee that default_factory is called "for each parse".

## [2.0.1] - Unreleased

Expand Down
6 changes: 6 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ Using defaults:
>>> print(parser.parse_args([]))
Options(x=1, y=2, z=3.14)

Using ArgumentDefaultsHelpFormatter is supported. If a default_factory is used
in the dataclass it will be called for a fresh result on each parse. The default
value provided in --help is initialized at parser setup time.

Implementation of default_factory with side-effects should not be used.

Enabling choices for an option:

.. code-block:: pycon
Expand Down
40 changes: 32 additions & 8 deletions argparse_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@

import argparse
from argparse import BooleanOptionalAction
from argparse import Namespace
from typing import (
TypeVar,
Optional,
Expand Down Expand Up @@ -275,7 +276,8 @@ def parse_args(options_class: Type[OptionsType], args: ArgsType = None) -> Optio
"""Parse arguments and return as the dataclass type."""
parser = argparse.ArgumentParser()
_add_dataclass_options(options_class, parser)
kwargs = _get_kwargs(parser.parse_args(args))
initial_namespace = _init_namespace(options_class)
kwargs = _get_kwargs(parser.parse_args(args, initial_namespace))
return options_class(**kwargs)


Expand All @@ -287,7 +289,9 @@ def parse_known_args(
"""
parser = argparse.ArgumentParser()
_add_dataclass_options(options_class, parser)
namespace, others = parser.parse_known_args(args=args)
initial_namespace = _init_namespace(options_class)
namespace, others = parser.parse_known_args(args, initial_namespace)
assert namespace == initial_namespace
kwargs = _get_kwargs(namespace)
return options_class(**kwargs), others

Expand Down Expand Up @@ -366,7 +370,10 @@ def _add_dataclass_options(
if field.default == field.default_factory == MISSING and not positional:
kwargs["required"] = True
else:
kwargs["default"] = MISSING
if field.default_factory is not MISSING:
kwargs["default"] = field.default_factory()
else:
kwargs["default"] = field.default

if field.type is bool:
_handle_bool_type(field, args, kwargs)
Expand All @@ -389,6 +396,22 @@ def _add_dataclass_options(
parser.add_argument(*args, **kwargs)


def _init_namespace(options_class: Type[OptionsType]) -> Namespace:
"""Init a namespace for passing into `argparse.ArgumentParser.parse_args`

Assign a flag value (MISSING) for all fields which have a default at the
dataclass level, this prevents argparse from assigning to those fields.
"""
ns = Namespace()
assert is_dataclass(options_class)
for field in fields(options_class):
if field.default is not MISSING:
setattr(ns, field.name, field.default)
elif field.default_factory is not MISSING:
setattr(ns, field.name, field.default_factory())
return ns


def _get_kwargs(namespace: argparse.Namespace) -> dict[str, Any]:
"""Converts a Namespace to a dictionary containing the items that
to be used as keyword arguments to the Options class.
Expand Down Expand Up @@ -469,10 +492,9 @@ def __init__(self, options_class: Type[OptionsType], *args, **kwargs):

def parse_args(self, args: ArgsType = None, namespace=None) -> OptionsType:
"""Parse arguments and return as the dataclass type."""
if namespace is not None:
raise ValueError("supplying a namespace is not allowed")
kwargs = _get_kwargs(super().parse_args(args))
return self._options_type(**kwargs)
opts = super().parse_args(args, namespace)
assert isinstance(opts, self._options_type)
return opts

def parse_known_args(
self, args: ArgsType = None, namespace=None
Expand All @@ -482,7 +504,9 @@ def parse_known_args(
"""
if namespace is not None:
raise ValueError("supplying a namespace is not allowed")
namespace, others = super().parse_known_args(args=args)
initial_namespace = _init_namespace(self._options_type)
namespace, others = super().parse_known_args(args, initial_namespace)
assert namespace == initial_namespace
kwargs = _get_kwargs(namespace)
return self._options_type(**kwargs), others

Expand Down
35 changes: 28 additions & 7 deletions tests/test_argumentparser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from argparse import ArgumentDefaultsHelpFormatter
import sys
import unittest
import datetime as dt
Expand Down Expand Up @@ -170,27 +171,31 @@ class Parameters:

def test_default_factory_2(self):
factory_calls = 0
factory_result = "0"

def factory_func():
nonlocal factory_calls
factory_calls += 1
return f"Default Message: {factory_calls}"
return f"Default Message: {factory_result}"

@dataclass
class Parameters:
message: str = field(default_factory=factory_func)

params = ArgumentParser(Parameters).parse_args([])
parser = ArgumentParser(Parameters)
factory_result = "1"
params = parser.parse_args([])
self.assertEqual(params.message, "Default Message: 1")
self.assertEqual(factory_calls, 1)
self.assertGreaterEqual(factory_calls, 1)

params = ArgumentParser(Parameters).parse_args(["--message", "User message"])
params = parser.parse_args(["--message", "User message"])
self.assertEqual(params.message, "User message")
self.assertEqual(factory_calls, 1)
self.assertGreaterEqual(factory_calls, 1)

params = ArgumentParser(Parameters).parse_args([])
factory_result = "2"
params = parser.parse_args([])
self.assertEqual(params.message, "Default Message: 2")
self.assertEqual(factory_calls, 2)
self.assertGreaterEqual(factory_calls, 1)

def test_optional_args(self):
@dataclass
Expand Down Expand Up @@ -308,6 +313,22 @@ class Args:
self.assertEqual(10, params.num_of_foo)
self.assertFalse(params.is_fun)

def test_default_help(self):
@dataclass
class Opt:
answer: int = field(
default=42,
metadata=dict(help="answer"),
)

"""Test ArgumentsDefaultsHelpFormatter works as expected."""
parser = ArgumentParser(
Opt,
formatter_class=ArgumentDefaultsHelpFormatter,
)
help_message = parser.format_help()
assert "answer (default: 42)" in help_message


if __name__ == "__main__":
unittest.main()
11 changes: 7 additions & 4 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,27 +263,30 @@ class Parameters:

def test_default_factory_2(self):
factory_calls = 0
factory_result = "0"

def factory_func():
nonlocal factory_calls
factory_calls += 1
return f"Default Message: {factory_calls}"
return f"Default Message: {factory_result}"

@dataclass
class Parameters:
message: str = field(default_factory=factory_func)

factory_result = "1"
params = parse_args(Parameters, [])
self.assertEqual(params.message, "Default Message: 1")
self.assertEqual(factory_calls, 1)
self.assertGreaterEqual(factory_calls, 1)

params = parse_args(Parameters, ["--message", "User message"])
self.assertEqual(params.message, "User message")
self.assertEqual(factory_calls, 1)
self.assertGreaterEqual(factory_calls, 1)

factory_result = "2"
params = parse_args(Parameters, [])
self.assertEqual(params.message, "Default Message: 2")
self.assertEqual(factory_calls, 2)
self.assertGreaterEqual(factory_calls, 1)

def test_parse_known_args(self):
@dataclass
Expand Down