1. YOLO11输出解析错误: 移除不存在的objectness行,正确使用class_scores.max() 2. CPU NMS逻辑错误: keep_mask同时标记保留和抑制框导致NMS失效,改用独立suppressed集合 3. 坐标映射缺失: _build_tracks中scale_info未使用,添加revert_boxes还原到ROI裁剪空间 4. batch=1限制: 恢复真正的动态batch推理(1~8),BatchPreprocessor支持多图stack 5. 帧率控制缺失: _read_frame添加time.monotonic()间隔控制,按target_fps跳帧 6. 拉流推理耦合: 新增独立推理线程(InferenceWorker),生产者-消费者模式解耦 7. 攒批形同虚设: 添加50ms攒批窗口+max_batch阈值,替代>=1立即处理 8. LeavePost双重等待: LEAVING确认后直接触发告警,不再进入OFF_DUTY二次等待 9. register_algorithm每帧调用: 添加_registered_keys缓存,O(1)快速路径跳过 10. GPU context线程安全: TensorRT infer()内部加锁,防止多线程CUDA context竞争 附带修复: - reset_algorithm中未定义algorithm_type变量(NameError) - update_roi_params中循环变量key覆盖外层key - AlertInfo缺少bind_id字段(TypeError) - _logger.log_alert在标准logger上不存在(AttributeError) - AlarmStateMachine死锁(Lock改为RLock) - ROICropper.create_mask坐标解析错误 - 更新测试用例适配新API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
261 lines
7.8 KiB
Python
261 lines
7.8 KiB
Python
"""
|
||
后处理模块单元测试
|
||
"""
|
||
|
||
import unittest
|
||
from unittest.mock import MagicMock, patch
|
||
from datetime import datetime
|
||
import sys
|
||
import os
|
||
import numpy as np
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
|
||
class TestNMSProcessor(unittest.TestCase):
|
||
"""测试NMS处理器"""
|
||
|
||
def test_nms_single_box(self):
|
||
"""测试单个检测框"""
|
||
from core.postprocessor import NMSProcessor
|
||
|
||
nms = NMSProcessor(nms_threshold=0.45)
|
||
|
||
boxes = np.array([[100, 100, 200, 200]])
|
||
scores = np.array([0.9])
|
||
class_ids = np.array([0])
|
||
|
||
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
|
||
|
||
self.assertEqual(len(keep_boxes), 1)
|
||
|
||
def test_nms_multiple_boxes(self):
|
||
"""测试多个检测框(高IoU重叠框应被抑制)"""
|
||
from core.postprocessor import NMSProcessor
|
||
|
||
nms = NMSProcessor(nms_threshold=0.45)
|
||
|
||
# box1 和 box2 高度重叠 (IoU > 0.45),box3 独立
|
||
boxes = np.array([
|
||
[100, 100, 200, 200],
|
||
[110, 110, 210, 210],
|
||
[300, 300, 400, 400]
|
||
])
|
||
scores = np.array([0.9, 0.85, 0.8])
|
||
class_ids = np.array([0, 0, 0])
|
||
|
||
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
|
||
|
||
self.assertLessEqual(len(keep_boxes), 2)
|
||
|
||
def test_nms_empty_boxes(self):
|
||
"""测试空检测框"""
|
||
from core.postprocessor import NMSProcessor
|
||
|
||
nms = NMSProcessor()
|
||
|
||
boxes = np.array([]).reshape(0, 4)
|
||
scores = np.array([])
|
||
class_ids = np.array([])
|
||
|
||
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
|
||
|
||
self.assertEqual(len(keep_boxes), 0)
|
||
|
||
|
||
class TestCoordinateMapper(unittest.TestCase):
|
||
"""测试坐标映射"""
|
||
|
||
def test_map_from_letterbox(self):
|
||
"""测试从Letterbox空间映射"""
|
||
from core.postprocessor import CoordinateMapper
|
||
|
||
mapper = CoordinateMapper()
|
||
|
||
box = [120, 120, 360, 360]
|
||
scale_info = (0.375, 60, 60, 0.375)
|
||
original_size = (1280, 720)
|
||
|
||
mapped = mapper.map_from_letterbox(box, scale_info, original_size)
|
||
|
||
self.assertEqual(len(mapped), 4)
|
||
self.assertGreater(mapped[0], 0)
|
||
|
||
def test_get_box_center(self):
|
||
"""测试获取中心点"""
|
||
from core.postprocessor import CoordinateMapper
|
||
|
||
mapper = CoordinateMapper()
|
||
|
||
center = mapper.get_box_center([100, 100, 200, 200])
|
||
|
||
self.assertEqual(center[0], 150)
|
||
self.assertEqual(center[1], 150)
|
||
|
||
def test_get_box_bottom_center(self):
|
||
"""测试获取底部中心点"""
|
||
from core.postprocessor import CoordinateMapper
|
||
|
||
mapper = CoordinateMapper()
|
||
|
||
bottom = mapper.get_box_bottom_center([100, 100, 200, 200])
|
||
|
||
self.assertEqual(bottom[0], 150)
|
||
self.assertEqual(bottom[1], 200)
|
||
|
||
|
||
class TestROIAnalyzer(unittest.TestCase):
|
||
"""测试ROI分析器"""
|
||
|
||
def test_is_point_in_roi(self):
|
||
"""测试点在ROI内"""
|
||
from config.config_models import ROIInfo, ROIType, AlgorithmType
|
||
from core.postprocessor import ROIAnalyzer
|
||
|
||
roi = ROIInfo(
|
||
roi_id="roi001",
|
||
camera_id="cam001",
|
||
roi_type=ROIType.RECTANGLE,
|
||
coordinates=[[100, 100], [200, 200]],
|
||
algorithm_type=AlgorithmType.LEAVE_POST,
|
||
)
|
||
|
||
analyzer = ROIAnalyzer()
|
||
|
||
self.assertTrue(analyzer.is_point_in_roi((150, 150), roi))
|
||
self.assertFalse(analyzer.is_point_in_roi((250, 250), roi))
|
||
|
||
def test_is_detection_in_roi(self):
|
||
"""测试检测在ROI内"""
|
||
from config.config_models import ROIInfo, ROIType, AlgorithmType
|
||
from core.postprocessor import ROIAnalyzer
|
||
|
||
roi = ROIInfo(
|
||
roi_id="roi001",
|
||
camera_id="cam001",
|
||
roi_type=ROIType.RECTANGLE,
|
||
coordinates=[[100, 100], [200, 200]],
|
||
algorithm_type=AlgorithmType.LEAVE_POST,
|
||
)
|
||
|
||
analyzer = ROIAnalyzer()
|
||
|
||
box = [120, 120, 180, 180]
|
||
self.assertTrue(analyzer.is_detection_in_roi(box, roi, "center"))
|
||
|
||
box_outside = [250, 250, 300, 300]
|
||
self.assertFalse(analyzer.is_detection_in_roi(box_outside, roi, "center"))
|
||
|
||
|
||
class TestAlarmStateMachine(unittest.TestCase):
|
||
"""测试告警状态机"""
|
||
|
||
def test_state_machine_creation(self):
|
||
"""测试状态机创建"""
|
||
from core.postprocessor import AlarmStateMachine
|
||
|
||
machine = AlarmStateMachine(
|
||
alert_threshold=3,
|
||
alert_cooldown=300
|
||
)
|
||
|
||
self.assertEqual(machine.alert_threshold, 3)
|
||
self.assertEqual(machine.alert_cooldown, 300)
|
||
|
||
def test_update_detection(self):
|
||
"""测试更新检测状态"""
|
||
from core.postprocessor import AlarmStateMachine
|
||
|
||
machine = AlarmStateMachine(alert_threshold=3)
|
||
|
||
for i in range(3):
|
||
result = machine.update("roi001", True)
|
||
|
||
self.assertTrue(result["should_alert"])
|
||
self.assertEqual(result["reason"], "threshold_reached")
|
||
|
||
def test_update_no_detection(self):
|
||
"""测试无检测更新"""
|
||
from core.postprocessor import AlarmStateMachine
|
||
|
||
machine = AlarmStateMachine(alert_threshold=3)
|
||
|
||
result = machine.update("roi001", False)
|
||
|
||
self.assertFalse(result["should_alert"])
|
||
|
||
def test_reset(self):
|
||
"""测试重置"""
|
||
from core.postprocessor import AlarmStateMachine
|
||
|
||
machine = AlarmStateMachine(alert_threshold=3)
|
||
|
||
for i in range(3):
|
||
machine.update("roi001", True)
|
||
|
||
machine.reset("roi001")
|
||
|
||
state = machine.get_state("roi001")
|
||
self.assertEqual(state.detection_count, 0)
|
||
|
||
|
||
class TestPostProcessor(unittest.TestCase):
|
||
"""测试后处理器"""
|
||
|
||
def test_process_detections(self):
|
||
"""测试处理检测结果"""
|
||
from core.postprocessor import PostProcessor
|
||
|
||
processor = PostProcessor()
|
||
|
||
outputs = [np.random.randn(1, 10, 100).astype(np.float32)]
|
||
|
||
boxes, scores, class_ids = processor.process_detections(outputs)
|
||
|
||
self.assertEqual(len(boxes.shape), 2)
|
||
|
||
def test_check_alarm_condition(self):
|
||
"""测试检查告警条件"""
|
||
from core.postprocessor import PostProcessor
|
||
|
||
processor = PostProcessor()
|
||
|
||
result = processor.check_alarm_condition("roi001", True)
|
||
|
||
self.assertIn("should_alert", result)
|
||
self.assertIn("detection_count", result)
|
||
|
||
def test_create_alert_info(self):
|
||
"""测试创建告警信息"""
|
||
from core.postprocessor import PostProcessor
|
||
|
||
processor = PostProcessor()
|
||
|
||
alert = processor.create_alert_info(
|
||
roi_id="roi001",
|
||
camera_id="cam001",
|
||
detection_results={
|
||
"class_name": "person",
|
||
"confidence": 0.95,
|
||
"bbox": [100, 100, 200, 200]
|
||
},
|
||
message="离岗告警"
|
||
)
|
||
|
||
self.assertEqual(alert.roi_id, "roi001")
|
||
self.assertEqual(alert.camera_id, "cam001")
|
||
|
||
def test_get_statistics(self):
|
||
"""测试获取统计"""
|
||
from core.postprocessor import PostProcessor
|
||
|
||
processor = PostProcessor()
|
||
stats = processor.get_statistics()
|
||
|
||
self.assertIn("nms_threshold", stats)
|
||
self.assertIn("conf_threshold", stats)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|