Files
qwen-test/vsp/qwen3.5-9b/test_accuracy.py

168 lines
5.4 KiB
Python
Raw Normal View History

"""精度评估 - 测试模型在常见任务上的准确性"""
import json
import os
import sys
import glob
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datetime import datetime
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model_utils import load_model, apply_chat
# 测试数据集
ACCURACY_TESTS = [
# 知识问答
{
"category": "知识问答",
"prompt": "中国的首都是哪个城市?请只回答城市名。",
"expected_contains": ["北京"],
},
{
"category": "知识问答",
"prompt": "水的化学式是什么?请只回答化学式。",
"expected_contains": ["H2O"],
},
{
"category": "知识问答",
"prompt": "地球到太阳的平均距离大约是多少公里A. 1.5亿 B. 3亿 C. 5亿 D. 1亿。请只回答选项字母。",
"expected_contains": ["A"],
},
# 数学推理
{
"category": "数学推理",
"prompt": "计算 15 * 23 = ? 请只回答数字。",
"expected_contains": ["345"],
},
{
"category": "数学推理",
"prompt": "一个三角形三边分别是3、4、5它是什么三角形请只回答类型。",
"expected_contains": ["直角"],
},
# 逻辑推理
{
"category": "逻辑推理",
"prompt": "所有的狗都是动物。小白是一只狗。所以小白是什么?请只回答一个词。",
"expected_contains": ["动物"],
},
# 代码理解
{
"category": "代码理解",
"prompt": "以下Python代码的输出是什么\n```python\nprint(len([1, 2, 3, 4, 5]))\n```\n请只回答数字。",
"expected_contains": ["5"],
},
# 翻译
{
"category": "翻译",
"prompt": "'Hello World'翻译成中文,请只回答翻译结果。",
"expected_contains": ["你好", "世界"],
},
# 摘要能力
{
"category": "摘要",
"prompt": "用一句话总结人工智能AI是指由人工制造出来的系统所展现出来的智能。AI的核心问题包括推理、知识表示、规划、学习、自然语言处理、感知和移动与操作物体的能力。",
"expected_contains": ["人工智能", "AI"],
},
# 分类
{
"category": "情感分类",
"prompt": "判断以下文本的情感是正面还是负面:'这个产品太糟糕了,完全不值这个价格'。请只回答'正面''负面'",
"expected_contains": ["负面"],
},
]
def evaluate_accuracy(model, tokenizer):
"""运行精度评估"""
print("=" * 60)
print("Qwen3.5-9B 精度评估")
print("=" * 60)
results = []
category_stats = {}
for i, test in enumerate(ACCURACY_TESTS):
messages = [{"role": "user", "content": test["prompt"]}]
text = apply_chat(tokenizer, messages)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=False,
)
input_len = inputs["input_ids"].shape[1]
response = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
# 检查是否包含预期关键词
passed = any(kw in response for kw in test["expected_contains"])
cat = test["category"]
if cat not in category_stats:
category_stats[cat] = {"total": 0, "passed": 0}
category_stats[cat]["total"] += 1
if passed:
category_stats[cat]["passed"] += 1
status = "PASS" if passed else "FAIL"
print(f"\n[{status}] 测试 {i+1} ({cat})")
print(f" 问题: {test['prompt'][:50]}...")
print(f" 回答: {response[:80]}")
print(f" 预期包含: {test['expected_contains']}")
results.append({
"category": cat,
"prompt": test["prompt"],
"response": response,
"expected": test["expected_contains"],
"passed": passed,
})
# 汇总
total = len(results)
passed = sum(1 for r in results if r["passed"])
print(f"\n{'='*60}")
print(f"精度评估汇总")
print(f"{'='*60}")
print(f" 总计: {total} 题, 通过: {passed} 题, 准确率: {passed/total*100:.1f}%")
print(f"\n 分类统计:")
for cat, stats in category_stats.items():
rate = stats["passed"] / stats["total"] * 100
print(f" {cat}: {stats['passed']}/{stats['total']} ({rate:.0f}%)")
return {
"total": total,
"passed": passed,
"accuracy": round(passed / total * 100, 1),
"category_stats": category_stats,
"details": results,
}
def save_results(accuracy_results):
"""保存结果"""
output_dir = "vsp/qwen3.5-9b/results"
os.makedirs(output_dir, exist_ok=True)
report = {
"timestamp": datetime.now().isoformat(),
"model": "Qwen3.5-9B",
"quantization": "4-bit NF4",
"accuracy": accuracy_results,
}
path = os.path.join(output_dir, "accuracy_results.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"\n结果已保存到 {path}")
if __name__ == "__main__":
os.chdir(os.path.dirname(os.path.abspath(__file__)) + "/../..")
model, tokenizer = load_model()
results = evaluate_accuracy(model, tokenizer)
save_results(results)