diff --git a/config/settings.py b/config/settings.py index 7e6746d..7fc58af 100644 --- a/config/settings.py +++ b/config/settings.py @@ -5,7 +5,7 @@ import os from dataclasses import dataclass, field -from typing import List, Optional +from typing import Dict, List, Optional @dataclass @@ -117,10 +117,17 @@ class InferenceConfig: input_width: int = 480 input_height: int = 480 batch_size: int = 1 - conf_threshold: float = 0.4 - nms_threshold: float = 0.45 + conf_threshold: float = 0.45 + nms_threshold: float = 0.5 device_id: int = 0 fp16_mode: bool = True + # 按算法类型覆盖置信度阈值,key=algo_code, value=threshold + # 未命中时回退到 conf_threshold + algo_conf_thresholds: Dict[str, float] = field(default_factory=dict) + + def get_conf_threshold(self, algo_code: str) -> float: + """获取指定算法的置信度阈值,未配置则回退全局值""" + return self.algo_conf_thresholds.get(algo_code, self.conf_threshold) # ===================== Debug / Local Sync ===================== @@ -273,8 +280,9 @@ class Settings: input_width=int(os.getenv("INPUT_WIDTH", "480")), input_height=int(os.getenv("INPUT_HEIGHT", "480")), batch_size=int(os.getenv("BATCH_SIZE", "4")), - conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.4")), - nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")), + conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.45")), + nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.5")), + algo_conf_thresholds=self._parse_algo_conf_thresholds(), ) self.config_sync_mode = os.getenv("CONFIG_SYNC_MODE", "LOCAL").upper() @@ -297,6 +305,25 @@ class Settings: # 使用 COCO 类别名称 self.class_names = COCO_CLASS_NAMES + @staticmethod + def _parse_algo_conf_thresholds() -> Dict[str, float]: + """解析 ALGO_CONF_* 环境变量,返回 {algo_code: threshold} 字典 + + 环境变量命名规则: ALGO_CONF_{ALGO_CODE},如: + ALGO_CONF_LEAVE_POST=0.35 + ALGO_CONF_INTRUSION=0.55 + """ + prefix = "ALGO_CONF_" + result = {} + for key, value in os.environ.items(): + if key.startswith(prefix) and value: + algo_code = key[len(prefix):].lower() + try: + result[algo_code] = float(value) + except ValueError: + pass + return result + def _parse_working_hours(self) -> List[dict]: """解析工作时间配置""" working_hours_str = os.getenv("WORKING_HOURS", "") diff --git a/core/postprocessor.py b/core/postprocessor.py index 87a87f2..1c67d01 100644 --- a/core/postprocessor.py +++ b/core/postprocessor.py @@ -35,7 +35,7 @@ class NMSProcessor: 可选 GPU 加速 (torchvision.ops.nms) """ - def __init__(self, nms_threshold: float = 0.45, use_gpu: bool = False): + def __init__(self, nms_threshold: float = 0.5, use_gpu: bool = False): self.nms_threshold = nms_threshold self.use_gpu = use_gpu and _HAS_TORCH self._logger = get_logger("postprocessor") @@ -162,7 +162,7 @@ class NMSProcessor: boxes: np.ndarray, scores: np.ndarray, class_ids: np.ndarray, - conf_threshold: float = 0.5, + conf_threshold: float = 0.45, max_output_size: int = 300 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ @@ -579,8 +579,8 @@ class PostProcessor: "alert_cooldown": 300, } - self.nms_threshold = config.get("nms_threshold", 0.45) - self.conf_threshold = config.get("conf_threshold", 0.5) + self.nms_threshold = config.get("nms_threshold", 0.5) + self.conf_threshold = config.get("conf_threshold", 0.45) MAX_DETECTIONS = 8400 self._buffer_xyxy = np.zeros((MAX_DETECTIONS, 4), dtype=np.float32) @@ -707,22 +707,24 @@ class PostProcessor: batch_outputs: List[np.ndarray], batch_size: int, conf_threshold: Optional[float] = None, - nms_threshold: Optional[float] = None + nms_threshold: Optional[float] = None, + per_item_conf_thresholds: Optional[List[float]] = None, ) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]: """ 批量处理多 ROI 检测结果 - + Args: batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组) batch_size: batch 大小 conf_threshold: 置信度阈值 nms_threshold: NMS阈值 - + per_item_conf_thresholds: 每个 batch item 独立的置信度阈值列表, + 长度必须等于 batch_size。传入时忽略 conf_threshold 参数。 + Returns: List of (检测框, 置信度, 类别ID) per ROI """ - if conf_threshold is None: - conf_threshold = self.conf_threshold + default_conf = conf_threshold if conf_threshold is not None else self.conf_threshold if nms_threshold is None: nms_threshold = self.nms_threshold @@ -793,11 +795,16 @@ class PostProcessor: self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2 self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2 + item_conf = ( + per_item_conf_thresholds[batch_idx] + if per_item_conf_thresholds is not None + else default_conf + ) keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter( self._buffer_xyxy[:valid_count], scores_coarse.astype(np.float32), class_ids, - conf_threshold + item_conf ) results.append((keep_boxes, keep_scores, keep_classes)) diff --git a/main.py b/main.py index 149f1e3..f7db103 100644 --- a/main.py +++ b/main.py @@ -598,14 +598,24 @@ class EdgeInferenceService: self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms") batch_size = len(chunk) + + # 按算法类型获取每个 item 的独立置信度阈值 + per_item_conf = [ + self._settings.inference.get_conf_threshold(item[2].algo_code) + for item in chunk + ] + batch_results = self._postprocessor.batch_process_detections( outputs, batch_size, - conf_threshold=self._settings.inference.conf_threshold + per_item_conf_thresholds=per_item_conf, ) total_detections = sum(len(r[0]) for r in batch_results) - self._logger.debug(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}") + self._logger.debug( + f"[推理] batch_size={batch_size}, 总检测数={total_detections}, " + f"conf_thresholds={per_item_conf}" + ) for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(chunk): boxes, scores, class_ids = batch_results[idx]