重构: 模型类别配置动态化,适配 v2+ (5 类) 模型
- settings.py: 删除 80 类 COCO_CLASS_NAMES 硬编码,改为 MODEL_CLASS_NAMES
+ 新增 MODEL_NUM_CLASSES / MODEL_OUTPUT_CHANNELS 辅助常量
+ 新增 ALGO_INTENT_CLASSES 字典 + get_algo_target_classes() 辅助函数
+ COCO_CLASS_NAMES 保留名称向后兼容,指向 MODEL_CLASS_NAMES
- postprocessor.py: 4 处硬编码 84 → MODEL_OUTPUT_CHANNELS
+ 支持不同类别数模型切换不改代码
- algorithms.py: 4 处硬编码 target_classes 默认值 → get_algo_target_classes()
+ IllegalParkingAlgorithm / VehicleCongestionAlgorithm /
NonMotorVehicleParkingAlgorithm / GarbageDetectionAlgorithm
+ 自动过滤当前模型不支持的类(truck/bus 等)
以后换模型只需修改 settings.py 一处 MODEL_CLASS_NAMES。
This commit is contained in:
@@ -143,19 +143,40 @@ class DebugConfig:
|
||||
local_config_path: str = "./config/local_config.json"
|
||||
|
||||
|
||||
# COCO 数据集类别名称(YOLO 模型使用)
|
||||
COCO_CLASS_NAMES = [
|
||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
||||
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
|
||||
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
|
||||
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
|
||||
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
|
||||
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
|
||||
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
|
||||
"mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
|
||||
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
|
||||
]
|
||||
# 模型类别配置(支持不同模型切换,修改此处即可)
|
||||
# 当前: yolo11s_v2plus_20260421 (5 类定制模型)
|
||||
# 历史: yolo11n 官方 (80 类 COCO),列表见 git 历史
|
||||
MODEL_CLASS_NAMES = ["garbage", "person", "car", "bicycle", "motorcycle"]
|
||||
MODEL_NUM_CLASSES = len(MODEL_CLASS_NAMES) # 模型类别数
|
||||
MODEL_OUTPUT_CHANNELS = 4 + MODEL_NUM_CLASSES # YOLO 输出通道 = 4(xywh) + nc
|
||||
|
||||
# 向后兼容:保留 COCO_CLASS_NAMES 名称,指向当前模型类别
|
||||
COCO_CLASS_NAMES = MODEL_CLASS_NAMES
|
||||
|
||||
|
||||
# 各算法的业务关注类别(全集,不限定当前模型是否支持)
|
||||
# 部署时通过 get_algo_target_classes() 自动过滤出当前模型支持的子集
|
||||
ALGO_INTENT_CLASSES = {
|
||||
"leave_post": ["person"],
|
||||
"intrusion": ["person"],
|
||||
"illegal_parking": ["car", "truck", "bus"],
|
||||
"vehicle_congestion": ["car", "truck", "bus", "motorcycle"],
|
||||
"non_motor_vehicle_parking": ["bicycle", "motorcycle"],
|
||||
"garbage": ["garbage"],
|
||||
}
|
||||
|
||||
|
||||
def get_algo_target_classes(algo_code: str) -> list:
|
||||
"""获取算法的目标类别,自动过滤出当前模型支持的类
|
||||
|
||||
Args:
|
||||
algo_code: 算法代码 (leave_post / intrusion / illegal_parking / ...)
|
||||
|
||||
Returns:
|
||||
目标类名列表,保证每个类都在 MODEL_CLASS_NAMES 中
|
||||
"""
|
||||
intent = ALGO_INTENT_CLASSES.get(algo_code, [])
|
||||
return [c for c in intent if c in MODEL_CLASS_NAMES]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user