feat: 添加精度评估脚本(知识/数学/逻辑/代码/翻译)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
185
vsp/qwen3.5-9b/test_accuracy.py
Normal file
185
vsp/qwen3.5-9b/test_accuracy.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""精度评估 - 测试模型在常见任务上的准确性"""
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# 测试数据集
|
||||
ACCURACY_TESTS = [
|
||||
# 知识问答
|
||||
{
|
||||
"category": "知识问答",
|
||||
"prompt": "中国的首都是哪个城市?请只回答城市名。",
|
||||
"expected_contains": ["北京"],
|
||||
},
|
||||
{
|
||||
"category": "知识问答",
|
||||
"prompt": "水的化学式是什么?请只回答化学式。",
|
||||
"expected_contains": ["H2O"],
|
||||
},
|
||||
{
|
||||
"category": "知识问答",
|
||||
"prompt": "地球到太阳的平均距离大约是多少公里?A. 1.5亿 B. 3亿 C. 5亿 D. 1亿。请只回答选项字母。",
|
||||
"expected_contains": ["A"],
|
||||
},
|
||||
# 数学推理
|
||||
{
|
||||
"category": "数学推理",
|
||||
"prompt": "计算 15 * 23 = ? 请只回答数字。",
|
||||
"expected_contains": ["345"],
|
||||
},
|
||||
{
|
||||
"category": "数学推理",
|
||||
"prompt": "一个三角形三边分别是3、4、5,它是什么三角形?请只回答类型。",
|
||||
"expected_contains": ["直角"],
|
||||
},
|
||||
# 逻辑推理
|
||||
{
|
||||
"category": "逻辑推理",
|
||||
"prompt": "所有的狗都是动物。小白是一只狗。所以小白是什么?请只回答一个词。",
|
||||
"expected_contains": ["动物"],
|
||||
},
|
||||
# 代码理解
|
||||
{
|
||||
"category": "代码理解",
|
||||
"prompt": "以下Python代码的输出是什么?\n```python\nprint(len([1, 2, 3, 4, 5]))\n```\n请只回答数字。",
|
||||
"expected_contains": ["5"],
|
||||
},
|
||||
# 翻译
|
||||
{
|
||||
"category": "翻译",
|
||||
"prompt": "将'Hello World'翻译成中文,请只回答翻译结果。",
|
||||
"expected_contains": ["你好", "世界"],
|
||||
},
|
||||
# 摘要能力
|
||||
{
|
||||
"category": "摘要",
|
||||
"prompt": "用一句话总结:人工智能(AI)是指由人工制造出来的系统所展现出来的智能。AI的核心问题包括推理、知识表示、规划、学习、自然语言处理、感知和移动与操作物体的能力。",
|
||||
"expected_contains": ["人工智能", "AI"],
|
||||
},
|
||||
# 分类
|
||||
{
|
||||
"category": "情感分类",
|
||||
"prompt": "判断以下文本的情感是正面还是负面:'这个产品太糟糕了,完全不值这个价格'。请只回答'正面'或'负面'。",
|
||||
"expected_contains": ["负面"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载 4-bit 量化模型"""
|
||||
paths = glob.glob("vsp/qwen3.5-9b/model/**/config.json", recursive=True)
|
||||
model_path = os.path.dirname(paths[0]) if paths else "Qwen/Qwen3.5-9B"
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def evaluate_accuracy(model, tokenizer):
|
||||
"""运行精度评估"""
|
||||
print("=" * 60)
|
||||
print("Qwen3.5-9B 精度评估")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
category_stats = {}
|
||||
|
||||
for i, test in enumerate(ACCURACY_TESTS):
|
||||
messages = [{"role": "user", "content": test["prompt"]}]
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=100,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
response = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
|
||||
|
||||
# 检查是否包含预期关键词
|
||||
passed = any(kw in response for kw in test["expected_contains"])
|
||||
|
||||
cat = test["category"]
|
||||
if cat not in category_stats:
|
||||
category_stats[cat] = {"total": 0, "passed": 0}
|
||||
category_stats[cat]["total"] += 1
|
||||
if passed:
|
||||
category_stats[cat]["passed"] += 1
|
||||
|
||||
status = "PASS" if passed else "FAIL"
|
||||
print(f"\n[{status}] 测试 {i+1} ({cat})")
|
||||
print(f" 问题: {test['prompt'][:50]}...")
|
||||
print(f" 回答: {response[:80]}")
|
||||
print(f" 预期包含: {test['expected_contains']}")
|
||||
|
||||
results.append({
|
||||
"category": cat,
|
||||
"prompt": test["prompt"],
|
||||
"response": response,
|
||||
"expected": test["expected_contains"],
|
||||
"passed": passed,
|
||||
})
|
||||
|
||||
# 汇总
|
||||
total = len(results)
|
||||
passed = sum(1 for r in results if r["passed"])
|
||||
print(f"\n{'='*60}")
|
||||
print(f"精度评估汇总")
|
||||
print(f"{'='*60}")
|
||||
print(f" 总计: {total} 题, 通过: {passed} 题, 准确率: {passed/total*100:.1f}%")
|
||||
print(f"\n 分类统计:")
|
||||
for cat, stats in category_stats.items():
|
||||
rate = stats["passed"] / stats["total"] * 100
|
||||
print(f" {cat}: {stats['passed']}/{stats['total']} ({rate:.0f}%)")
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"passed": passed,
|
||||
"accuracy": round(passed / total * 100, 1),
|
||||
"category_stats": category_stats,
|
||||
"details": results,
|
||||
}
|
||||
|
||||
|
||||
def save_results(accuracy_results):
|
||||
"""保存结果"""
|
||||
output_dir = "vsp/qwen3.5-9b/results"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
report = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": "Qwen3.5-9B",
|
||||
"quantization": "4-bit NF4",
|
||||
"accuracy": accuracy_results,
|
||||
}
|
||||
|
||||
path = os.path.join(output_dir, "accuracy_results.json")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(report, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n结果已保存到 {path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.chdir(os.path.dirname(os.path.abspath(__file__)) + "/../..")
|
||||
model, tokenizer = load_model()
|
||||
results = evaluate_accuracy(model, tokenizer)
|
||||
save_results(results)
|
||||
Reference in New Issue
Block a user