- Moved all project files and directories (config, core, models, etc.) from edge_inference_service/ to the repository root ai_edge/ - Updated model path in config/settings.py to reflect new structure - Revised usage paths in __init__.py documentation
89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
"""
|
|
TensorRT模块单元测试
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
import sys
|
|
import os
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
class TestTensorRTEngine(unittest.TestCase):
|
|
"""测试TensorRT引擎类"""
|
|
|
|
def test_engine_creation(self):
|
|
"""测试引擎创建"""
|
|
from core.tensorrt_engine import TensorRTEngine
|
|
from config.settings import InferenceConfig
|
|
|
|
config = InferenceConfig(
|
|
model_path="./models/test.engine",
|
|
input_width=480,
|
|
input_height=480,
|
|
batch_size=1,
|
|
fp16_mode=True
|
|
)
|
|
|
|
engine = TensorRTEngine(config)
|
|
|
|
self.assertEqual(engine.config.input_width, 480)
|
|
self.assertEqual(engine.config.input_height, 480)
|
|
self.assertTrue(engine.config.fp16_mode)
|
|
|
|
def test_performance_stats_initial(self):
|
|
"""测试初始性能统计"""
|
|
from core.tensorrt_engine import TensorRTEngine
|
|
from config.settings import InferenceConfig
|
|
|
|
config = InferenceConfig()
|
|
engine = TensorRTEngine(config)
|
|
|
|
stats = engine.get_performance_stats()
|
|
|
|
self.assertIn("inference_count", stats)
|
|
self.assertIn("total_inference_time_ms", stats)
|
|
self.assertEqual(stats["inference_count"], 0)
|
|
|
|
def test_memory_usage(self):
|
|
"""测试显存使用查询"""
|
|
from core.tensorrt_engine import TensorRTEngine
|
|
from config.settings import InferenceConfig
|
|
|
|
config = InferenceConfig()
|
|
engine = TensorRTEngine(config)
|
|
|
|
memory = engine.get_memory_usage()
|
|
|
|
self.assertIn("total_mb", memory)
|
|
self.assertIn("used_mb", memory)
|
|
self.assertIn("free_mb", memory)
|
|
|
|
|
|
class TestEngineManager(unittest.TestCase):
|
|
"""测试引擎管理器"""
|
|
|
|
def test_manager_creation(self):
|
|
"""测试管理器创建"""
|
|
from core.tensorrt_engine import EngineManager
|
|
|
|
manager = EngineManager()
|
|
|
|
self.assertEqual(len(manager._engines), 0)
|
|
|
|
def test_get_nonexistent_engine(self):
|
|
"""测试获取不存在的引擎"""
|
|
from core.tensorrt_engine import EngineManager
|
|
|
|
manager = EngineManager()
|
|
|
|
engine = manager.get_engine("nonexistent")
|
|
|
|
self.assertIsNone(engine)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|