- Moved all project files and directories (config, core, models, etc.) from edge_inference_service/ to the repository root ai_edge/ - Updated model path in config/settings.py to reflect new structure - Revised usage paths in __init__.py documentation
260 lines
7.7 KiB
Python
260 lines
7.7 KiB
Python
"""
|
|
后处理模块单元测试
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
from datetime import datetime
|
|
import sys
|
|
import os
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
class TestNMSProcessor(unittest.TestCase):
|
|
"""测试NMS处理器"""
|
|
|
|
def test_nms_single_box(self):
|
|
"""测试单个检测框"""
|
|
from core.postprocessor import NMSProcessor
|
|
|
|
nms = NMSProcessor(nms_threshold=0.45)
|
|
|
|
boxes = np.array([[100, 100, 200, 200]])
|
|
scores = np.array([0.9])
|
|
class_ids = np.array([0])
|
|
|
|
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
|
|
|
|
self.assertEqual(len(keep_boxes), 1)
|
|
|
|
def test_nms_multiple_boxes(self):
|
|
"""测试多个检测框"""
|
|
from core.postprocessor import NMSProcessor
|
|
|
|
nms = NMSProcessor(nms_threshold=0.45)
|
|
|
|
boxes = np.array([
|
|
[100, 100, 200, 200],
|
|
[150, 150, 250, 250],
|
|
[300, 300, 400, 400]
|
|
])
|
|
scores = np.array([0.9, 0.85, 0.8])
|
|
class_ids = np.array([0, 0, 0])
|
|
|
|
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
|
|
|
|
self.assertLessEqual(len(keep_boxes), 2)
|
|
|
|
def test_nms_empty_boxes(self):
|
|
"""测试空检测框"""
|
|
from core.postprocessor import NMSProcessor
|
|
|
|
nms = NMSProcessor()
|
|
|
|
boxes = np.array([]).reshape(0, 4)
|
|
scores = np.array([])
|
|
class_ids = np.array([])
|
|
|
|
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
|
|
|
|
self.assertEqual(len(keep_boxes), 0)
|
|
|
|
|
|
class TestCoordinateMapper(unittest.TestCase):
|
|
"""测试坐标映射"""
|
|
|
|
def test_map_from_letterbox(self):
|
|
"""测试从Letterbox空间映射"""
|
|
from core.postprocessor import CoordinateMapper
|
|
|
|
mapper = CoordinateMapper()
|
|
|
|
box = [120, 120, 360, 360]
|
|
scale_info = (0.375, 60, 60, 0.375)
|
|
original_size = (1280, 720)
|
|
|
|
mapped = mapper.map_from_letterbox(box, scale_info, original_size)
|
|
|
|
self.assertEqual(len(mapped), 4)
|
|
self.assertGreater(mapped[0], 0)
|
|
|
|
def test_get_box_center(self):
|
|
"""测试获取中心点"""
|
|
from core.postprocessor import CoordinateMapper
|
|
|
|
mapper = CoordinateMapper()
|
|
|
|
center = mapper.get_box_center([100, 100, 200, 200])
|
|
|
|
self.assertEqual(center[0], 150)
|
|
self.assertEqual(center[1], 150)
|
|
|
|
def test_get_box_bottom_center(self):
|
|
"""测试获取底部中心点"""
|
|
from core.postprocessor import CoordinateMapper
|
|
|
|
mapper = CoordinateMapper()
|
|
|
|
bottom = mapper.get_box_bottom_center([100, 100, 200, 200])
|
|
|
|
self.assertEqual(bottom[0], 150)
|
|
self.assertEqual(bottom[1], 200)
|
|
|
|
|
|
class TestROIAnalyzer(unittest.TestCase):
|
|
"""测试ROI分析器"""
|
|
|
|
def test_is_point_in_roi(self):
|
|
"""测试点在ROI内"""
|
|
from config.config_models import ROIInfo, ROIType, AlgorithmType
|
|
from core.postprocessor import ROIAnalyzer
|
|
|
|
roi = ROIInfo(
|
|
roi_id="roi001",
|
|
camera_id="cam001",
|
|
roi_type=ROIType.RECTANGLE,
|
|
coordinates=[[100, 100], [200, 200]],
|
|
algorithm_type=AlgorithmType.LEAVE_POST,
|
|
)
|
|
|
|
analyzer = ROIAnalyzer()
|
|
|
|
self.assertTrue(analyzer.is_point_in_roi((150, 150), roi))
|
|
self.assertFalse(analyzer.is_point_in_roi((250, 250), roi))
|
|
|
|
def test_is_detection_in_roi(self):
|
|
"""测试检测在ROI内"""
|
|
from config.config_models import ROIInfo, ROIType, AlgorithmType
|
|
from core.postprocessor import ROIAnalyzer
|
|
|
|
roi = ROIInfo(
|
|
roi_id="roi001",
|
|
camera_id="cam001",
|
|
roi_type=ROIType.RECTANGLE,
|
|
coordinates=[[100, 100], [200, 200]],
|
|
algorithm_type=AlgorithmType.LEAVE_POST,
|
|
)
|
|
|
|
analyzer = ROIAnalyzer()
|
|
|
|
box = [120, 120, 180, 180]
|
|
self.assertTrue(analyzer.is_detection_in_roi(box, roi, "center"))
|
|
|
|
box_outside = [250, 250, 300, 300]
|
|
self.assertFalse(analyzer.is_detection_in_roi(box_outside, roi, "center"))
|
|
|
|
|
|
class TestAlarmStateMachine(unittest.TestCase):
|
|
"""测试告警状态机"""
|
|
|
|
def test_state_machine_creation(self):
|
|
"""测试状态机创建"""
|
|
from core.postprocessor import AlarmStateMachine
|
|
|
|
machine = AlarmStateMachine(
|
|
alert_threshold=3,
|
|
alert_cooldown=300
|
|
)
|
|
|
|
self.assertEqual(machine.alert_threshold, 3)
|
|
self.assertEqual(machine.alert_cooldown, 300)
|
|
|
|
def test_update_detection(self):
|
|
"""测试更新检测状态"""
|
|
from core.postprocessor import AlarmStateMachine
|
|
|
|
machine = AlarmStateMachine(alert_threshold=3)
|
|
|
|
for i in range(3):
|
|
result = machine.update("roi001", True)
|
|
|
|
self.assertTrue(result["should_alert"])
|
|
self.assertEqual(result["reason"], "threshold_reached")
|
|
|
|
def test_update_no_detection(self):
|
|
"""测试无检测更新"""
|
|
from core.postprocessor import AlarmStateMachine
|
|
|
|
machine = AlarmStateMachine(alert_threshold=3)
|
|
|
|
result = machine.update("roi001", False)
|
|
|
|
self.assertFalse(result["should_alert"])
|
|
|
|
def test_reset(self):
|
|
"""测试重置"""
|
|
from core.postprocessor import AlarmStateMachine
|
|
|
|
machine = AlarmStateMachine(alert_threshold=3)
|
|
|
|
for i in range(3):
|
|
machine.update("roi001", True)
|
|
|
|
machine.reset("roi001")
|
|
|
|
state = machine.get_state("roi001")
|
|
self.assertEqual(state.detection_count, 0)
|
|
|
|
|
|
class TestPostProcessor(unittest.TestCase):
|
|
"""测试后处理器"""
|
|
|
|
def test_process_detections(self):
|
|
"""测试处理检测结果"""
|
|
from core.postprocessor import PostProcessor
|
|
|
|
processor = PostProcessor()
|
|
|
|
outputs = [np.random.randn(1, 10, 100).astype(np.float32)]
|
|
|
|
boxes, scores, class_ids = processor.process_detections(outputs)
|
|
|
|
self.assertEqual(len(boxes.shape), 2)
|
|
|
|
def test_check_alarm_condition(self):
|
|
"""测试检查告警条件"""
|
|
from core.postprocessor import PostProcessor
|
|
|
|
processor = PostProcessor()
|
|
|
|
result = processor.check_alarm_condition("roi001", True)
|
|
|
|
self.assertIn("should_alert", result)
|
|
self.assertIn("detection_count", result)
|
|
|
|
def test_create_alert_info(self):
|
|
"""测试创建告警信息"""
|
|
from core.postprocessor import PostProcessor
|
|
|
|
processor = PostProcessor()
|
|
|
|
alert = processor.create_alert_info(
|
|
roi_id="roi001",
|
|
camera_id="cam001",
|
|
detection_results={
|
|
"class_name": "person",
|
|
"confidence": 0.95,
|
|
"bbox": [100, 100, 200, 200]
|
|
},
|
|
message="离岗告警"
|
|
)
|
|
|
|
self.assertEqual(alert.roi_id, "roi001")
|
|
self.assertEqual(alert.camera_id, "cam001")
|
|
|
|
def test_get_statistics(self):
|
|
"""测试获取统计"""
|
|
from core.postprocessor import PostProcessor
|
|
|
|
processor = PostProcessor()
|
|
stats = processor.get_statistics()
|
|
|
|
self.assertIn("nms_threshold", stats)
|
|
self.assertIn("conf_threshold", stats)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|