重构: 模型类别配置动态化,适配 v2+ (5 类) 模型
- settings.py: 删除 80 类 COCO_CLASS_NAMES 硬编码,改为 MODEL_CLASS_NAMES
+ 新增 MODEL_NUM_CLASSES / MODEL_OUTPUT_CHANNELS 辅助常量
+ 新增 ALGO_INTENT_CLASSES 字典 + get_algo_target_classes() 辅助函数
+ COCO_CLASS_NAMES 保留名称向后兼容,指向 MODEL_CLASS_NAMES
- postprocessor.py: 4 处硬编码 84 → MODEL_OUTPUT_CHANNELS
+ 支持不同类别数模型切换不改代码
- algorithms.py: 4 处硬编码 target_classes 默认值 → get_algo_target_classes()
+ IllegalParkingAlgorithm / VehicleCongestionAlgorithm /
NonMotorVehicleParkingAlgorithm / GarbageDetectionAlgorithm
+ 自动过滤当前模型不支持的类(truck/bus 等)
以后换模型只需修改 settings.py 一处 MODEL_CLASS_NAMES。
This commit is contained in:
@@ -20,7 +20,7 @@ try:
|
||||
except ImportError:
|
||||
_HAS_TORCH = False
|
||||
|
||||
from config.settings import get_settings
|
||||
from config.settings import get_settings, MODEL_OUTPUT_CHANNELS, MODEL_NUM_CLASSES
|
||||
from config.config_models import ROIInfo, ROIType, AlertInfo, AlertLevel
|
||||
from utils.logger import get_logger
|
||||
from utils.common import generate_unique_id
|
||||
@@ -667,14 +667,14 @@ class PostProcessor:
|
||||
if output.ndim != 2:
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
if output.shape[0] != 84:
|
||||
if output.shape[0] != MODEL_OUTPUT_CHANNELS:
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
|
||||
boxes_xywh = output[0:4, :].T
|
||||
|
||||
# YOLO11 输出格式: [4+num_classes, 8400]
|
||||
# 前4行是 xywh,后80行是各类别分数,没有单独的 objectness 行
|
||||
class_scores = output[4:, :] # [num_classes, 8400]
|
||||
# YOLO11 输出格式: [4+num_classes, num_anchors]
|
||||
# 前4行是 xywh,后 num_classes 行是各类别分数,没有单独的 objectness 行
|
||||
class_scores = output[4:, :] # [num_classes, num_anchors]
|
||||
scores = class_scores.max(axis=0) # 取各类别最大分数
|
||||
class_ids = class_scores.argmax(axis=0) # 对应类别ID
|
||||
|
||||
@@ -734,15 +734,15 @@ class PostProcessor:
|
||||
first_output = batch_outputs[0]
|
||||
if isinstance(first_output, np.ndarray):
|
||||
if first_output.ndim == 3 and first_output.shape[0] == batch_size:
|
||||
# 已经是 (batch, 84, anchors) 格式
|
||||
# 已经是 (batch, 4+nc, anchors) 格式
|
||||
outputs_array = first_output
|
||||
elif first_output.ndim == 1:
|
||||
# TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 84, anchors)
|
||||
# TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 4+nc, anchors)
|
||||
per_image = first_output.shape[0] // batch_size
|
||||
num_anchors = per_image // 84
|
||||
outputs_array = first_output.reshape(batch_size, 84, num_anchors)
|
||||
num_anchors = per_image // MODEL_OUTPUT_CHANNELS
|
||||
outputs_array = first_output.reshape(batch_size, MODEL_OUTPUT_CHANNELS, num_anchors)
|
||||
elif first_output.ndim == 2:
|
||||
# (84, anchors) 单张图的输出
|
||||
# (4+nc, anchors) 单张图的输出
|
||||
outputs_array = first_output.reshape(1, first_output.shape[0], first_output.shape[1])
|
||||
else:
|
||||
outputs_array = first_output
|
||||
@@ -771,7 +771,7 @@ class PostProcessor:
|
||||
if output.ndim == 3:
|
||||
output = output[0]
|
||||
|
||||
if output.shape[0] != 84:
|
||||
if output.shape[0] != MODEL_OUTPUT_CHANNELS:
|
||||
results.append((np.array([]), np.array([]), np.array([])))
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user