Files
security-ai-edge/tests/test_tensorrt.py
16337 b0ddb6ee1a feat(project): move edge_inference_service contents to root and update paths
- Moved all project files and directories (config, core, models, etc.) from
  edge_inference_service/ to the repository root ai_edge/
- Updated model path in config/settings.py to reflect new structure
- Revised usage paths in __init__.py documentation
2026-01-29 18:43:19 +08:00

89 lines
2.5 KiB
Python

"""
TensorRT模块单元测试
"""
import unittest
from unittest.mock import MagicMock, patch
import sys
import os
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class TestTensorRTEngine(unittest.TestCase):
"""测试TensorRT引擎类"""
def test_engine_creation(self):
"""测试引擎创建"""
from core.tensorrt_engine import TensorRTEngine
from config.settings import InferenceConfig
config = InferenceConfig(
model_path="./models/test.engine",
input_width=480,
input_height=480,
batch_size=1,
fp16_mode=True
)
engine = TensorRTEngine(config)
self.assertEqual(engine.config.input_width, 480)
self.assertEqual(engine.config.input_height, 480)
self.assertTrue(engine.config.fp16_mode)
def test_performance_stats_initial(self):
"""测试初始性能统计"""
from core.tensorrt_engine import TensorRTEngine
from config.settings import InferenceConfig
config = InferenceConfig()
engine = TensorRTEngine(config)
stats = engine.get_performance_stats()
self.assertIn("inference_count", stats)
self.assertIn("total_inference_time_ms", stats)
self.assertEqual(stats["inference_count"], 0)
def test_memory_usage(self):
"""测试显存使用查询"""
from core.tensorrt_engine import TensorRTEngine
from config.settings import InferenceConfig
config = InferenceConfig()
engine = TensorRTEngine(config)
memory = engine.get_memory_usage()
self.assertIn("total_mb", memory)
self.assertIn("used_mb", memory)
self.assertIn("free_mb", memory)
class TestEngineManager(unittest.TestCase):
"""测试引擎管理器"""
def test_manager_creation(self):
"""测试管理器创建"""
from core.tensorrt_engine import EngineManager
manager = EngineManager()
self.assertEqual(len(manager._engines), 0)
def test_get_nonexistent_engine(self):
"""测试获取不存在的引擎"""
from core.tensorrt_engine import EngineManager
manager = EngineManager()
engine = manager.get_engine("nonexistent")
self.assertIsNone(engine)
if __name__ == "__main__":
unittest.main()