重构: 模型类别配置动态化,适配 v2+ (5 类) 模型
- settings.py: 删除 80 类 COCO_CLASS_NAMES 硬编码,改为 MODEL_CLASS_NAMES
+ 新增 MODEL_NUM_CLASSES / MODEL_OUTPUT_CHANNELS 辅助常量
+ 新增 ALGO_INTENT_CLASSES 字典 + get_algo_target_classes() 辅助函数
+ COCO_CLASS_NAMES 保留名称向后兼容,指向 MODEL_CLASS_NAMES
- postprocessor.py: 4 处硬编码 84 → MODEL_OUTPUT_CHANNELS
+ 支持不同类别数模型切换不改代码
- algorithms.py: 4 处硬编码 target_classes 默认值 → get_algo_target_classes()
+ IllegalParkingAlgorithm / VehicleCongestionAlgorithm /
NonMotorVehicleParkingAlgorithm / GarbageDetectionAlgorithm
+ 自动过滤当前模型不支持的类(truck/bus 等)
以后换模型只需修改 settings.py 一处 MODEL_CLASS_NAMES。
This commit is contained in:
@@ -11,6 +11,8 @@ import cv2
|
||||
import numpy as np
|
||||
import redis
|
||||
|
||||
from config.settings import get_algo_target_classes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -800,7 +802,7 @@ class IllegalParkingAlgorithm(BaseAlgorithm):
|
||||
self.parking_countdown_sec = parking_countdown_sec
|
||||
self.confirm_clear_sec = confirm_clear_sec
|
||||
self.cooldown_sec = cooldown_sec
|
||||
self.target_classes = target_classes or ["car", "truck", "bus", "motorcycle"]
|
||||
self.target_classes = target_classes or get_algo_target_classes("illegal_parking")
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
@@ -1070,7 +1072,7 @@ class VehicleCongestionAlgorithm(BaseAlgorithm):
|
||||
self.confirm_congestion_sec = confirm_congestion_sec
|
||||
self.confirm_clear_sec = confirm_clear_sec
|
||||
self.cooldown_sec = cooldown_sec
|
||||
self.target_classes = target_classes or ["car", "truck", "bus", "motorcycle"]
|
||||
self.target_classes = target_classes or get_algo_target_classes("vehicle_congestion")
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
self.dissipation_ratio = dissipation_ratio # Step 6
|
||||
|
||||
@@ -1336,7 +1338,7 @@ class NonMotorVehicleParkingAlgorithm(BaseAlgorithm):
|
||||
self.parking_countdown_sec = parking_countdown_sec
|
||||
self.confirm_clear_sec = confirm_clear_sec
|
||||
self.cooldown_sec = cooldown_sec
|
||||
self.target_classes = target_classes or ["bicycle", "motorcycle"]
|
||||
self.target_classes = target_classes or get_algo_target_classes("non_motor_vehicle_parking")
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
@@ -1599,7 +1601,7 @@ class GarbageDetectionAlgorithm(BaseAlgorithm):
|
||||
self.confirm_garbage_sec = confirm_garbage_sec
|
||||
self.confirm_clear_sec = confirm_clear_sec
|
||||
self.cooldown_sec = cooldown_sec
|
||||
self.target_classes = target_classes or ["garbage"]
|
||||
self.target_classes = target_classes or get_algo_target_classes("garbage")
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
|
||||
@@ -143,19 +143,40 @@ class DebugConfig:
|
||||
local_config_path: str = "./config/local_config.json"
|
||||
|
||||
|
||||
# COCO 数据集类别名称(YOLO 模型使用)
|
||||
COCO_CLASS_NAMES = [
|
||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
||||
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
|
||||
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
|
||||
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
|
||||
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
|
||||
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
|
||||
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
|
||||
"mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
|
||||
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
|
||||
]
|
||||
# 模型类别配置(支持不同模型切换,修改此处即可)
|
||||
# 当前: yolo11s_v2plus_20260421 (5 类定制模型)
|
||||
# 历史: yolo11n 官方 (80 类 COCO),列表见 git 历史
|
||||
MODEL_CLASS_NAMES = ["garbage", "person", "car", "bicycle", "motorcycle"]
|
||||
MODEL_NUM_CLASSES = len(MODEL_CLASS_NAMES) # 模型类别数
|
||||
MODEL_OUTPUT_CHANNELS = 4 + MODEL_NUM_CLASSES # YOLO 输出通道 = 4(xywh) + nc
|
||||
|
||||
# 向后兼容:保留 COCO_CLASS_NAMES 名称,指向当前模型类别
|
||||
COCO_CLASS_NAMES = MODEL_CLASS_NAMES
|
||||
|
||||
|
||||
# 各算法的业务关注类别(全集,不限定当前模型是否支持)
|
||||
# 部署时通过 get_algo_target_classes() 自动过滤出当前模型支持的子集
|
||||
ALGO_INTENT_CLASSES = {
|
||||
"leave_post": ["person"],
|
||||
"intrusion": ["person"],
|
||||
"illegal_parking": ["car", "truck", "bus"],
|
||||
"vehicle_congestion": ["car", "truck", "bus", "motorcycle"],
|
||||
"non_motor_vehicle_parking": ["bicycle", "motorcycle"],
|
||||
"garbage": ["garbage"],
|
||||
}
|
||||
|
||||
|
||||
def get_algo_target_classes(algo_code: str) -> list:
|
||||
"""获取算法的目标类别,自动过滤出当前模型支持的类
|
||||
|
||||
Args:
|
||||
algo_code: 算法代码 (leave_post / intrusion / illegal_parking / ...)
|
||||
|
||||
Returns:
|
||||
目标类名列表,保证每个类都在 MODEL_CLASS_NAMES 中
|
||||
"""
|
||||
intent = ALGO_INTENT_CLASSES.get(algo_code, [])
|
||||
return [c for c in intent if c in MODEL_CLASS_NAMES]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -20,7 +20,7 @@ try:
|
||||
except ImportError:
|
||||
_HAS_TORCH = False
|
||||
|
||||
from config.settings import get_settings
|
||||
from config.settings import get_settings, MODEL_OUTPUT_CHANNELS, MODEL_NUM_CLASSES
|
||||
from config.config_models import ROIInfo, ROIType, AlertInfo, AlertLevel
|
||||
from utils.logger import get_logger
|
||||
from utils.common import generate_unique_id
|
||||
@@ -667,14 +667,14 @@ class PostProcessor:
|
||||
if output.ndim != 2:
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
if output.shape[0] != 84:
|
||||
if output.shape[0] != MODEL_OUTPUT_CHANNELS:
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
|
||||
boxes_xywh = output[0:4, :].T
|
||||
|
||||
# YOLO11 输出格式: [4+num_classes, 8400]
|
||||
# 前4行是 xywh,后80行是各类别分数,没有单独的 objectness 行
|
||||
class_scores = output[4:, :] # [num_classes, 8400]
|
||||
# YOLO11 输出格式: [4+num_classes, num_anchors]
|
||||
# 前4行是 xywh,后 num_classes 行是各类别分数,没有单独的 objectness 行
|
||||
class_scores = output[4:, :] # [num_classes, num_anchors]
|
||||
scores = class_scores.max(axis=0) # 取各类别最大分数
|
||||
class_ids = class_scores.argmax(axis=0) # 对应类别ID
|
||||
|
||||
@@ -734,15 +734,15 @@ class PostProcessor:
|
||||
first_output = batch_outputs[0]
|
||||
if isinstance(first_output, np.ndarray):
|
||||
if first_output.ndim == 3 and first_output.shape[0] == batch_size:
|
||||
# 已经是 (batch, 84, anchors) 格式
|
||||
# 已经是 (batch, 4+nc, anchors) 格式
|
||||
outputs_array = first_output
|
||||
elif first_output.ndim == 1:
|
||||
# TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 84, anchors)
|
||||
# TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 4+nc, anchors)
|
||||
per_image = first_output.shape[0] // batch_size
|
||||
num_anchors = per_image // 84
|
||||
outputs_array = first_output.reshape(batch_size, 84, num_anchors)
|
||||
num_anchors = per_image // MODEL_OUTPUT_CHANNELS
|
||||
outputs_array = first_output.reshape(batch_size, MODEL_OUTPUT_CHANNELS, num_anchors)
|
||||
elif first_output.ndim == 2:
|
||||
# (84, anchors) 单张图的输出
|
||||
# (4+nc, anchors) 单张图的输出
|
||||
outputs_array = first_output.reshape(1, first_output.shape[0], first_output.shape[1])
|
||||
else:
|
||||
outputs_array = first_output
|
||||
@@ -771,7 +771,7 @@ class PostProcessor:
|
||||
if output.ndim == 3:
|
||||
output = output[0]
|
||||
|
||||
if output.shape[0] != 84:
|
||||
if output.shape[0] != MODEL_OUTPUT_CHANNELS:
|
||||
results.append((np.array([]), np.array([]), np.array([])))
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user