Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ transformers.LogitsProcessor(input_ids: torch.LongTensor, scores: torch.FloatTen

### [paddleformers.generation.LogitsProcessor](https://github.com/PaddlePaddle/PaddleNLP/blob/e336e78c338d2514ee6c937982ce5d8c960b85ff/paddlenlp/generation/logits_process.py#L26)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

```python
paddleformers.generation.LogitsProcessor(input_ids: paddle.Tensor, scores: paddle.Tensor)
paddleformers.generation.LogitsProcessor(input_ids: paddle.Tensor, logits: paddle.Tensor)
```

两者功能一致但参数名不一致,部分参数名不同,具体如下:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## [ 仅 API 调用方式不一致 ]torch.allclose
### [torch.allclose](https://pytorch.org/docs/stable/generated/torch.allclose.html?highlight=allclose#torch.allclose)
```python
torch.allclose(input,
other,
rtol=1e-05,
atol=1e-08,
equal_nan=False)
```

### [paddle.compat.allclose](https://github.com/PaddlePaddle/Paddle/blob/304d5c293907f8620ac6e811097a2847514b863c/python/paddle/compat/__init__.py#L71)
```python
paddle.compat.allclose(input,
other,
rtol=1e-05,
atol=1e-08,
equal_nan=False)
```

两者功能一致,但调用 API 名称不一致,具体如下:

### 转写示例


```python
# PyTorch 写法
is_close = torch.allclose(a, b)
# Paddle 写法
is_close = paddle.compat.allclose(a, b)
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
### [transformers.StoppingCriteriaList](https://github.com/huggingface/transformers/blob/d625294d79341662784495551abdf45e6cb9372f/src/transformers/generation/stopping_criteria.py#L503)

```python
transformers.StoppingCriteriaList()
transformers.StoppingCriteriaList(*args)

```

### [paddleformers.generation.StoppingCriteriaList](https://github.com/PaddlePaddle/PaddleFormers/blob/ca66f8dd619a6b2e17fa901042277501b2ed3230/paddleformers/generation/stopping_criteria.py#L72)

```python
paddleformers.generation.StoppingCriteriaList()
paddleformers.generation.StoppingCriteriaList(*args)

```

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,11 @@ def parse_special_category_apis(md_content, category):
return apis


def generate_category1_table(
docs_mapping, no_need_convert_file_path, base_dir, existing_apis
):
def generate_category1_table(existing_apis):
"""
生成类别1(API完全一致)的Markdown表格
"""
no_need_convert_list = extract_no_need_convert_list(
no_need_convert_file_path
)
no_need_convert_list = extract_no_need_convert_list()

rows = [] # 存储表格行数据的列表
used_apis = set() # 用于记录已处理的API,避免重复
Expand Down Expand Up @@ -138,7 +134,6 @@ def generate_category1_table(
def generate_category2_table(
docs_mapping,
api_mapping_file_path,
no_need_convert_file_path,
base_dir,
existing_apis,
attribute_mapping_file_path,
Expand All @@ -156,9 +151,7 @@ def generate_category2_table(
# "torch.utils.data.RandomSampler",
]

no_need_convert_list = extract_no_need_convert_list(
no_need_convert_file_path
)
no_need_convert_list = extract_no_need_convert_list()

# 加载api_mapping.json文件
api_mapping_data = load_mapping_json(api_mapping_file_path)
Expand Down Expand Up @@ -549,18 +542,12 @@ def main():
json_file_path = os.path.join(
os.path.dirname(__file__), "api_difference_info.json"
)
no_need_convert_path = os.path.join(
os.path.dirname(__file__), "global_var.py"
)
api_mapping_path = os.path.join(
os.path.dirname(__file__), "api_mapping.json"
)
api_alias_mapping_path = os.path.join(
os.path.dirname(__file__), "api_alias_mapping.json"
)
no_implement_path = os.path.join(
os.path.dirname(__file__), "no_implement.md"
)
attribute_mapping_path = os.path.join(
os.path.dirname(__file__), "attribute_mapping.json"
)
Expand Down Expand Up @@ -606,13 +593,10 @@ def main():

# 生成类别1和类别2的表格
existing_apis = set()
category1_table = generate_category1_table(
docs_mapping, no_need_convert_path, base_dir, existing_apis
)
category1_table = generate_category1_table(existing_apis)
category2_table = generate_category2_table(
docs_mapping,
api_mapping_path,
no_need_convert_path,
base_dir,
existing_apis,
attribute_mapping_path,
Expand Down
37 changes: 16 additions & 21 deletions docs/guides/model_convert/convert_from_pytorch/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
import json
import os
import re
Expand Down Expand Up @@ -269,7 +268,7 @@ def load_mapping_json(json_path):
return json.load(f)
except Exception as e:
print(f"错误: 读取JSON文件 {json_path} 时出错: {e!s}")
return []
return {}


def convert_to_github_url(local_path, base_dir):
Expand Down Expand Up @@ -331,24 +330,20 @@ def get_paddle_url(paddle_api: str) -> str:
return url + "#" + anchor


def extract_no_need_convert_list(file_path):
def extract_no_need_convert_list():
file_path = os.path.join(os.path.dirname(__file__), "api_mapping.json")
if not os.path.exists(file_path):
raise FileNotFoundError(
f"api_mapping.json should exist at {file_path} to extract no_need_convert_list"
)
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()

tree = ast.parse(content)
no_need_list = None

for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "GlobalManager":
for class_node in node.body:
if isinstance(class_node, ast.Assign) and any(
target.id == "NO_NEED_CONVERT_LIST"
for target in class_node.targets
):
# 提取列表字面量
list_source = ast.get_source_segment(
content, class_node.value
)
no_need_list = ast.literal_eval(list_source)
break
api_mapping_json = json.load(file)

no_need_list = [
k
for k, v in api_mapping_json.items()
if v.get("Matcher") == "ChangePrefixMatcher"
]
if len(no_need_list) == 0:
raise ValueError("no_need_list is empty")
return no_need_list
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def validate_api_mappings():
for api in ALLOW_MISSING_DIFF_DOCS:
api_map.pop(api, None)

no_need_list = extract_no_need_convert_list(
str(current_dir) + "/global_var.py"
)
no_need_list = extract_no_need_convert_list()

# 准备错误报告文件
error_file = current_dir / "validate_api_difference_consistency_error.txt"
Expand Down
Loading