Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 77 additions & 4 deletions mkdocstrings_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import griffe
import yaml
from griffe import Docstring
from griffe2md import ConfigDict, render_object_docs

# Suppress griffe warnings
Expand All @@ -15,6 +16,78 @@ class MkDocstringsParser:
def __init__(self):
pass

def inherit_docstrings(self, obj):
"""Inherit docstrings from immediate parent class for members without docstrings.

Args:
obj: A griffe object (typically a class) to process.
"""
# Only process classes that have bases (parent classes)
if not hasattr(obj, 'bases') or not obj.bases:
return

first_base = obj.bases[0]

# Resolve the base class object
try:
# The base might be a string or an expression, try to resolve it
if not hasattr(first_base, 'canonical_path'):
return

parent_path = first_base.canonical_path

# Get the root package object by traversing up to the top
root = obj
while root.parent:
root = root.parent

# Now navigate down from the root using the canonical path
parent_obj = None
try:
# Remove the package name prefix if present and use the path
path_parts = parent_path.split('.')
parent_obj = root

# Navigate through the path parts (starting from index 1 to skip package name)
for part in path_parts[1:]:
if hasattr(parent_obj, 'members') and part in parent_obj.members:
parent_obj = parent_obj.members[part]
elif hasattr(parent_obj, '__getitem__'):
try:
parent_obj = parent_obj[part]
except (KeyError, AttributeError):
parent_obj = None
break
else:
parent_obj = None
break
except Exception:
parent_obj = None

if not parent_obj:
return

# Iterate through members of the current class
if hasattr(obj, 'members'):
for member_name, member in obj.members.items():
# Skip if member already has a docstring
if member.docstring and member.docstring.value:
continue

# Check if parent has the same member with a docstring
if hasattr(parent_obj, 'members') and member_name in parent_obj.members:
parent_member = parent_obj.members[member_name]
if parent_member.docstring and parent_member.docstring.value:
# Copy the docstring from parent
member.docstring = Docstring(
parent_member.docstring.value,
lineno=member.lineno or 1,
endlineno=member.endlineno
)
except Exception:
# Silently fail if we can't resolve inheritance
pass

def parse_docstring_block(
self, block_content: str
) -> tuple[str, str, Dict[str, Any]]:
Expand Down Expand Up @@ -58,6 +131,9 @@ def generate_documentation(self, module_path: str, options: Dict[str, Any]) -> s
else:
obj = package

# Apply docstring inheritance from immediate parent class
self.inherit_docstrings(obj)

# Ensure the docstring is properly parsed with Google parser
# For functions, we might need to get the actual runtime docstring
if hasattr(obj, "kind") and obj.kind.value == "function":
Expand Down Expand Up @@ -213,7 +289,4 @@ def get_args(self):
print(parser.process_markdown(test_class))
print("\n" + "=" * 50 + "\n")
print("Function documentation:")
print(parser.process_markdown(test_function))

# args = parser.get_args()
# parser.process_file(args.input_file, args.output_file)
print(parser.process_markdown(test_function))
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "A simple parser for mkdocstrings signature blocks"
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"griffe2md",
"griffe2md<1.3",
"pyyaml",
"rich",
]
Expand Down
31 changes: 31 additions & 0 deletions tests/test_coreforecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def test_inherited_fn(setup_parser):
show_root_heading: true
show_source: true"""
output = setup_parser.process_markdown(inherited_fn)
print("hola")
print(output)
print("adios")
assert output == """### `Lag`

```python
Expand Down Expand Up @@ -98,9 +101,37 @@ def test_inherited_fn(setup_parser):
transform(ga)
```

Apply the transformation by group.

**Parameters:**

Name | Type | Description | Default
---- | ---- | ----------- | -------
`ga` | <code>GroupedArray</code> | Array with the grouped data. | *required*

**Returns:**

Type | Description
---- | -----------
| np.ndarray: Array with the transformed data.

#### `Lag.update`

```python
update(ga)
```

Compute the most recent value of the transformation for each group.

**Parameters:**

Name | Type | Description | Default
---- | ---- | ----------- | -------
`ga` | <code>GroupedArray</code> | Array with the grouped data. | *required*

**Returns:**

Type | Description
---- | -----------
| np.ndarray: Array with the updates for each group.
"""
47 changes: 47 additions & 0 deletions tests/test_neuralforecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,51 @@ def test_timeseriesloader(setup_parser):
`sampler` | <code>[Sampler](#Sampler) or [Iterable](#Iterable)</code> | Defines the strategy to draw samples from the dataset. | *required*
`drop_last` | <code>[bool](#bool)</code> | Set to True to drop the last incomplete batch. Defaults to False. | *required*
`**kwargs` | | Additional keyword arguments for DataLoader. | <code>{}</code>
"""

def test_autolstm(setup_parser):
fn = """::: neuralforecast.auto.AutoLSTM
options:
members: [__init__]
heading_level: 3"""
output = setup_parser.process_markdown(fn)
assert output == """### `AutoLSTM`

```python
AutoLSTM(h, loss=MAE(), valid_loss=None, config=None, search_alg=BasicVariantGenerator(random_state=1), num_samples=10, refit_with_val=False, cpus=cpu_count(), gpus=torch.cuda.device_count(), verbose=False, alias=None, backend='ray', callbacks=None)
```

Bases: <code>[BaseAuto](#neuralforecast.common._base_auto.BaseAuto)</code>

Class for Automatic Hyperparameter Optimization, it builds on top of `ray` to
give access to a wide variety of hyperparameter optimization tools ranging
from classic grid search, to Bayesian optimization and HyperBand algorithm.

The validation loss to be optimized is defined by the `config['loss']` dictionary
value, the config also contains the rest of the hyperparameter search space.

It is important to note that the success of this hyperparameter optimization
heavily relies on a strong correlation between the validation and test periods.

**Parameters:**

Name | Type | Description | Default
---- | ---- | ----------- | -------
`cls_model` | <code>PyTorch/PyTorchLightning model</code> | See `neuralforecast.models` [collection here](./models). | *required*
`h` | <code>int</code> | Forecast horizon | *required*
`loss` | <code>PyTorch module</code> | Instantiated train loss class from [losses collection](./losses.pytorch). | *required*
`valid_loss` | <code>PyTorch module</code> | Instantiated valid loss class from [losses collection](./losses.pytorch). | *required*
`config` | <code>dict or callable</code> | Dictionary with ray.tune defined search space or function that takes an optuna trial and returns a configuration dict. | *required*
`search_alg` | <code>ray.tune.search variant or optuna.sampler</code> | For ray see https://docs.ray.io/en/latest/tune/api_docs/suggestion.html | *required*
`For optuna see https` | | //optuna.readthedocs.io/en/stable/reference/samplers/index.html. | *required*
`num_samples` | <code>int</code> | Number of hyperparameter optimization steps/samples. | *required*
`cpus` | <code>int</code> | Number of cpus to use during optimization. Only used with ray tune. | *required*
`gpus` | <code>int</code> | Number of gpus to use during optimization, default all available. Only used with ray tune. | *required*
`refit_with_val` | <code>bool</code> | Refit of best model should preserve val_size. | *required*
`verbose` | <code>bool</code> | Track progress. | *required*
`alias` | <code>str</code> | Custom name of the model. | *required*
`backend` | <code>str</code> | Backend to use for searching the hyperparameter space, can be either 'ray' or 'optuna'. | *required*
`callbacks` | <code>list of callable</code> | List of functions to call during the optimization process. | *required*
`ray reference` | | https://docs.ray.io/en/latest/tune/tutorials/tune-metrics.html | *required*
`optuna reference` | | https://optuna.readthedocs.io/en/stable | *required*
"""
Loading