197 lines
7.3 KiB
Python
197 lines
7.3 KiB
Python
"""DeepSeek Vision API 分析器测试"""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from vitals.vision.providers.deepseek import DeepSeekVisionAnalyzer, get_deepseek_analyzer
|
|
from vitals.vision.analyzer import get_analyzer, FoodAnalyzer
|
|
|
|
|
|
class TestDeepSeekVisionAnalyzer:
|
|
"""DeepSeek Vision 分析器测试"""
|
|
|
|
def test_init_with_api_key(self):
|
|
"""测试使用 API Key 初始化"""
|
|
analyzer = DeepSeekVisionAnalyzer(api_key="test-key")
|
|
assert analyzer.api_key == "test-key"
|
|
assert analyzer.base_url == "https://api.deepseek.com/v1"
|
|
|
|
def test_init_from_env(self):
|
|
"""测试从环境变量读取 API Key"""
|
|
with patch.dict("os.environ", {"DEEPSEEK_API_KEY": "env-key"}):
|
|
analyzer = DeepSeekVisionAnalyzer()
|
|
assert analyzer.api_key == "env-key"
|
|
|
|
def test_init_no_key(self):
|
|
"""测试没有 API Key 时的情况"""
|
|
with patch.dict("os.environ", {}, clear=True):
|
|
analyzer = DeepSeekVisionAnalyzer()
|
|
assert analyzer.api_key is None
|
|
|
|
def test_analyze_image_no_api_key(self):
|
|
"""测试没有 API Key 时分析图片应该抛出异常"""
|
|
with patch.dict("os.environ", {}, clear=True):
|
|
analyzer = DeepSeekVisionAnalyzer()
|
|
with pytest.raises(ValueError, match="DEEPSEEK_API_KEY"):
|
|
analyzer.analyze_image(Path("/fake/image.jpg"))
|
|
|
|
def test_analyze_text_no_api_key(self):
|
|
"""测试没有 API Key 时分析文字应该抛出异常"""
|
|
with patch.dict("os.environ", {}, clear=True):
|
|
analyzer = DeepSeekVisionAnalyzer()
|
|
with pytest.raises(ValueError, match="DEEPSEEK_API_KEY"):
|
|
analyzer.analyze_text("一碗米饭")
|
|
|
|
@patch("httpx.Client")
|
|
def test_analyze_text_success(self, mock_client_class):
|
|
"""测试文字分析成功"""
|
|
# 模拟 API 响应
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "一碗米饭+两个鸡蛋"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.post.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
mock_client_class.return_value = mock_client
|
|
|
|
analyzer = DeepSeekVisionAnalyzer(api_key="test-key")
|
|
result = analyzer.analyze_text("今天吃了一碗米饭和两个鸡蛋")
|
|
|
|
assert "description" in result
|
|
assert "total_calories" in result
|
|
assert result["provider"] == "deepseek"
|
|
assert result["original_input"] == "今天吃了一碗米饭和两个鸡蛋"
|
|
|
|
@patch("httpx.Client")
|
|
def test_analyze_image_success(self, mock_client_class):
|
|
"""测试图片分析成功"""
|
|
# 模拟 API 响应
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "米饭+红烧肉+西兰花"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.post.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
mock_client_class.return_value = mock_client
|
|
|
|
# 创建临时图片文件
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
|
f.write(b"fake image data")
|
|
tmp_path = Path(f.name)
|
|
|
|
try:
|
|
analyzer = DeepSeekVisionAnalyzer(api_key="test-key")
|
|
result = analyzer.analyze_image(tmp_path)
|
|
|
|
assert "description" in result
|
|
assert "total_calories" in result
|
|
assert result["provider"] == "deepseek"
|
|
finally:
|
|
tmp_path.unlink(missing_ok=True)
|
|
|
|
def test_encode_image(self):
|
|
"""测试图片编码"""
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
|
f.write(b"test data")
|
|
tmp_path = Path(f.name)
|
|
|
|
try:
|
|
analyzer = DeepSeekVisionAnalyzer(api_key="test-key")
|
|
encoded = analyzer._encode_image(tmp_path)
|
|
assert encoded == "dGVzdCBkYXRh" # base64 of "test data"
|
|
finally:
|
|
tmp_path.unlink(missing_ok=True)
|
|
|
|
def test_get_media_type(self):
|
|
"""测试 MIME 类型获取"""
|
|
analyzer = DeepSeekVisionAnalyzer(api_key="test-key")
|
|
|
|
assert analyzer._get_media_type(Path("test.jpg")) == "image/jpeg"
|
|
assert analyzer._get_media_type(Path("test.jpeg")) == "image/jpeg"
|
|
assert analyzer._get_media_type(Path("test.png")) == "image/png"
|
|
assert analyzer._get_media_type(Path("test.gif")) == "image/gif"
|
|
assert analyzer._get_media_type(Path("test.webp")) == "image/webp"
|
|
assert analyzer._get_media_type(Path("test.unknown")) == "image/jpeg"
|
|
|
|
|
|
class TestGetDeepSeekAnalyzer:
|
|
"""测试 get_deepseek_analyzer 工厂函数"""
|
|
|
|
def test_get_analyzer(self):
|
|
"""测试获取分析器"""
|
|
analyzer = get_deepseek_analyzer(api_key="test-key")
|
|
assert isinstance(analyzer, DeepSeekVisionAnalyzer)
|
|
assert analyzer.api_key == "test-key"
|
|
|
|
|
|
class TestGetAnalyzerFactory:
|
|
"""测试 analyzer.py 中的 get_analyzer 工厂函数"""
|
|
|
|
def test_get_deepseek_analyzer(self):
|
|
"""测试获取 DeepSeek 分析器"""
|
|
with patch.dict("os.environ", {"DEEPSEEK_API_KEY": "test-key"}):
|
|
analyzer = get_analyzer(provider="deepseek")
|
|
assert isinstance(analyzer, DeepSeekVisionAnalyzer)
|
|
|
|
def test_get_claude_analyzer(self):
|
|
"""测试获取 Claude 分析器"""
|
|
from vitals.vision.analyzer import ClaudeFoodAnalyzer
|
|
analyzer = get_analyzer(provider="claude", api_key="test-key")
|
|
assert isinstance(analyzer, ClaudeFoodAnalyzer)
|
|
|
|
def test_get_local_analyzer(self):
|
|
"""测试获取本地分析器"""
|
|
from vitals.vision.analyzer import LocalFoodAnalyzer
|
|
analyzer = get_analyzer(provider="local")
|
|
assert isinstance(analyzer, LocalFoodAnalyzer)
|
|
|
|
def test_backward_compatibility(self):
|
|
"""测试向后兼容 use_claude 参数"""
|
|
from vitals.vision.analyzer import ClaudeFoodAnalyzer
|
|
analyzer = get_analyzer(use_claude=True, api_key="test-key")
|
|
assert isinstance(analyzer, ClaudeFoodAnalyzer)
|
|
|
|
def test_default_provider(self):
|
|
"""测试默认使用 DeepSeek"""
|
|
with patch.dict("os.environ", {"DEEPSEEK_API_KEY": "test-key"}):
|
|
analyzer = get_analyzer()
|
|
assert isinstance(analyzer, DeepSeekVisionAnalyzer)
|
|
|
|
|
|
class TestLocalFoodAnalyzer:
|
|
"""测试本地分析器"""
|
|
|
|
def test_analyze_returns_empty(self):
|
|
"""测试本地分析器返回空结果"""
|
|
from vitals.vision.analyzer import LocalFoodAnalyzer
|
|
|
|
analyzer = LocalFoodAnalyzer()
|
|
result = analyzer.analyze(Path("/fake/image.jpg"))
|
|
|
|
assert result["description"] == ""
|
|
assert result["total_calories"] == 0
|
|
assert "note" in result
|