feat: 支持按算法独立配置置信度阈值
通过 ALGO_CONF_{ALGO_CODE} 环境变量为每个算法设置独立的 conf_threshold,
未配置的算法回退到全局 CONF_THRESHOLD。推理过程零改动,仅后处理过滤阶段
按 bind.algo_code 使用对应阈值。
当前配置:离岗=0.4(降低漏检),入侵=0.5(减少误报)。
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -117,10 +117,17 @@ class InferenceConfig:
|
||||
input_width: int = 480
|
||||
input_height: int = 480
|
||||
batch_size: int = 1
|
||||
conf_threshold: float = 0.4
|
||||
nms_threshold: float = 0.45
|
||||
conf_threshold: float = 0.45
|
||||
nms_threshold: float = 0.5
|
||||
device_id: int = 0
|
||||
fp16_mode: bool = True
|
||||
# 按算法类型覆盖置信度阈值,key=algo_code, value=threshold
|
||||
# 未命中时回退到 conf_threshold
|
||||
algo_conf_thresholds: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def get_conf_threshold(self, algo_code: str) -> float:
|
||||
"""获取指定算法的置信度阈值,未配置则回退全局值"""
|
||||
return self.algo_conf_thresholds.get(algo_code, self.conf_threshold)
|
||||
|
||||
|
||||
# ===================== Debug / Local Sync =====================
|
||||
@@ -273,8 +280,9 @@ class Settings:
|
||||
input_width=int(os.getenv("INPUT_WIDTH", "480")),
|
||||
input_height=int(os.getenv("INPUT_HEIGHT", "480")),
|
||||
batch_size=int(os.getenv("BATCH_SIZE", "4")),
|
||||
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.4")),
|
||||
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")),
|
||||
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.45")),
|
||||
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.5")),
|
||||
algo_conf_thresholds=self._parse_algo_conf_thresholds(),
|
||||
)
|
||||
|
||||
self.config_sync_mode = os.getenv("CONFIG_SYNC_MODE", "LOCAL").upper()
|
||||
@@ -297,6 +305,25 @@ class Settings:
|
||||
# 使用 COCO 类别名称
|
||||
self.class_names = COCO_CLASS_NAMES
|
||||
|
||||
@staticmethod
|
||||
def _parse_algo_conf_thresholds() -> Dict[str, float]:
|
||||
"""解析 ALGO_CONF_* 环境变量,返回 {algo_code: threshold} 字典
|
||||
|
||||
环境变量命名规则: ALGO_CONF_{ALGO_CODE},如:
|
||||
ALGO_CONF_LEAVE_POST=0.35
|
||||
ALGO_CONF_INTRUSION=0.55
|
||||
"""
|
||||
prefix = "ALGO_CONF_"
|
||||
result = {}
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith(prefix) and value:
|
||||
algo_code = key[len(prefix):].lower()
|
||||
try:
|
||||
result[algo_code] = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _parse_working_hours(self) -> List[dict]:
|
||||
"""解析工作时间配置"""
|
||||
working_hours_str = os.getenv("WORKING_HOURS", "")
|
||||
|
||||
@@ -35,7 +35,7 @@ class NMSProcessor:
|
||||
可选 GPU 加速 (torchvision.ops.nms)
|
||||
"""
|
||||
|
||||
def __init__(self, nms_threshold: float = 0.45, use_gpu: bool = False):
|
||||
def __init__(self, nms_threshold: float = 0.5, use_gpu: bool = False):
|
||||
self.nms_threshold = nms_threshold
|
||||
self.use_gpu = use_gpu and _HAS_TORCH
|
||||
self._logger = get_logger("postprocessor")
|
||||
@@ -162,7 +162,7 @@ class NMSProcessor:
|
||||
boxes: np.ndarray,
|
||||
scores: np.ndarray,
|
||||
class_ids: np.ndarray,
|
||||
conf_threshold: float = 0.5,
|
||||
conf_threshold: float = 0.45,
|
||||
max_output_size: int = 300
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
@@ -579,8 +579,8 @@ class PostProcessor:
|
||||
"alert_cooldown": 300,
|
||||
}
|
||||
|
||||
self.nms_threshold = config.get("nms_threshold", 0.45)
|
||||
self.conf_threshold = config.get("conf_threshold", 0.5)
|
||||
self.nms_threshold = config.get("nms_threshold", 0.5)
|
||||
self.conf_threshold = config.get("conf_threshold", 0.45)
|
||||
|
||||
MAX_DETECTIONS = 8400
|
||||
self._buffer_xyxy = np.zeros((MAX_DETECTIONS, 4), dtype=np.float32)
|
||||
@@ -707,22 +707,24 @@ class PostProcessor:
|
||||
batch_outputs: List[np.ndarray],
|
||||
batch_size: int,
|
||||
conf_threshold: Optional[float] = None,
|
||||
nms_threshold: Optional[float] = None
|
||||
nms_threshold: Optional[float] = None,
|
||||
per_item_conf_thresholds: Optional[List[float]] = None,
|
||||
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
批量处理多 ROI 检测结果
|
||||
|
||||
|
||||
Args:
|
||||
batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组)
|
||||
batch_size: batch 大小
|
||||
conf_threshold: 置信度阈值
|
||||
nms_threshold: NMS阈值
|
||||
|
||||
per_item_conf_thresholds: 每个 batch item 独立的置信度阈值列表,
|
||||
长度必须等于 batch_size。传入时忽略 conf_threshold 参数。
|
||||
|
||||
Returns:
|
||||
List of (检测框, 置信度, 类别ID) per ROI
|
||||
"""
|
||||
if conf_threshold is None:
|
||||
conf_threshold = self.conf_threshold
|
||||
default_conf = conf_threshold if conf_threshold is not None else self.conf_threshold
|
||||
if nms_threshold is None:
|
||||
nms_threshold = self.nms_threshold
|
||||
|
||||
@@ -793,11 +795,16 @@ class PostProcessor:
|
||||
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
|
||||
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
|
||||
|
||||
item_conf = (
|
||||
per_item_conf_thresholds[batch_idx]
|
||||
if per_item_conf_thresholds is not None
|
||||
else default_conf
|
||||
)
|
||||
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
|
||||
self._buffer_xyxy[:valid_count],
|
||||
scores_coarse.astype(np.float32),
|
||||
class_ids,
|
||||
conf_threshold
|
||||
item_conf
|
||||
)
|
||||
|
||||
results.append((keep_boxes, keep_scores, keep_classes))
|
||||
|
||||
14
main.py
14
main.py
@@ -598,14 +598,24 @@ class EdgeInferenceService:
|
||||
self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms")
|
||||
|
||||
batch_size = len(chunk)
|
||||
|
||||
# 按算法类型获取每个 item 的独立置信度阈值
|
||||
per_item_conf = [
|
||||
self._settings.inference.get_conf_threshold(item[2].algo_code)
|
||||
for item in chunk
|
||||
]
|
||||
|
||||
batch_results = self._postprocessor.batch_process_detections(
|
||||
outputs,
|
||||
batch_size,
|
||||
conf_threshold=self._settings.inference.conf_threshold
|
||||
per_item_conf_thresholds=per_item_conf,
|
||||
)
|
||||
|
||||
total_detections = sum(len(r[0]) for r in batch_results)
|
||||
self._logger.debug(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}")
|
||||
self._logger.debug(
|
||||
f"[推理] batch_size={batch_size}, 总检测数={total_detections}, "
|
||||
f"conf_thresholds={per_item_conf}"
|
||||
)
|
||||
|
||||
for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(chunk):
|
||||
boxes, scores, class_ids = batch_results[idx]
|
||||
|
||||
Reference in New Issue
Block a user