From 003c2885b96a4238d933be31d11f8609528cf7a5 Mon Sep 17 00:00:00 2001 From: 16337 <1633794139@qq.com> Date: Tue, 21 Apr 2026 14:46:53 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84:=20=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E7=B1=BB=E5=88=AB=E9=85=8D=E7=BD=AE=E5=8A=A8=E6=80=81=E5=8C=96?= =?UTF-8?q?=EF=BC=8C=E9=80=82=E9=85=8D=20v2+=20(5=20=E7=B1=BB)=20=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - settings.py: 删除 80 类 COCO_CLASS_NAMES 硬编码,改为 MODEL_CLASS_NAMES + 新增 MODEL_NUM_CLASSES / MODEL_OUTPUT_CHANNELS 辅助常量 + 新增 ALGO_INTENT_CLASSES 字典 + get_algo_target_classes() 辅助函数 + COCO_CLASS_NAMES 保留名称向后兼容,指向 MODEL_CLASS_NAMES - postprocessor.py: 4 处硬编码 84 → MODEL_OUTPUT_CHANNELS + 支持不同类别数模型切换不改代码 - algorithms.py: 4 处硬编码 target_classes 默认值 → get_algo_target_classes() + IllegalParkingAlgorithm / VehicleCongestionAlgorithm / NonMotorVehicleParkingAlgorithm / GarbageDetectionAlgorithm + 自动过滤当前模型不支持的类(truck/bus 等) 以后换模型只需修改 settings.py 一处 MODEL_CLASS_NAMES。 --- algorithms.py | 10 +++++---- config/settings.py | 47 +++++++++++++++++++++++++++++++------------ core/postprocessor.py | 24 +++++++++++----------- 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/algorithms.py b/algorithms.py index 15983f2..79dc223 100644 --- a/algorithms.py +++ b/algorithms.py @@ -11,6 +11,8 @@ import cv2 import numpy as np import redis +from config.settings import get_algo_target_classes + logger = logging.getLogger(__name__) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -800,7 +802,7 @@ class IllegalParkingAlgorithm(BaseAlgorithm): self.parking_countdown_sec = parking_countdown_sec self.confirm_clear_sec = confirm_clear_sec self.cooldown_sec = cooldown_sec - self.target_classes = target_classes or ["car", "truck", "bus", "motorcycle"] + self.target_classes = target_classes or get_algo_target_classes("illegal_parking") self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL # 状态变量 @@ -1070,7 +1072,7 @@ class VehicleCongestionAlgorithm(BaseAlgorithm): self.confirm_congestion_sec = confirm_congestion_sec self.confirm_clear_sec = confirm_clear_sec self.cooldown_sec = cooldown_sec - self.target_classes = target_classes or ["car", "truck", "bus", "motorcycle"] + self.target_classes = target_classes or get_algo_target_classes("vehicle_congestion") self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL self.dissipation_ratio = dissipation_ratio # Step 6 @@ -1336,7 +1338,7 @@ class NonMotorVehicleParkingAlgorithm(BaseAlgorithm): self.parking_countdown_sec = parking_countdown_sec self.confirm_clear_sec = confirm_clear_sec self.cooldown_sec = cooldown_sec - self.target_classes = target_classes or ["bicycle", "motorcycle"] + self.target_classes = target_classes or get_algo_target_classes("non_motor_vehicle_parking") self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL # 状态变量 @@ -1599,7 +1601,7 @@ class GarbageDetectionAlgorithm(BaseAlgorithm): self.confirm_garbage_sec = confirm_garbage_sec self.confirm_clear_sec = confirm_clear_sec self.cooldown_sec = cooldown_sec - self.target_classes = target_classes or ["garbage"] + self.target_classes = target_classes or get_algo_target_classes("garbage") self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL # 状态变量 diff --git a/config/settings.py b/config/settings.py index d31ae3e..7ea4aec 100644 --- a/config/settings.py +++ b/config/settings.py @@ -143,19 +143,40 @@ class DebugConfig: local_config_path: str = "./config/local_config.json" -# COCO 数据集类别名称(YOLO 模型使用) -COCO_CLASS_NAMES = [ - "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", - "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", - "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", - "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", - "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", - "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", - "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", - "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", - "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", - "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush" -] +# 模型类别配置(支持不同模型切换,修改此处即可) +# 当前: yolo11s_v2plus_20260421 (5 类定制模型) +# 历史: yolo11n 官方 (80 类 COCO),列表见 git 历史 +MODEL_CLASS_NAMES = ["garbage", "person", "car", "bicycle", "motorcycle"] +MODEL_NUM_CLASSES = len(MODEL_CLASS_NAMES) # 模型类别数 +MODEL_OUTPUT_CHANNELS = 4 + MODEL_NUM_CLASSES # YOLO 输出通道 = 4(xywh) + nc + +# 向后兼容:保留 COCO_CLASS_NAMES 名称,指向当前模型类别 +COCO_CLASS_NAMES = MODEL_CLASS_NAMES + + +# 各算法的业务关注类别(全集,不限定当前模型是否支持) +# 部署时通过 get_algo_target_classes() 自动过滤出当前模型支持的子集 +ALGO_INTENT_CLASSES = { + "leave_post": ["person"], + "intrusion": ["person"], + "illegal_parking": ["car", "truck", "bus"], + "vehicle_congestion": ["car", "truck", "bus", "motorcycle"], + "non_motor_vehicle_parking": ["bicycle", "motorcycle"], + "garbage": ["garbage"], +} + + +def get_algo_target_classes(algo_code: str) -> list: + """获取算法的目标类别,自动过滤出当前模型支持的类 + + Args: + algo_code: 算法代码 (leave_post / intrusion / illegal_parking / ...) + + Returns: + 目标类名列表,保证每个类都在 MODEL_CLASS_NAMES 中 + """ + intent = ALGO_INTENT_CLASSES.get(algo_code, []) + return [c for c in intent if c in MODEL_CLASS_NAMES] @dataclass diff --git a/core/postprocessor.py b/core/postprocessor.py index 92590e2..76a3fc8 100644 --- a/core/postprocessor.py +++ b/core/postprocessor.py @@ -20,7 +20,7 @@ try: except ImportError: _HAS_TORCH = False -from config.settings import get_settings +from config.settings import get_settings, MODEL_OUTPUT_CHANNELS, MODEL_NUM_CLASSES from config.config_models import ROIInfo, ROIType, AlertInfo, AlertLevel from utils.logger import get_logger from utils.common import generate_unique_id @@ -667,14 +667,14 @@ class PostProcessor: if output.ndim != 2: return np.array([]), np.array([]), np.array([]) - if output.shape[0] != 84: + if output.shape[0] != MODEL_OUTPUT_CHANNELS: return np.array([]), np.array([]), np.array([]) - + boxes_xywh = output[0:4, :].T - # YOLO11 输出格式: [4+num_classes, 8400] - # 前4行是 xywh,后80行是各类别分数,没有单独的 objectness 行 - class_scores = output[4:, :] # [num_classes, 8400] + # YOLO11 输出格式: [4+num_classes, num_anchors] + # 前4行是 xywh,后 num_classes 行是各类别分数,没有单独的 objectness 行 + class_scores = output[4:, :] # [num_classes, num_anchors] scores = class_scores.max(axis=0) # 取各类别最大分数 class_ids = class_scores.argmax(axis=0) # 对应类别ID @@ -734,15 +734,15 @@ class PostProcessor: first_output = batch_outputs[0] if isinstance(first_output, np.ndarray): if first_output.ndim == 3 and first_output.shape[0] == batch_size: - # 已经是 (batch, 84, anchors) 格式 + # 已经是 (batch, 4+nc, anchors) 格式 outputs_array = first_output elif first_output.ndim == 1: - # TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 84, anchors) + # TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 4+nc, anchors) per_image = first_output.shape[0] // batch_size - num_anchors = per_image // 84 - outputs_array = first_output.reshape(batch_size, 84, num_anchors) + num_anchors = per_image // MODEL_OUTPUT_CHANNELS + outputs_array = first_output.reshape(batch_size, MODEL_OUTPUT_CHANNELS, num_anchors) elif first_output.ndim == 2: - # (84, anchors) 单张图的输出 + # (4+nc, anchors) 单张图的输出 outputs_array = first_output.reshape(1, first_output.shape[0], first_output.shape[1]) else: outputs_array = first_output @@ -771,7 +771,7 @@ class PostProcessor: if output.ndim == 3: output = output[0] - if output.shape[0] != 84: + if output.shape[0] != MODEL_OUTPUT_CHANNELS: results.append((np.array([]), np.array([]), np.array([]))) continue