Files
security-ai-edge/tests/test_postprocessor.py
16337 98595402c6 fix: 修复10个关键bug提升系统稳定性和性能
1. YOLO11输出解析错误: 移除不存在的objectness行,正确使用class_scores.max()
2. CPU NMS逻辑错误: keep_mask同时标记保留和抑制框导致NMS失效,改用独立suppressed集合
3. 坐标映射缺失: _build_tracks中scale_info未使用,添加revert_boxes还原到ROI裁剪空间
4. batch=1限制: 恢复真正的动态batch推理(1~8),BatchPreprocessor支持多图stack
5. 帧率控制缺失: _read_frame添加time.monotonic()间隔控制,按target_fps跳帧
6. 拉流推理耦合: 新增独立推理线程(InferenceWorker),生产者-消费者模式解耦
7. 攒批形同虚设: 添加50ms攒批窗口+max_batch阈值,替代>=1立即处理
8. LeavePost双重等待: LEAVING确认后直接触发告警,不再进入OFF_DUTY二次等待
9. register_algorithm每帧调用: 添加_registered_keys缓存,O(1)快速路径跳过
10. GPU context线程安全: TensorRT infer()内部加锁,防止多线程CUDA context竞争

附带修复:
- reset_algorithm中未定义algorithm_type变量(NameError)
- update_roi_params中循环变量key覆盖外层key
- AlertInfo缺少bind_id字段(TypeError)
- _logger.log_alert在标准logger上不存在(AttributeError)
- AlarmStateMachine死锁(Lock改为RLock)
- ROICropper.create_mask坐标解析错误
- 更新测试用例适配新API

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 16:47:26 +08:00

261 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
后处理模块单元测试
"""
import unittest
from unittest.mock import MagicMock, patch
from datetime import datetime
import sys
import os
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class TestNMSProcessor(unittest.TestCase):
"""测试NMS处理器"""
def test_nms_single_box(self):
"""测试单个检测框"""
from core.postprocessor import NMSProcessor
nms = NMSProcessor(nms_threshold=0.45)
boxes = np.array([[100, 100, 200, 200]])
scores = np.array([0.9])
class_ids = np.array([0])
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
self.assertEqual(len(keep_boxes), 1)
def test_nms_multiple_boxes(self):
"""测试多个检测框高IoU重叠框应被抑制"""
from core.postprocessor import NMSProcessor
nms = NMSProcessor(nms_threshold=0.45)
# box1 和 box2 高度重叠 (IoU > 0.45)box3 独立
boxes = np.array([
[100, 100, 200, 200],
[110, 110, 210, 210],
[300, 300, 400, 400]
])
scores = np.array([0.9, 0.85, 0.8])
class_ids = np.array([0, 0, 0])
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
self.assertLessEqual(len(keep_boxes), 2)
def test_nms_empty_boxes(self):
"""测试空检测框"""
from core.postprocessor import NMSProcessor
nms = NMSProcessor()
boxes = np.array([]).reshape(0, 4)
scores = np.array([])
class_ids = np.array([])
keep_boxes, keep_scores, keep_classes = nms.process(boxes, scores, class_ids)
self.assertEqual(len(keep_boxes), 0)
class TestCoordinateMapper(unittest.TestCase):
"""测试坐标映射"""
def test_map_from_letterbox(self):
"""测试从Letterbox空间映射"""
from core.postprocessor import CoordinateMapper
mapper = CoordinateMapper()
box = [120, 120, 360, 360]
scale_info = (0.375, 60, 60, 0.375)
original_size = (1280, 720)
mapped = mapper.map_from_letterbox(box, scale_info, original_size)
self.assertEqual(len(mapped), 4)
self.assertGreater(mapped[0], 0)
def test_get_box_center(self):
"""测试获取中心点"""
from core.postprocessor import CoordinateMapper
mapper = CoordinateMapper()
center = mapper.get_box_center([100, 100, 200, 200])
self.assertEqual(center[0], 150)
self.assertEqual(center[1], 150)
def test_get_box_bottom_center(self):
"""测试获取底部中心点"""
from core.postprocessor import CoordinateMapper
mapper = CoordinateMapper()
bottom = mapper.get_box_bottom_center([100, 100, 200, 200])
self.assertEqual(bottom[0], 150)
self.assertEqual(bottom[1], 200)
class TestROIAnalyzer(unittest.TestCase):
"""测试ROI分析器"""
def test_is_point_in_roi(self):
"""测试点在ROI内"""
from config.config_models import ROIInfo, ROIType, AlgorithmType
from core.postprocessor import ROIAnalyzer
roi = ROIInfo(
roi_id="roi001",
camera_id="cam001",
roi_type=ROIType.RECTANGLE,
coordinates=[[100, 100], [200, 200]],
algorithm_type=AlgorithmType.LEAVE_POST,
)
analyzer = ROIAnalyzer()
self.assertTrue(analyzer.is_point_in_roi((150, 150), roi))
self.assertFalse(analyzer.is_point_in_roi((250, 250), roi))
def test_is_detection_in_roi(self):
"""测试检测在ROI内"""
from config.config_models import ROIInfo, ROIType, AlgorithmType
from core.postprocessor import ROIAnalyzer
roi = ROIInfo(
roi_id="roi001",
camera_id="cam001",
roi_type=ROIType.RECTANGLE,
coordinates=[[100, 100], [200, 200]],
algorithm_type=AlgorithmType.LEAVE_POST,
)
analyzer = ROIAnalyzer()
box = [120, 120, 180, 180]
self.assertTrue(analyzer.is_detection_in_roi(box, roi, "center"))
box_outside = [250, 250, 300, 300]
self.assertFalse(analyzer.is_detection_in_roi(box_outside, roi, "center"))
class TestAlarmStateMachine(unittest.TestCase):
"""测试告警状态机"""
def test_state_machine_creation(self):
"""测试状态机创建"""
from core.postprocessor import AlarmStateMachine
machine = AlarmStateMachine(
alert_threshold=3,
alert_cooldown=300
)
self.assertEqual(machine.alert_threshold, 3)
self.assertEqual(machine.alert_cooldown, 300)
def test_update_detection(self):
"""测试更新检测状态"""
from core.postprocessor import AlarmStateMachine
machine = AlarmStateMachine(alert_threshold=3)
for i in range(3):
result = machine.update("roi001", True)
self.assertTrue(result["should_alert"])
self.assertEqual(result["reason"], "threshold_reached")
def test_update_no_detection(self):
"""测试无检测更新"""
from core.postprocessor import AlarmStateMachine
machine = AlarmStateMachine(alert_threshold=3)
result = machine.update("roi001", False)
self.assertFalse(result["should_alert"])
def test_reset(self):
"""测试重置"""
from core.postprocessor import AlarmStateMachine
machine = AlarmStateMachine(alert_threshold=3)
for i in range(3):
machine.update("roi001", True)
machine.reset("roi001")
state = machine.get_state("roi001")
self.assertEqual(state.detection_count, 0)
class TestPostProcessor(unittest.TestCase):
"""测试后处理器"""
def test_process_detections(self):
"""测试处理检测结果"""
from core.postprocessor import PostProcessor
processor = PostProcessor()
outputs = [np.random.randn(1, 10, 100).astype(np.float32)]
boxes, scores, class_ids = processor.process_detections(outputs)
self.assertEqual(len(boxes.shape), 2)
def test_check_alarm_condition(self):
"""测试检查告警条件"""
from core.postprocessor import PostProcessor
processor = PostProcessor()
result = processor.check_alarm_condition("roi001", True)
self.assertIn("should_alert", result)
self.assertIn("detection_count", result)
def test_create_alert_info(self):
"""测试创建告警信息"""
from core.postprocessor import PostProcessor
processor = PostProcessor()
alert = processor.create_alert_info(
roi_id="roi001",
camera_id="cam001",
detection_results={
"class_name": "person",
"confidence": 0.95,
"bbox": [100, 100, 200, 200]
},
message="离岗告警"
)
self.assertEqual(alert.roi_id, "roi001")
self.assertEqual(alert.camera_id, "cam001")
def test_get_statistics(self):
"""测试获取统计"""
from core.postprocessor import PostProcessor
processor = PostProcessor()
stats = processor.get_statistics()
self.assertIn("nms_threshold", stats)
self.assertIn("conf_threshold", stats)
if __name__ == "__main__":
unittest.main()