Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions src/rapids_pre_commit_hooks/hardcoded_version.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import bisect
import re
from typing import TYPE_CHECKING

import tomlkit
import tomlkit.exceptions

from .lint import LintMain, Linter
from .utils.toml import find_value_location

if TYPE_CHECKING:
import argparse
import os
from collections.abc import Iterator
from collections.abc import Generator

# Matches any 2-part or 3-part numeric version strings, and stores the
# components in named capture groups:
Expand All @@ -26,10 +31,45 @@
r"(?:^|\D)(?P<full>(?P<major>\d{1,2})\.(?P<minor>\d{1,2})(?:\.(?P<patch>\d{1,2}))?)(?=\D|$)"
)

PYPROJECT_TOML_RE: re.Pattern = re.compile(r"(?:^|/)pyproject\.toml$")


def get_excluded_section_pyproject_toml(
document: tomlkit.TOMLDocument, path: tuple[str, ...]
) -> "Generator[tuple[int, int]]":
try:
yield find_value_location(document, path, append=False)
except tomlkit.exceptions.NonExistentKey:
pass


def get_excluded_sections_pyproject_toml(
linter: Linter,
) -> "Generator[tuple[int, int]]":
document = tomlkit.loads(linter.content)

yield from get_excluded_section_pyproject_toml(
document, ("project", "dependencies")
)
yield from get_excluded_section_pyproject_toml(
document, ("project", "optional-dependencies")
)
yield from get_excluded_section_pyproject_toml(
document, ("build-system", "requires")
)
yield from get_excluded_section_pyproject_toml(
document, ("tool", "rapids-build-backend", "requires")
)


def get_excluded_sections(linter: Linter) -> "Generator[tuple[int, int]]":
if PYPROJECT_TOML_RE.search(linter.filename):
yield from get_excluded_sections_pyproject_toml(linter)


def find_hardcoded_versions(
content: str, full_version: tuple[int, int, int]
) -> "Iterator[re.Match[str]]":
) -> "Generator[re.Match[str]]":
"""Detect all instances of a specific 2- or 3-part version in text
content."""

Expand Down Expand Up @@ -78,7 +118,19 @@ def check_hardcoded_version(
return

full_version = read_version_file(args.version_file)
excluded_sections = sorted(get_excluded_sections(linter))
for match in find_hardcoded_versions(linter.content, full_version):
section_index = bisect.bisect_right(
excluded_sections, match.span("full")
)
if section_index > 0:
section_start, section_end = excluded_sections[section_index - 1]
if (
match.start("full") >= section_start
and match.end("full") <= section_end
):
continue

linter.add_warning(
match.span("full"),
f"do not hard-code version, read from {args.version_file} "
Expand Down
64 changes: 1 addition & 63 deletions src/rapids_pre_commit_hooks/pyproject_license.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import copy
import uuid
import re

import tomlkit
import tomlkit.exceptions

from .lint import Linter, LintMain
from .utils.toml import find_value_location

RAPIDS_LICENSE: str = "Apache-2.0"
ACCEPTABLE_LICENSES: set[str] = {
Expand All @@ -18,67 +17,6 @@
}


_LocType = tuple[int, int]


def find_value_location(
document: "tomlkit.TOMLDocument",
key: tuple[str, ...],
*,
append: bool,
) -> _LocType:
"""
Find the exact location of a key in a stringified TOML document.

Parameters
----------
document : tomlkit.TOMLDocument
TOML content
key : tuple[str, ...]
Tuple of strings, of any length.
Items are evaluated in order as keys to subset into ``document``.
For example, to reference the 'license' value in the [project] table
in a pyproject.toml, ``key = ("project", "license",)``.
append : bool
If ``True``, returns the location where new text will be added.
If ``False``, returns the range of characters to be replaced.

Returns
-------
loc : tuple[int, int]
Location of the key and its value in the document.
e.g., ``(20, 35)`` = "the 20th-35th characters, including newlines"
* element 0: number of characters from beginning of the document to
beginning of the section indicated by ``key``
* element 1: final character to replace
"""
copied_document = copy.deepcopy(document)
placeholder = uuid.uuid4()
placeholder_toml = tomlkit.string(str(placeholder))
placeholder_repr = placeholder_toml.as_string()

# tomlkit does not provide "mark" information to determine where exactly in
# the document a value is located, so instead we replace it with a
# placeholder and look for that in the new document.
node = copied_document
while len(key) > (0 if append else 1):
node = node[key[0]] # type: ignore[assignment]
key = key[1:]

if append:
node.add(str(placeholder), placeholder_toml)
value_to_find = f"{placeholder} = {placeholder_repr}"
begin_loc = copied_document.as_string().find(value_to_find)
return begin_loc, begin_loc

# otherwise, if replacing without appending
old_value = node[key[0]]
node[key[0]] = str(placeholder)
begin_loc = copied_document.as_string().find(placeholder_repr)
end_loc = begin_loc + len(old_value.as_string())
return begin_loc, end_loc


def check_pyproject_license(linter: Linter, _args: argparse.Namespace) -> None:
document = tomlkit.loads(linter.content)
try:
Expand Down
76 changes: 76 additions & 0 deletions src/rapids_pre_commit_hooks/utils/toml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import copy
import uuid

import tomlkit


_LocType = tuple[int, int]


def find_value_location(
document: "tomlkit.TOMLDocument",
key: tuple[str, ...],
*,
append: bool,
) -> _LocType:
"""
Find the exact location of a key in a stringified TOML document.

Parameters
----------
document : tomlkit.TOMLDocument
TOML content
key : tuple[str, ...]
Tuple of strings, of any length.
Items are evaluated in order as keys to subset into ``document``.
For example, to reference the 'license' value in the [project] table
in a pyproject.toml, ``key = ("project", "license",)``.
append : bool
If ``True``, returns the location where new text will be added.
If ``False``, returns the range of characters to be replaced.

Returns
-------
loc : tuple[int, int]
Location of the key and its value in the document.
e.g., ``(20, 35)`` = "the 20th-35th characters, including newlines"
* element 0: number of characters from beginning of the document to
beginning of the section indicated by ``key``
* element 1: final character to replace
"""
copied_document = copy.deepcopy(document)
placeholder = uuid.uuid4()
placeholder_toml = tomlkit.string(str(placeholder))
placeholder_repr = placeholder_toml.as_string()

# tomlkit does not provide "mark" information to determine where exactly in
# the document a value is located, so instead we replace it with a
# placeholder and look for that in the new document.
node = copied_document
while len(key) > (0 if append else 1):
node = node[key[0]] # type: ignore[assignment]
key = key[1:]

if append:
node.add(str(placeholder), placeholder_toml)
value_to_find = f"{placeholder} = {placeholder_repr}"
begin_loc = copied_document.as_string().find(value_to_find)
return begin_loc, begin_loc

# otherwise, if replacing without appending
old_value = node[key[0]]
placeholder_value, value_to_find = (
(
{str(placeholder): placeholder_toml},
f"{placeholder} = {placeholder_repr}",
)
if isinstance(old_value, tomlkit.items.Table)
else (str(placeholder), placeholder_repr)
)
node[key[0]] = placeholder_value
begin_loc = copied_document.as_string().find(value_to_find)
end_loc = begin_loc + len(old_value.as_string())
return begin_loc, end_loc
Loading