Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions docs/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ As explained in the [documentation](index.md), case functions have no requiremen
### `@case`

```python
@case(id=None, # type: str # noqa
tags=None, # type: Union[Any, Iterable[Any]]
marks=(), # type: Union[MarkDecorator, Iterable[MarkDecorator]]
@case(id: str = None, # noqa
tags: Union[Any, Iterable[Any]] = None,
marks: Union[MarkDecorator, Iterable[MarkDecorator]] = (),
)
```

Expand All @@ -54,8 +54,7 @@ def case_hi():
### `@with_case_tags`

```python
@with_case_tags(*tags, # type: Any
):
@with_case_tags(*tags: Any):
```

This decorator can be applied to a class defining cases to apply multiple
Expand Down Expand Up @@ -105,8 +104,8 @@ class CasesContainerClass:
### `copy_case_info`

```python
def copy_case_info(from_fun, # type: Callable
to_fun # type: Callable
def copy_case_info(from_fun: Callable,
to_fun: Callable
):
```

Expand All @@ -116,8 +115,8 @@ Copies all information from case function `from_fun` to `to_fun`.
### `set_case_id`

```python
def set_case_id(id, # type: str
case_func # type: Callable
def set_case_id(id: str,
case_func: Callable
):
```

Expand All @@ -127,8 +126,8 @@ Sets an explicit id on case function `case_func`.
### `get_case_id`

```python
def get_case_id(case_func, # type: Callable
prefix_for_default_ids='case_' # type: str
def get_case_id(case_func: Callable,
prefix_for_default_ids: str = 'case_'
):
```

Expand All @@ -147,9 +146,9 @@ If a custom id is not present, a case id is automatically created from the funct
### `get_case_marks`

```python
def get_case_marks(case_func, # type: Callable
concatenate_with_fun_marks=False, # type: bool
as_decorators=False # type: bool
def get_case_marks(case_func: Callable,
concatenate_with_fun_marks: bool = False,
as_decorators: bool = False
):
```

Expand All @@ -169,7 +168,7 @@ There are currently two ways to place a mark on a case function: either with `@p
### `get_case_tags`

```python
def get_case_tags(case_func # type: Callable
def get_case_tags(case_func: Callable
):
```

Expand All @@ -183,9 +182,9 @@ Return the tags on this case function or an empty tuple.
### `matches_tag_query`

```python
def matches_tag_query(case_fun, # type: Callable
has_tag=None, # type: Union[str, Iterable[str]]
filter=None, # type: Union[Callable[[Callable], bool], Iterable[Callable[[Callable], bool]]] # noqa
def matches_tag_query(case_fun: Callable,
has_tag: Union[str, Iterable[str]] = None,
filter: Union[Callable[[Callable], bool], Iterable[Callable[[Callable], bool]]] = None, # noqa
):
```

Expand All @@ -209,9 +208,9 @@ Returns True if the case function is selected by the query:
### `is_case_class`

```python
def is_case_class(cls, # type: Any
case_marker_in_name='Case', # type: str
check_name=True # type: bool
def is_case_class(cls: Any,
case_marker_in_name: str = 'Case',
check_name: bool = True
):
```

Expand All @@ -230,9 +229,9 @@ Returns True if the given object is a class and, if `check_name=True` (default),
### `is_case_function`

```python
def is_case_function(f, # type: Any
prefix='case_', # type: str
check_prefix=True # type: bool
def is_case_function(f: Any,
prefix: str = 'case_',
check_prefix: bool = True
):
```

Expand Down
99 changes: 40 additions & 59 deletions src/pytest_cases/case_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
# License: 3-clause BSD, <https://github.com/smarie/python-pytest-cases/blob/master/LICENSE>
from copy import copy
from decopatch import function_decorator, DECORATED

try: # python 3.5+
from typing import Callable, Union, Optional, Any, Tuple, Iterable, List, Set
except ImportError:
pass
from typing import Callable, Union, Optional, Any, Iterable

from .common_mini_six import string_types
from .common_pytest import safe_isclass
Expand Down Expand Up @@ -44,22 +40,22 @@ class _CaseInfo(object):
__slots__ = ('id', 'marks', 'tags')

def __init__(self,
id=None, # type: str
marks=(), # type: Tuple[MarkDecorator, ...]
tags=() # type: Tuple[Any]
id: str = None,
marks: tuple[MarkDecorator, ...] = (),
tags: tuple[Any, ...] = ()
):
self.id = id
self.marks = marks # type: Tuple[MarkDecorator, ...]
self.tags = ()
self.id: str = id
self.marks: tuple[MarkDecorator, ...] = marks
self.tags: tuple[Any, ...] = ()
self.add_tags(tags)

def __repr__(self):
return "_CaseInfo(id=%r,marks=%r,tags=%r)" % (self.id, self.marks, self.tags)

@classmethod
def get_from(cls,
case_func, # type: Callable
create_if_missing=False # type: bool
case_func: Callable,
create_if_missing: bool = False
):
""" Return the _CaseInfo associated with case_fun or None

Expand All @@ -73,15 +69,11 @@ def get_from(cls,
ci.attach_to(case_func)
return ci

def attach_to(self,
case_func # type: Callable
):
def attach_to(self, case_func: Callable):
"""attach this case_info to the given case function"""
setattr(case_func, CASE_FIELD, self)

def add_tags(self,
tags # type: Union[Any, Union[List, Set, Tuple]]
):
def add_tags(self, tags: Union[Any, Union[list, set, tuple]]):
"""add the given tag or tags"""
if tags:
if isinstance(tags, string_types) or not isinstance(tags, (set, list, tuple)):
Expand All @@ -90,9 +82,7 @@ def add_tags(self,

self.tags += tuple(tags)

def matches_tag_query(self,
has_tag=None, # type: Union[str, Iterable[str]]
):
def matches_tag_query(self, has_tag: Union[str, Iterable[str]] = None):
"""
Returns True if the case function with this case_info is selected by the query

Expand All @@ -103,17 +93,17 @@ def matches_tag_query(self,

@classmethod
def copy_info(cls,
from_case_func,
to_case_func):
from_case_func: Callable,
to_case_func: Callable):
case_info = cls.get_from(from_case_func)
if case_info is not None:
# there is something to copy: do it
cp = copy(case_info)
cp.attach_to(to_case_func)


def _tags_match_query(tags, # type: Iterable[str]
has_tag # type: Optional[Union[str, Iterable[str]]]
def _tags_match_query(tags: Iterable[str],
has_tag: Optional[Union[str, Iterable[str]]]
):
"""Internal routine to determine is all tags in `has_tag` are persent in `tags`
Note that `has_tag` can be a single tag, or none
Expand All @@ -127,23 +117,23 @@ def _tags_match_query(tags, # type: Iterable[str]
return all(t in tags for t in has_tag)


def copy_case_info(from_fun, # type: Callable
to_fun # type: Callable
def copy_case_info(from_fun: Callable,
to_fun: Callable
):
"""Copy all information from case function `from_fun` to `to_fun`."""
_CaseInfo.copy_info(from_fun, to_fun)


def set_case_id(id, # type: str
case_func # type: Callable
def set_case_id(id: str,
case_func: Callable
):
"""Set an explicit id on case function `case_func`."""
ci = _CaseInfo.get_from(case_func, create_if_missing=True)
ci.id = id


def get_case_id(case_func, # type: Callable
prefix_for_default_ids=CASE_PREFIX_FUN # type: str
def get_case_id(case_func: Callable,
prefix_for_default_ids: str = CASE_PREFIX_FUN
):
"""Return the case id associated with this case function.

Expand Down Expand Up @@ -176,11 +166,10 @@ def get_case_id(case_func, # type: Callable
# def add_case_marks: no need, equivalent of @case(marks) or @mark


def get_case_marks(case_func, # type: Callable
concatenate_with_fun_marks=False, # type: bool
as_decorators=False # type: bool
):
# type: (...) -> Union[Tuple[Mark, ...], Tuple[MarkDecorator, ...]]
def get_case_marks(case_func: Callable,
concatenate_with_fun_marks: bool = False,
as_decorators: bool = False
) -> Union[tuple[Mark, ...], tuple[MarkDecorator, ...]]:
"""Return the marks that are on the case function.

There are currently two ways to place a mark on a case function: either with `@pytest.mark.<name>` or in
Expand Down Expand Up @@ -218,16 +207,15 @@ def get_case_marks(case_func, # type: Callable
# ci.add_tags(tags)


def get_case_tags(case_func # type: Callable
):
def get_case_tags(case_func: Callable):
"""Return the tags on this case function or an empty tuple"""
ci = _CaseInfo.get_from(case_func)
return ci.tags if ci is not None else ()


def matches_tag_query(case_fun, # type: Callable
has_tag=None, # type: Union[str, Iterable[str]]
filter=None, # type: Union[Callable[[Callable], bool], Iterable[Callable[[Callable], bool]]] # noqa
def matches_tag_query(case_fun: Callable,
has_tag: Union[str, Iterable[str]] = None,
filter: Union[Callable[[Callable], bool], Iterable[Callable[[Callable], bool]]] = None,
):
"""
This function is the one used by `@parametrize_with_cases` to filter the case functions collected. It can be used
Expand Down Expand Up @@ -275,16 +263,10 @@ def matches_tag_query(case_fun, # type: Callable
return selected


try:
SeveralMarkDecorators = Union[Tuple[MarkDecorator, ...], List[MarkDecorator], Set[MarkDecorator]]
except: # noqa
pass


@function_decorator
def case(id=None, # type: str # noqa
tags=None, # type: Union[Any, Iterable[Any]]
marks=(), # type: Union[MarkDecorator, SeveralMarkDecorators]
def case(id: str = None,
tags: Union[Any, Iterable[Any]] = None,
marks: Union[MarkDecorator, tuple[MarkDecorator, ...], list[MarkDecorator], set[MarkDecorator]] = (),
case_func=DECORATED # noqa
):
"""
Expand All @@ -311,9 +293,9 @@ def case_hi():
return case_func


def is_case_class(cls, # type: Any
case_marker_in_name=CASE_PREFIX_CLS, # type: str
check_name=True # type: bool
def is_case_class(cls: Any,
case_marker_in_name: str = CASE_PREFIX_CLS,
check_name: bool = True
):
"""
This function is the one used by `@parametrize_with_cases` to collect cases within classes. It can be used manually
Expand All @@ -335,9 +317,9 @@ def is_case_class(cls, # type: Any
GEN_BY_US = '_pytestcases_gen'


def is_case_function(f, # type: Any
prefix=CASE_PREFIX_FUN, # type: str
check_prefix=True # type: bool
def is_case_function(f: Any,
prefix: str = CASE_PREFIX_FUN,
check_prefix: bool = True
):
"""
This function is the one used by `@parametrize_with_cases` to collect cases. It can be used manually for
Expand All @@ -363,7 +345,7 @@ def is_case_function(f, # type: Any
else:
try:
return f.__name__.startswith(prefix) if check_prefix else True
except:
except: # noqa
# GH#287: safe fallback
return False

Expand Down Expand Up @@ -397,4 +379,3 @@ def _decorator(cls):
case_info.add_tags(tags_to_add)
return cls
return _decorator

Loading
Loading