Skip to content
Open
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c0b66f9
Add ci case for min token and max token
Aug 5, 2025
5000259
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Aug 6, 2025
d88cb71
【CI case】include total_tokens in the last packet of completion interf…
Aug 8, 2025
9a3ec54
Merge branch 'develop' into develop
xjkmfa Aug 8, 2025
8eef006
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Aug 11, 2025
f8eddd0
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Aug 12, 2025
1ee4618
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Aug 19, 2025
a2b0a59
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Aug 24, 2025
cfa6540
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Aug 26, 2025
e6c048b
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Sep 15, 2025
fa9f8a9
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Oct 29, 2025
e8d2d12
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Nov 26, 2025
acfda94
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Nov 27, 2025
fae92ee
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Dec 2, 2025
34dc3ce
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Dec 16, 2025
c0921f6
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Dec 25, 2025
656a6b5
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Dec 26, 2025
cab36ec
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Dec 30, 2025
fca4ce2
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Dec 30, 2025
da7bbfd
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Jan 9, 2026
9e1caa7
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Jan 12, 2026
42f162a
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Jan 30, 2026
d426eec
Merge branch 'PaddlePaddle:develop' into develop
xjkmfa Feb 6, 2026
51042fe
[ci] prompt_logprobs precision case
Feb 6, 2026
7901bb6
[ci] prompt_logprobs precision case
Feb 6, 2026
bcf8995
[ci] prompt_logprobs precision case
Feb 6, 2026
84faa92
[ci] prompt_logprobs precision case
Feb 6, 2026
61baac3
[ci] prompt_logprobs precision case
Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 159 additions & 11 deletions tests/ci_use/Prompt_logprobs/test_prompt_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys
import time

import numpy
import pytest
import requests

Expand Down Expand Up @@ -339,21 +340,25 @@ def test_unstream_with_prompt_logprobs_zero_completions():
assert token_id in resp_json["choices"][0]["prompt_token_ids"]


def test_unstream_with_prompt_logprobs_chunk():
def test_unstream_with_prompt_logprobs_chunk_chat():
"""
测试chunk切分的能力是否正常
"""
data = {
"stream": False,
"prompt": [10] * (32 * 1024),
"messages": [
{"role": "user", "content": "!hello! " * (8 * 1024)},
],
"max_tokens": 1,
"prompt_logprobs": 1,
}
response = send_request(COMPLETIONS_URL, data)
# 构建请求并发送
response = send_request(URL, data)
resp_json = response.json()
# print(json.dumps(resp_json, ensure_ascii=False))

# 校验返回内容与概率信息
assert resp_json["choices"][0]["text"] is not None
assert resp_json["choices"][0]["message"]["content"] is not None
# assert resp_json["usage"]["prompt_tokens"] == 7
assert resp_json["usage"]["completion_tokens"] == 1
for i, prompt_logprobs in enumerate(resp_json["choices"][0]["prompt_logprobs"]):
Expand All @@ -368,24 +373,21 @@ def test_unstream_with_prompt_logprobs_chunk():
assert top[i]["decoded_token"].encode("utf-8")


def test_unstream_with_prompt_logprobs_chunk_chat():
def test_unstream_with_prompt_logprobs_chunk():
"""
测试chunk切分的能力是否正常
"""
data = {
"stream": False,
"messages": [
{"role": "user", "content": "!hello! " * (8 * 1024)},
],
"prompt": [10] * (32 * 1024),
"max_tokens": 1,
"prompt_logprobs": 1,
}
# 构建请求并发送
response = send_request(URL, data)
response = send_request(COMPLETIONS_URL, data)
resp_json = response.json()

# 校验返回内容与概率信息
assert resp_json["choices"][0]["message"]["content"] is not None
assert resp_json["choices"][0]["text"] is not None
# assert resp_json["usage"]["prompt_tokens"] == 7
assert resp_json["usage"]["completion_tokens"] == 1
for i, prompt_logprobs in enumerate(resp_json["choices"][0]["prompt_logprobs"]):
Expand Down Expand Up @@ -620,5 +622,151 @@ def send_request(url, payload, timeout=600, stream=False):
return None


def test_logprobs_with_prompt_logprobs_diff():
"""
测试prompt_logprobs与logprobs的一致性
"""
data = {
"stream": False,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 1024,
"logprobs": True,
"top_logprobs": 0,
"return_token_ids": True,
"temperature": 1,
"top_p": 1.0,
"top_k": 0,
"seed": 33,
}

# 构建请求并发送
response_short = send_request(URL, data)
resp_json_short = response_short.json()
print(json.dumps(resp_json_short, ensure_ascii=False))
prompt_token_ids = resp_json_short["choices"][0]["message"]["prompt_token_ids"]
completion_token_ids = resp_json_short["choices"][0]["message"]["completion_token_ids"]
logprobs = resp_json_short["choices"][0]["logprobs"]["content"]
# assert completions_token_ids
data2 = {
"stream": False,
"messages": [
{"role": "user", "content": ""},
],
"max_tokens": 1,
"prompt_logprobs": 0,
"return_token_ids": True,
"temperature": 1,
"top_p": 1.0,
"top_k": 0,
"seed": 33,
"prompt_token_ids": prompt_token_ids + completion_token_ids,
}

# 构建请求并发送
response_long = send_request(URL, data2)
resp_json_long = response_long.json()
print(json.dumps(resp_json_long, ensure_ascii=False))
prompt_logprobs = resp_json_long["choices"][0].get("prompt_logprobs")
completion_prompt_logprobs = prompt_logprobs[len(prompt_token_ids) :]

print("======对比1请求的logprob和2请求的后半部分prompt_logprobs======>")

with open("output_logprobs.log", "w", encoding="utf-8") as f:
for i in range(len(completion_token_ids)):
output_token_ids = completion_token_ids[i]
line = (
f"{i}, {output_token_ids}, "
f'logprob={logprobs[i]["logprob"]}, '
f'prompt_logprob={completion_prompt_logprobs[i][str(output_token_ids)]["logprob"]}\n'
)
f.write(line)

print("====== 校验绝对误差 abs(logprob - prompt_logprob) <= 10 ======")

MAX_ABS_ERROR = 1.0

for i in range(len(completion_token_ids)):
token_id = completion_token_ids[i]
logprob = logprobs[i]["logprob"]
prompt_logprob = completion_prompt_logprobs[i][str(token_id)]["logprob"]
# numpy.testing.assert_allclose(numpy.array(logprob), numpy.array(prompt_logprob))
numpy.testing.assert_allclose(
numpy.array(prompt_logprob),
numpy.array(logprob),
rtol=3e-1,
atol=1e-3,
)
abs_error = abs(logprob - prompt_logprob)

assert abs_error <= MAX_ABS_ERROR, (
f"[ABS_ERROR_TOO_LARGE] "
f"index={i}, token_id={token_id}, "
f"logprob={logprob}, "
f"prompt_logprob={prompt_logprob}, "
f"abs_error={abs_error}"
)

print("✅ 所有 token 的绝对误差均 <= 1")


def test_prompt_logprobs_accuracy():
"""
测试prompt_logprobs的精度,计算一致
"""
data1 = {
"stream": False,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"top_p": 1.0,
"temperature": 0,
"max_tokens": 10,
"n": 1,
"seed": 1,
"return_token_ids": True,
"prompt_logprobs": 3,
"top_k": -1,
}

# 构建请求并发送
response_short = send_request(URL, data1)
resp_json_short = response_short.json()
print(json.dumps(resp_json_short, ensure_ascii=False))
prompt_token_ids = resp_json_short["choices"][0]["message"]["prompt_token_ids"]
completion_token_ids = resp_json_short["choices"][0]["message"]["completion_token_ids"]
prompt_short_logprobs = resp_json_short["choices"][0]["prompt_logprobs"]
# print(json.dumps(prompt_short_logprobs, ensure_ascii=False))

print("-----------------------prompt_short_logprobs------------------------------------")
prompt_and_completion_token_ids = prompt_token_ids + completion_token_ids
data2 = {
"stream": False,
"messages": [
{"role": "user", "content": ""},
],
"top_p": 1.0,
"temperature": 0,
"max_tokens": 10,
"n": 1,
"seed": 1,
"prompt_logprobs": 3,
"top_k": -1,
"prompt_token_ids": prompt_and_completion_token_ids,
}
# 构建请求并发送
response_long = send_request(URL, data2)
resp_json_long = response_long.json()
prompt_long_logprobs = resp_json_long["choices"][0]["prompt_logprobs"]
print("-----------------------prompt_long_logprobs------------------------------------")
print(json.dumps(prompt_long_logprobs, ensure_ascii=False))

for i in range(len(prompt_short_logprobs)):
assert prompt_long_logprobs[i] == prompt_short_logprobs[i], f"prompt_logprobs mismatch at token index {i}"


if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-sv"]))
Loading