- 模型加载改为 bitsandbytes 4-bit NF4 量化,device_map={"":0} 纯 GPU
- 关闭 Qwen3.5 thinking 模式 (enable_thinking=False)
- 精度从 60% 提升到 90%,推理速度 1-2 tokens/s
- GPU 显存 7.13GB/8GB,输出质量正常
- 更新所有测试结果和综合报告
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
168 lines
5.4 KiB
Python
168 lines
5.4 KiB
Python
"""精度评估 - 测试模型在常见任务上的准确性"""
|
||
import json
|
||
import os
|
||
import sys
|
||
import glob
|
||
import torch
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
from datetime import datetime
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
from model_utils import load_model, apply_chat
|
||
|
||
|
||
# 测试数据集
|
||
ACCURACY_TESTS = [
|
||
# 知识问答
|
||
{
|
||
"category": "知识问答",
|
||
"prompt": "中国的首都是哪个城市?请只回答城市名。",
|
||
"expected_contains": ["北京"],
|
||
},
|
||
{
|
||
"category": "知识问答",
|
||
"prompt": "水的化学式是什么?请只回答化学式。",
|
||
"expected_contains": ["H2O"],
|
||
},
|
||
{
|
||
"category": "知识问答",
|
||
"prompt": "地球到太阳的平均距离大约是多少公里?A. 1.5亿 B. 3亿 C. 5亿 D. 1亿。请只回答选项字母。",
|
||
"expected_contains": ["A"],
|
||
},
|
||
# 数学推理
|
||
{
|
||
"category": "数学推理",
|
||
"prompt": "计算 15 * 23 = ? 请只回答数字。",
|
||
"expected_contains": ["345"],
|
||
},
|
||
{
|
||
"category": "数学推理",
|
||
"prompt": "一个三角形三边分别是3、4、5,它是什么三角形?请只回答类型。",
|
||
"expected_contains": ["直角"],
|
||
},
|
||
# 逻辑推理
|
||
{
|
||
"category": "逻辑推理",
|
||
"prompt": "所有的狗都是动物。小白是一只狗。所以小白是什么?请只回答一个词。",
|
||
"expected_contains": ["动物"],
|
||
},
|
||
# 代码理解
|
||
{
|
||
"category": "代码理解",
|
||
"prompt": "以下Python代码的输出是什么?\n```python\nprint(len([1, 2, 3, 4, 5]))\n```\n请只回答数字。",
|
||
"expected_contains": ["5"],
|
||
},
|
||
# 翻译
|
||
{
|
||
"category": "翻译",
|
||
"prompt": "将'Hello World'翻译成中文,请只回答翻译结果。",
|
||
"expected_contains": ["你好", "世界"],
|
||
},
|
||
# 摘要能力
|
||
{
|
||
"category": "摘要",
|
||
"prompt": "用一句话总结:人工智能(AI)是指由人工制造出来的系统所展现出来的智能。AI的核心问题包括推理、知识表示、规划、学习、自然语言处理、感知和移动与操作物体的能力。",
|
||
"expected_contains": ["人工智能", "AI"],
|
||
},
|
||
# 分类
|
||
{
|
||
"category": "情感分类",
|
||
"prompt": "判断以下文本的情感是正面还是负面:'这个产品太糟糕了,完全不值这个价格'。请只回答'正面'或'负面'。",
|
||
"expected_contains": ["负面"],
|
||
},
|
||
]
|
||
|
||
|
||
def evaluate_accuracy(model, tokenizer):
|
||
"""运行精度评估"""
|
||
print("=" * 60)
|
||
print("Qwen3.5-9B 精度评估")
|
||
print("=" * 60)
|
||
|
||
results = []
|
||
category_stats = {}
|
||
|
||
for i, test in enumerate(ACCURACY_TESTS):
|
||
messages = [{"role": "user", "content": test["prompt"]}]
|
||
text = apply_chat(tokenizer, messages)
|
||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||
|
||
with torch.no_grad():
|
||
outputs = model.generate(
|
||
**inputs,
|
||
max_new_tokens=100,
|
||
do_sample=False,
|
||
)
|
||
|
||
input_len = inputs["input_ids"].shape[1]
|
||
response = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
|
||
|
||
# 检查是否包含预期关键词
|
||
passed = any(kw in response for kw in test["expected_contains"])
|
||
|
||
cat = test["category"]
|
||
if cat not in category_stats:
|
||
category_stats[cat] = {"total": 0, "passed": 0}
|
||
category_stats[cat]["total"] += 1
|
||
if passed:
|
||
category_stats[cat]["passed"] += 1
|
||
|
||
status = "PASS" if passed else "FAIL"
|
||
print(f"\n[{status}] 测试 {i+1} ({cat})")
|
||
print(f" 问题: {test['prompt'][:50]}...")
|
||
print(f" 回答: {response[:80]}")
|
||
print(f" 预期包含: {test['expected_contains']}")
|
||
|
||
results.append({
|
||
"category": cat,
|
||
"prompt": test["prompt"],
|
||
"response": response,
|
||
"expected": test["expected_contains"],
|
||
"passed": passed,
|
||
})
|
||
|
||
# 汇总
|
||
total = len(results)
|
||
passed = sum(1 for r in results if r["passed"])
|
||
print(f"\n{'='*60}")
|
||
print(f"精度评估汇总")
|
||
print(f"{'='*60}")
|
||
print(f" 总计: {total} 题, 通过: {passed} 题, 准确率: {passed/total*100:.1f}%")
|
||
print(f"\n 分类统计:")
|
||
for cat, stats in category_stats.items():
|
||
rate = stats["passed"] / stats["total"] * 100
|
||
print(f" {cat}: {stats['passed']}/{stats['total']} ({rate:.0f}%)")
|
||
|
||
return {
|
||
"total": total,
|
||
"passed": passed,
|
||
"accuracy": round(passed / total * 100, 1),
|
||
"category_stats": category_stats,
|
||
"details": results,
|
||
}
|
||
|
||
|
||
def save_results(accuracy_results):
|
||
"""保存结果"""
|
||
output_dir = "vsp/qwen3.5-9b/results"
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
report = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"model": "Qwen3.5-9B",
|
||
"quantization": "4-bit NF4",
|
||
"accuracy": accuracy_results,
|
||
}
|
||
|
||
path = os.path.join(output_dir, "accuracy_results.json")
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
json.dump(report, f, ensure_ascii=False, indent=2)
|
||
print(f"\n结果已保存到 {path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
os.chdir(os.path.dirname(os.path.abspath(__file__)) + "/../..")
|
||
model, tokenizer = load_model()
|
||
results = evaluate_accuracy(model, tokenizer)
|
||
save_results(results)
|