重构: 模型类别配置动态化,适配 v2+ (5 类) 模型

- settings.py: 删除 80 类 COCO_CLASS_NAMES 硬编码,改为 MODEL_CLASS_NAMES
  + 新增 MODEL_NUM_CLASSES / MODEL_OUTPUT_CHANNELS 辅助常量
  + 新增 ALGO_INTENT_CLASSES 字典 + get_algo_target_classes() 辅助函数
  + COCO_CLASS_NAMES 保留名称向后兼容,指向 MODEL_CLASS_NAMES

- postprocessor.py: 4 处硬编码 84 → MODEL_OUTPUT_CHANNELS
  + 支持不同类别数模型切换不改代码

- algorithms.py: 4 处硬编码 target_classes 默认值 → get_algo_target_classes()
  + IllegalParkingAlgorithm / VehicleCongestionAlgorithm /
    NonMotorVehicleParkingAlgorithm / GarbageDetectionAlgorithm
  + 自动过滤当前模型不支持的类(truck/bus 等)

以后换模型只需修改 settings.py 一处 MODEL_CLASS_NAMES。
This commit is contained in:
2026-04-21 14:46:53 +08:00
parent a891deba00
commit 003c2885b9
3 changed files with 52 additions and 29 deletions

View File

@@ -11,6 +11,8 @@ import cv2
import numpy as np
import redis
from config.settings import get_algo_target_classes
logger = logging.getLogger(__name__)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -800,7 +802,7 @@ class IllegalParkingAlgorithm(BaseAlgorithm):
self.parking_countdown_sec = parking_countdown_sec
self.confirm_clear_sec = confirm_clear_sec
self.cooldown_sec = cooldown_sec
self.target_classes = target_classes or ["car", "truck", "bus", "motorcycle"]
self.target_classes = target_classes or get_algo_target_classes("illegal_parking")
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
# 状态变量
@@ -1070,7 +1072,7 @@ class VehicleCongestionAlgorithm(BaseAlgorithm):
self.confirm_congestion_sec = confirm_congestion_sec
self.confirm_clear_sec = confirm_clear_sec
self.cooldown_sec = cooldown_sec
self.target_classes = target_classes or ["car", "truck", "bus", "motorcycle"]
self.target_classes = target_classes or get_algo_target_classes("vehicle_congestion")
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
self.dissipation_ratio = dissipation_ratio # Step 6
@@ -1336,7 +1338,7 @@ class NonMotorVehicleParkingAlgorithm(BaseAlgorithm):
self.parking_countdown_sec = parking_countdown_sec
self.confirm_clear_sec = confirm_clear_sec
self.cooldown_sec = cooldown_sec
self.target_classes = target_classes or ["bicycle", "motorcycle"]
self.target_classes = target_classes or get_algo_target_classes("non_motor_vehicle_parking")
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
# 状态变量
@@ -1599,7 +1601,7 @@ class GarbageDetectionAlgorithm(BaseAlgorithm):
self.confirm_garbage_sec = confirm_garbage_sec
self.confirm_clear_sec = confirm_clear_sec
self.cooldown_sec = cooldown_sec
self.target_classes = target_classes or ["garbage"]
self.target_classes = target_classes or get_algo_target_classes("garbage")
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
# 状态变量

View File

@@ -143,19 +143,40 @@ class DebugConfig:
local_config_path: str = "./config/local_config.json"
# COCO 数据集类别名称YOLO 模型使用
COCO_CLASS_NAMES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
"mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
]
# 模型类别配置(支持不同模型切换,修改此处即可
# 当前: yolo11s_v2plus_20260421 (5 类定制模型)
# 历史: yolo11n 官方 (80 类 COCO),列表见 git 历史
MODEL_CLASS_NAMES = ["garbage", "person", "car", "bicycle", "motorcycle"]
MODEL_NUM_CLASSES = len(MODEL_CLASS_NAMES) # 模型类别数
MODEL_OUTPUT_CHANNELS = 4 + MODEL_NUM_CLASSES # YOLO 输出通道 = 4(xywh) + nc
# 向后兼容:保留 COCO_CLASS_NAMES 名称,指向当前模型类别
COCO_CLASS_NAMES = MODEL_CLASS_NAMES
# 各算法的业务关注类别(全集,不限定当前模型是否支持)
# 部署时通过 get_algo_target_classes() 自动过滤出当前模型支持的子集
ALGO_INTENT_CLASSES = {
"leave_post": ["person"],
"intrusion": ["person"],
"illegal_parking": ["car", "truck", "bus"],
"vehicle_congestion": ["car", "truck", "bus", "motorcycle"],
"non_motor_vehicle_parking": ["bicycle", "motorcycle"],
"garbage": ["garbage"],
}
def get_algo_target_classes(algo_code: str) -> list:
"""获取算法的目标类别,自动过滤出当前模型支持的类
Args:
algo_code: 算法代码 (leave_post / intrusion / illegal_parking / ...)
Returns:
目标类名列表,保证每个类都在 MODEL_CLASS_NAMES 中
"""
intent = ALGO_INTENT_CLASSES.get(algo_code, [])
return [c for c in intent if c in MODEL_CLASS_NAMES]
@dataclass

View File

@@ -20,7 +20,7 @@ try:
except ImportError:
_HAS_TORCH = False
from config.settings import get_settings
from config.settings import get_settings, MODEL_OUTPUT_CHANNELS, MODEL_NUM_CLASSES
from config.config_models import ROIInfo, ROIType, AlertInfo, AlertLevel
from utils.logger import get_logger
from utils.common import generate_unique_id
@@ -667,14 +667,14 @@ class PostProcessor:
if output.ndim != 2:
return np.array([]), np.array([]), np.array([])
if output.shape[0] != 84:
if output.shape[0] != MODEL_OUTPUT_CHANNELS:
return np.array([]), np.array([]), np.array([])
boxes_xywh = output[0:4, :].T
# YOLO11 输出格式: [4+num_classes, 8400]
# 前4行是 xywh80行是各类别分数,没有单独的 objectness 行
class_scores = output[4:, :] # [num_classes, 8400]
# YOLO11 输出格式: [4+num_classes, num_anchors]
# 前4行是 xywh num_classes 行是各类别分数,没有单独的 objectness 行
class_scores = output[4:, :] # [num_classes, num_anchors]
scores = class_scores.max(axis=0) # 取各类别最大分数
class_ids = class_scores.argmax(axis=0) # 对应类别ID
@@ -734,15 +734,15 @@ class PostProcessor:
first_output = batch_outputs[0]
if isinstance(first_output, np.ndarray):
if first_output.ndim == 3 and first_output.shape[0] == batch_size:
# 已经是 (batch, 84, anchors) 格式
# 已经是 (batch, 4+nc, anchors) 格式
outputs_array = first_output
elif first_output.ndim == 1:
# TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 84, anchors)
# TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 4+nc, anchors)
per_image = first_output.shape[0] // batch_size
num_anchors = per_image // 84
outputs_array = first_output.reshape(batch_size, 84, num_anchors)
num_anchors = per_image // MODEL_OUTPUT_CHANNELS
outputs_array = first_output.reshape(batch_size, MODEL_OUTPUT_CHANNELS, num_anchors)
elif first_output.ndim == 2:
# (84, anchors) 单张图的输出
# (4+nc, anchors) 单张图的输出
outputs_array = first_output.reshape(1, first_output.shape[0], first_output.shape[1])
else:
outputs_array = first_output
@@ -771,7 +771,7 @@ class PostProcessor:
if output.ndim == 3:
output = output[0]
if output.shape[0] != 84:
if output.shape[0] != MODEL_OUTPUT_CHANNELS:
results.append((np.array([]), np.array([]), np.array([])))
continue