diff --git a/vsp/qwen3.5-9b/test_accuracy.py b/vsp/qwen3.5-9b/test_accuracy.py new file mode 100644 index 0000000..4895d5d --- /dev/null +++ b/vsp/qwen3.5-9b/test_accuracy.py @@ -0,0 +1,185 @@ +"""精度评估 - 测试模型在常见任务上的准确性""" +import json +import os +import glob +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from datetime import datetime + + +# 测试数据集 +ACCURACY_TESTS = [ + # 知识问答 + { + "category": "知识问答", + "prompt": "中国的首都是哪个城市?请只回答城市名。", + "expected_contains": ["北京"], + }, + { + "category": "知识问答", + "prompt": "水的化学式是什么?请只回答化学式。", + "expected_contains": ["H2O"], + }, + { + "category": "知识问答", + "prompt": "地球到太阳的平均距离大约是多少公里?A. 1.5亿 B. 3亿 C. 5亿 D. 1亿。请只回答选项字母。", + "expected_contains": ["A"], + }, + # 数学推理 + { + "category": "数学推理", + "prompt": "计算 15 * 23 = ? 请只回答数字。", + "expected_contains": ["345"], + }, + { + "category": "数学推理", + "prompt": "一个三角形三边分别是3、4、5,它是什么三角形?请只回答类型。", + "expected_contains": ["直角"], + }, + # 逻辑推理 + { + "category": "逻辑推理", + "prompt": "所有的狗都是动物。小白是一只狗。所以小白是什么?请只回答一个词。", + "expected_contains": ["动物"], + }, + # 代码理解 + { + "category": "代码理解", + "prompt": "以下Python代码的输出是什么?\n```python\nprint(len([1, 2, 3, 4, 5]))\n```\n请只回答数字。", + "expected_contains": ["5"], + }, + # 翻译 + { + "category": "翻译", + "prompt": "将'Hello World'翻译成中文,请只回答翻译结果。", + "expected_contains": ["你好", "世界"], + }, + # 摘要能力 + { + "category": "摘要", + "prompt": "用一句话总结:人工智能(AI)是指由人工制造出来的系统所展现出来的智能。AI的核心问题包括推理、知识表示、规划、学习、自然语言处理、感知和移动与操作物体的能力。", + "expected_contains": ["人工智能", "AI"], + }, + # 分类 + { + "category": "情感分类", + "prompt": "判断以下文本的情感是正面还是负面:'这个产品太糟糕了,完全不值这个价格'。请只回答'正面'或'负面'。", + "expected_contains": ["负面"], + }, +] + + +def load_model(): + """加载 4-bit 量化模型""" + paths = glob.glob("vsp/qwen3.5-9b/model/**/config.json", recursive=True) + model_path = os.path.dirname(paths[0]) if paths else "Qwen/Qwen3.5-9B" + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + quantization_config=bnb_config, + device_map="auto", + trust_remote_code=True, + ) + return model, tokenizer + + +def evaluate_accuracy(model, tokenizer): + """运行精度评估""" + print("=" * 60) + print("Qwen3.5-9B 精度评估") + print("=" * 60) + + results = [] + category_stats = {} + + for i, test in enumerate(ACCURACY_TESTS): + messages = [{"role": "user", "content": test["prompt"]}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = tokenizer(text, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=100, + do_sample=False, + ) + + input_len = inputs["input_ids"].shape[1] + response = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip() + + # 检查是否包含预期关键词 + passed = any(kw in response for kw in test["expected_contains"]) + + cat = test["category"] + if cat not in category_stats: + category_stats[cat] = {"total": 0, "passed": 0} + category_stats[cat]["total"] += 1 + if passed: + category_stats[cat]["passed"] += 1 + + status = "PASS" if passed else "FAIL" + print(f"\n[{status}] 测试 {i+1} ({cat})") + print(f" 问题: {test['prompt'][:50]}...") + print(f" 回答: {response[:80]}") + print(f" 预期包含: {test['expected_contains']}") + + results.append({ + "category": cat, + "prompt": test["prompt"], + "response": response, + "expected": test["expected_contains"], + "passed": passed, + }) + + # 汇总 + total = len(results) + passed = sum(1 for r in results if r["passed"]) + print(f"\n{'='*60}") + print(f"精度评估汇总") + print(f"{'='*60}") + print(f" 总计: {total} 题, 通过: {passed} 题, 准确率: {passed/total*100:.1f}%") + print(f"\n 分类统计:") + for cat, stats in category_stats.items(): + rate = stats["passed"] / stats["total"] * 100 + print(f" {cat}: {stats['passed']}/{stats['total']} ({rate:.0f}%)") + + return { + "total": total, + "passed": passed, + "accuracy": round(passed / total * 100, 1), + "category_stats": category_stats, + "details": results, + } + + +def save_results(accuracy_results): + """保存结果""" + output_dir = "vsp/qwen3.5-9b/results" + os.makedirs(output_dir, exist_ok=True) + + report = { + "timestamp": datetime.now().isoformat(), + "model": "Qwen3.5-9B", + "quantization": "4-bit NF4", + "accuracy": accuracy_results, + } + + path = os.path.join(output_dir, "accuracy_results.json") + with open(path, "w", encoding="utf-8") as f: + json.dump(report, f, ensure_ascii=False, indent=2) + print(f"\n结果已保存到 {path}") + + +if __name__ == "__main__": + os.chdir(os.path.dirname(os.path.abspath(__file__)) + "/../..") + model, tokenizer = load_model() + results = evaluate_accuracy(model, tokenizer) + save_results(results)