""" 后处理模块单元测试 """ import unittest from unittest.mock import MagicMock, patch from datetime import datetime import sys import os import numpy as np sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class TestNMSProcessor(unittest.TestCase): """测试NMS处理器""" def test_nms_single_box(self): """测试单个检测框""" from core.postprocessor import NMSProcessor nms = NMSProcessor(nms_threshold=0.45) boxes = np.array([[100, 100, 200, 200]]) scores = np.array([0.9]) class_ids = np.array([0]) keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids) self.assertEqual(len(keep_boxes), 1) def test_nms_multiple_boxes(self): """测试多个检测框""" from core.postprocessor import NMSProcessor nms = NMSProcessor(nms_threshold=0.45) boxes = np.array([ [100, 100, 200, 200], [150, 150, 250, 250], [300, 300, 400, 400] ]) scores = np.array([0.9, 0.85, 0.8]) class_ids = np.array([0, 0, 0]) keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids) self.assertLessEqual(len(keep_boxes), 2) def test_nms_empty_boxes(self): """测试空检测框""" from core.postprocessor import NMSProcessor nms = NMSProcessor() boxes = np.array([]).reshape(0, 4) scores = np.array([]) class_ids = np.array([]) keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids) self.assertEqual(len(keep_boxes), 0) class TestCoordinateMapper(unittest.TestCase): """测试坐标映射""" def test_map_from_letterbox(self): """测试从Letterbox空间映射""" from core.postprocessor import CoordinateMapper mapper = CoordinateMapper() box = [120, 120, 360, 360] scale_info = (0.375, 60, 60, 0.375) original_size = (1280, 720) mapped = mapper.map_from_letterbox(box, scale_info, original_size) self.assertEqual(len(mapped), 4) self.assertGreater(mapped[0], 0) def test_get_box_center(self): """测试获取中心点""" from core.postprocessor import CoordinateMapper mapper = CoordinateMapper() center = mapper.get_box_center([100, 100, 200, 200]) self.assertEqual(center[0], 150) self.assertEqual(center[1], 150) def test_get_box_bottom_center(self): """测试获取底部中心点""" from core.postprocessor import CoordinateMapper mapper = CoordinateMapper() bottom = mapper.get_box_bottom_center([100, 100, 200, 200]) self.assertEqual(bottom[0], 150) self.assertEqual(bottom[1], 200) class TestROIAnalyzer(unittest.TestCase): """测试ROI分析器""" def test_is_point_in_roi(self): """测试点在ROI内""" from config.config_models import ROIInfo, ROIType, AlgorithmType from core.postprocessor import ROIAnalyzer roi = ROIInfo( roi_id="roi001", camera_id="cam001", roi_type=ROIType.RECTANGLE, coordinates=[[100, 100], [200, 200]], algorithm_type=AlgorithmType.LEAVE_POST, ) analyzer = ROIAnalyzer() self.assertTrue(analyzer.is_point_in_roi((150, 150), roi)) self.assertFalse(analyzer.is_point_in_roi((250, 250), roi)) def test_is_detection_in_roi(self): """测试检测在ROI内""" from config.config_models import ROIInfo, ROIType, AlgorithmType from core.postprocessor import ROIAnalyzer roi = ROIInfo( roi_id="roi001", camera_id="cam001", roi_type=ROIType.RECTANGLE, coordinates=[[100, 100], [200, 200]], algorithm_type=AlgorithmType.LEAVE_POST, ) analyzer = ROIAnalyzer() box = [120, 120, 180, 180] self.assertTrue(analyzer.is_detection_in_roi(box, roi, "center")) box_outside = [250, 250, 300, 300] self.assertFalse(analyzer.is_detection_in_roi(box_outside, roi, "center")) class TestAlarmStateMachine(unittest.TestCase): """测试告警状态机""" def test_state_machine_creation(self): """测试状态机创建""" from core.postprocessor import AlarmStateMachine machine = AlarmStateMachine( alert_threshold=3, alert_cooldown=300 ) self.assertEqual(machine.alert_threshold, 3) self.assertEqual(machine.alert_cooldown, 300) def test_update_detection(self): """测试更新检测状态""" from core.postprocessor import AlarmStateMachine machine = AlarmStateMachine(alert_threshold=3) for i in range(3): result = machine.update("roi001", True) self.assertTrue(result["should_alert"]) self.assertEqual(result["reason"], "threshold_reached") def test_update_no_detection(self): """测试无检测更新""" from core.postprocessor import AlarmStateMachine machine = AlarmStateMachine(alert_threshold=3) result = machine.update("roi001", False) self.assertFalse(result["should_alert"]) def test_reset(self): """测试重置""" from core.postprocessor import AlarmStateMachine machine = AlarmStateMachine(alert_threshold=3) for i in range(3): machine.update("roi001", True) machine.reset("roi001") state = machine.get_state("roi001") self.assertEqual(state.detection_count, 0) class TestPostProcessor(unittest.TestCase): """测试后处理器""" def test_process_detections(self): """测试处理检测结果""" from core.postprocessor import PostProcessor processor = PostProcessor() outputs = [np.random.randn(1, 10, 100).astype(np.float32)] boxes, scores, class_ids = processor.process_detections(outputs) self.assertEqual(len(boxes.shape), 2) def test_check_alarm_condition(self): """测试检查告警条件""" from core.postprocessor import PostProcessor processor = PostProcessor() result = processor.check_alarm_condition("roi001", True) self.assertIn("should_alert", result) self.assertIn("detection_count", result) def test_create_alert_info(self): """测试创建告警信息""" from core.postprocessor import PostProcessor processor = PostProcessor() alert = processor.create_alert_info( roi_id="roi001", camera_id="cam001", detection_results={ "class_name": "person", "confidence": 0.95, "bbox": [100, 100, 200, 200] }, message="离岗告警" ) self.assertEqual(alert.roi_id, "roi001") self.assertEqual(alert.camera_id, "cam001") def test_get_statistics(self): """测试获取统计""" from core.postprocessor import PostProcessor processor = PostProcessor() stats = processor.get_statistics() self.assertIn("nms_threshold", stats) self.assertIn("conf_threshold", stats) if __name__ == "__main__": unittest.main()