Compare commits
3 Commits
f9c7f9018e
...
2c00b5afe3
| Author | SHA1 | Date | |
|---|---|---|---|
| 2c00b5afe3 | |||
| e965b10603 | |||
| 1e562798eb |
@@ -1,6 +1,7 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from db.crud import (
|
from db.crud import (
|
||||||
@@ -16,6 +17,14 @@ from inference.pipeline import get_pipeline
|
|||||||
router = APIRouter(prefix="/api/cameras", tags=["摄像头管理"])
|
router = APIRouter(prefix="/api/cameras", tags=["摄像头管理"])
|
||||||
|
|
||||||
|
|
||||||
|
class CameraUpdateRequest(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
rtsp_url: Optional[str] = None
|
||||||
|
fps_limit: Optional[int] = None
|
||||||
|
process_every_n_frames: Optional[int] = None
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[dict])
|
@router.get("", response_model=List[dict])
|
||||||
def list_cameras(
|
def list_cameras(
|
||||||
enabled_only: bool = True,
|
enabled_only: bool = True,
|
||||||
@@ -83,29 +92,25 @@ def add_camera(
|
|||||||
@router.put("/{camera_id}", response_model=dict)
|
@router.put("/{camera_id}", response_model=dict)
|
||||||
def modify_camera(
|
def modify_camera(
|
||||||
camera_id: int,
|
camera_id: int,
|
||||||
name: Optional[str] = None,
|
request: CameraUpdateRequest = Body(...),
|
||||||
rtsp_url: Optional[str] = None,
|
|
||||||
fps_limit: Optional[int] = None,
|
|
||||||
process_every_n_frames: Optional[int] = None,
|
|
||||||
enabled: Optional[bool] = None,
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
camera = update_camera(
|
camera = update_camera(
|
||||||
db,
|
db,
|
||||||
camera_id=camera_id,
|
camera_id=camera_id,
|
||||||
name=name,
|
name=request.name,
|
||||||
rtsp_url=rtsp_url,
|
rtsp_url=request.rtsp_url,
|
||||||
fps_limit=fps_limit,
|
fps_limit=request.fps_limit,
|
||||||
process_every_n_frames=process_every_n_frames,
|
process_every_n_frames=request.process_every_n_frames,
|
||||||
enabled=enabled,
|
enabled=request.enabled,
|
||||||
)
|
)
|
||||||
if not camera:
|
if not camera:
|
||||||
raise HTTPException(status_code=404, detail="摄像头不存在")
|
raise HTTPException(status_code=404, detail="摄像头不存在")
|
||||||
|
|
||||||
pipeline = get_pipeline()
|
pipeline = get_pipeline()
|
||||||
if enabled is True:
|
if request.enabled is True:
|
||||||
pipeline.add_camera(camera)
|
pipeline.add_camera(camera)
|
||||||
elif enabled is False:
|
elif request.enabled is False:
|
||||||
pipeline.remove_camera(camera_id)
|
pipeline.remove_camera(camera_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
20
config.py
20
config.py
@@ -23,14 +23,16 @@ class DatabaseConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
engine_path: str = "models/yolo11n_fp16_480.engine"
|
engine_path: str = "models/yolo11s.engine"
|
||||||
pt_model_path: str = "models/yolo11n.pt"
|
onnx_path: str = "models/yolo11s.onnx"
|
||||||
imgsz: List[int] = [480, 480]
|
pt_model_path: str = "models/yolo11s.pt"
|
||||||
|
imgsz: List[int] = [640, 640]
|
||||||
conf_threshold: float = 0.5
|
conf_threshold: float = 0.5
|
||||||
iou_threshold: float = 0.45
|
iou_threshold: float = 0.45
|
||||||
device: int = 0
|
device: int = 0
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
half: bool = True
|
half: bool = True
|
||||||
|
use_onnx: bool = True
|
||||||
|
|
||||||
|
|
||||||
class StreamConfig(BaseModel):
|
class StreamConfig(BaseModel):
|
||||||
@@ -78,6 +80,17 @@ class LoggingConfig(BaseModel):
|
|||||||
backup_count: int = 5
|
backup_count: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
class CloudConfig(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
api_url: str = "https://api.example.com"
|
||||||
|
api_key: str = ""
|
||||||
|
device_id: str = "EDGE-001"
|
||||||
|
sync_interval: int = 60
|
||||||
|
alarm_retry_interval: int = 60
|
||||||
|
status_report_interval: int = 60
|
||||||
|
max_retries: int = 3
|
||||||
|
|
||||||
|
|
||||||
class MonitoringConfig(BaseModel):
|
class MonitoringConfig(BaseModel):
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
port: int = 9090
|
port: int = 9090
|
||||||
@@ -93,6 +106,7 @@ class LLMConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
|
cloud: CloudConfig = Field(default_factory=CloudConfig)
|
||||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||||
stream: StreamConfig = Field(default_factory=StreamConfig)
|
stream: StreamConfig = Field(default_factory=StreamConfig)
|
||||||
|
|||||||
25
config.yaml
25
config.yaml
@@ -1,5 +1,16 @@
|
|||||||
# 安保异常行为识别系统 - 核心配置
|
# 安保异常行为识别系统 - 核心配置
|
||||||
|
|
||||||
|
# 云端同步配置
|
||||||
|
cloud:
|
||||||
|
enabled: false # 启用云端同步(云端为主、本地为辅)
|
||||||
|
api_url: "https://api.example.com" # 云端API地址
|
||||||
|
api_key: "your-api-key" # API密钥
|
||||||
|
device_id: "EDGE-001" # 设备唯一标识
|
||||||
|
sync_interval: 60 # 配置同步间隔(秒)
|
||||||
|
alarm_retry_interval: 60 # 报警重试间隔(秒)
|
||||||
|
status_report_interval: 60 # 状态上报间隔(秒)
|
||||||
|
max_retries: 3 # 最大重试次数
|
||||||
|
|
||||||
# 数据库配置
|
# 数据库配置
|
||||||
database:
|
database:
|
||||||
dialect: "sqlite" # sqlite 或 mysql
|
dialect: "sqlite" # sqlite 或 mysql
|
||||||
@@ -12,21 +23,23 @@ database:
|
|||||||
|
|
||||||
# TensorRT模型配置
|
# TensorRT模型配置
|
||||||
model:
|
model:
|
||||||
engine_path: "models/yolo11n_fp16_480.engine"
|
engine_path: "models/yolo11n.engine"
|
||||||
|
onnx_path: "models/yolo11n.onnx"
|
||||||
pt_model_path: "models/yolo11n.pt"
|
pt_model_path: "models/yolo11n.pt"
|
||||||
imgsz: [480, 480]
|
imgsz: [640, 640]
|
||||||
conf_threshold: 0.5
|
conf_threshold: 0.5
|
||||||
iou_threshold: 0.45
|
iou_threshold: 0.45
|
||||||
device: 0 # GPU设备号
|
device: 0
|
||||||
batch_size: 8 # 最大batch size
|
batch_size: 8
|
||||||
half: true # FP16推理
|
half: false
|
||||||
|
use_onnx: true
|
||||||
|
|
||||||
# RTSP流配置
|
# RTSP流配置
|
||||||
stream:
|
stream:
|
||||||
buffer_size: 2 # 每路摄像头帧缓冲大小
|
buffer_size: 2 # 每路摄像头帧缓冲大小
|
||||||
reconnect_delay: 3.0 # 重连延迟(秒)
|
reconnect_delay: 3.0 # 重连延迟(秒)
|
||||||
timeout: 10.0 # 连接超时(秒)
|
timeout: 10.0 # 连接超时(秒)
|
||||||
fps_limit: 30 # 最大处理FPS
|
fps_limit: 10.0 # 最大处理FPS
|
||||||
|
|
||||||
# 推理队列配置
|
# 推理队列配置
|
||||||
inference:
|
inference:
|
||||||
|
|||||||
12
db/models.py
12
db/models.py
@@ -33,11 +33,15 @@ class Camera(Base):
|
|||||||
__tablename__ = "cameras"
|
__tablename__ = "cameras"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
cloud_id: Mapped[Optional[int]] = mapped_column(Integer, unique=True, nullable=True)
|
||||||
name: Mapped[str] = mapped_column(String(64), nullable=False)
|
name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
rtsp_url: Mapped[str] = mapped_column(Text, nullable=False)
|
rtsp_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
fps_limit: Mapped[int] = mapped_column(Integer, default=30)
|
fps_limit: Mapped[int] = mapped_column(Integer, default=30)
|
||||||
process_every_n_frames: Mapped[int] = mapped_column(Integer, default=3)
|
process_every_n_frames: Mapped[int] = mapped_column(Integer, default=3)
|
||||||
|
pending_sync: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
sync_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||||
|
sync_retry_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||||
@@ -74,6 +78,7 @@ class ROI(Base):
|
|||||||
__tablename__ = "rois"
|
__tablename__ = "rois"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
cloud_id: Mapped[Optional[int]] = mapped_column(Integer, unique=True, nullable=True)
|
||||||
camera_id: Mapped[int] = mapped_column(
|
camera_id: Mapped[int] = mapped_column(
|
||||||
Integer, ForeignKey("cameras.id"), nullable=False
|
Integer, ForeignKey("cameras.id"), nullable=False
|
||||||
)
|
)
|
||||||
@@ -88,6 +93,8 @@ class ROI(Base):
|
|||||||
threshold_sec: Mapped[int] = mapped_column(Integer, default=360)
|
threshold_sec: Mapped[int] = mapped_column(Integer, default=360)
|
||||||
confirm_sec: Mapped[int] = mapped_column(Integer, default=30)
|
confirm_sec: Mapped[int] = mapped_column(Integer, default=30)
|
||||||
return_sec: Mapped[int] = mapped_column(Integer, default=5)
|
return_sec: Mapped[int] = mapped_column(Integer, default=5)
|
||||||
|
pending_sync: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
sync_version: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||||
@@ -100,6 +107,7 @@ class Alarm(Base):
|
|||||||
__tablename__ = "alarms"
|
__tablename__ = "alarms"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
cloud_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||||
camera_id: Mapped[int] = mapped_column(
|
camera_id: Mapped[int] = mapped_column(
|
||||||
Integer, ForeignKey("cameras.id"), nullable=False
|
Integer, ForeignKey("cameras.id"), nullable=False
|
||||||
)
|
)
|
||||||
@@ -107,6 +115,10 @@ class Alarm(Base):
|
|||||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
confidence: Mapped[float] = mapped_column(Float, default=0.0)
|
confidence: Mapped[float] = mapped_column(Float, default=0.0)
|
||||||
snapshot_path: Mapped[Optional[str]] = mapped_column(Text)
|
snapshot_path: Mapped[Optional[str]] = mapped_column(Text)
|
||||||
|
region_data: Mapped[Optional[str]] = mapped_column(Text)
|
||||||
|
upload_status: Mapped[str] = mapped_column(String(32), default='pending_upload')
|
||||||
|
upload_retry_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
error_message: Mapped[Optional[str]] = mapped_column(Text)
|
||||||
llm_checked: Mapped[bool] = mapped_column(Boolean, default=False)
|
llm_checked: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
llm_result: Mapped[Optional[str]] = mapped_column(Text)
|
llm_result: Mapped[Optional[str]] = mapped_column(Text)
|
||||||
processed: Mapped[bool] = mapped_column(Boolean, default=False)
|
processed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
|||||||
@@ -121,8 +121,10 @@ const CameraManagement: React.FC = () => {
|
|||||||
await axios.put(`/api/cameras/${camera.id}`, { enabled: !camera.enabled });
|
await axios.put(`/api/cameras/${camera.id}`, { enabled: !camera.enabled });
|
||||||
message.success(camera.enabled ? '已停用' : '已启用');
|
message.success(camera.enabled ? '已停用' : '已启用');
|
||||||
fetchCameras();
|
fetchCameras();
|
||||||
} catch (err) {
|
} catch (err: any) {
|
||||||
message.error('操作失败');
|
console.error('Toggle error:', err);
|
||||||
|
const errorMsg = err.response?.data?.detail || err.message || '操作失败';
|
||||||
|
message.error(`操作失败: ${errorMsg}`);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Card, Row, Col, Statistic, List, Tag, Button, Space, Timeline } from 'antd';
|
import { Card, Row, Col, Statistic, List, Tag, Button, Space } from 'antd';
|
||||||
import { AlertOutlined, VideoCameraOutlined, ClockCircleOutlined } from '@ant-design/icons';
|
import { AlertOutlined, VideoCameraOutlined, ClockCircleOutlined } from '@ant-design/icons';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
import React, { useEffect, useState, useRef } from 'react';
|
import React, { useEffect, useState, useRef } from 'react';
|
||||||
import { Card, Button, Space, Select, message, Modal, Form, Input, InputNumber, Drawer } from 'antd';
|
import { Card, Button, Space, Select, message, Drawer, Form, Input, InputNumber, Switch } from 'antd';
|
||||||
import { Stage, Layer, Rect, Line, Circle, Text as KonvaText } from 'react-konva';
|
import { Stage, Layer, Rect, Line, Circle, Text as KonvaText } from 'react-konva';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
|
|
||||||
interface ROI {
|
interface ROI {
|
||||||
id: number;
|
id: number;
|
||||||
roi_id: string;
|
|
||||||
name: string;
|
name: string;
|
||||||
type: string;
|
type: string;
|
||||||
points: number[][];
|
points: number[][];
|
||||||
rule: string;
|
rule: string;
|
||||||
direction: string | null;
|
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
threshold_sec: number;
|
threshold_sec: number;
|
||||||
confirm_sec: number;
|
confirm_sec: number;
|
||||||
@@ -27,12 +25,14 @@ const ROIEditor: React.FC = () => {
|
|||||||
const [selectedCamera, setSelectedCamera] = useState<number | null>(null);
|
const [selectedCamera, setSelectedCamera] = useState<number | null>(null);
|
||||||
const [rois, setRois] = useState<ROI[]>([]);
|
const [rois, setRois] = useState<ROI[]>([]);
|
||||||
const [snapshot, setSnapshot] = useState<string>('');
|
const [snapshot, setSnapshot] = useState<string>('');
|
||||||
const [loading, setLoading] = useState(false);
|
|
||||||
const [imageDim, setImageDim] = useState({ width: 800, height: 600 });
|
const [imageDim, setImageDim] = useState({ width: 800, height: 600 });
|
||||||
const [selectedROI, setSelectedROI] = useState<ROI | null>(null);
|
const [selectedROI, setSelectedROI] = useState<ROI | null>(null);
|
||||||
const [drawerVisible, setDrawerVisible] = useState(false);
|
const [drawerVisible, setDrawerVisible] = useState(false);
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
|
|
||||||
|
const [isDrawing, setIsDrawing] = useState(false);
|
||||||
|
const [tempPoints, setTempPoints] = useState<number[][]>([]);
|
||||||
|
const [backgroundImage, setBackgroundImage] = useState<HTMLImageElement | null>(null);
|
||||||
const stageRef = useRef<any>(null);
|
const stageRef = useRef<any>(null);
|
||||||
|
|
||||||
const fetchCameras = async () => {
|
const fetchCameras = async () => {
|
||||||
@@ -58,19 +58,25 @@ const ROIEditor: React.FC = () => {
|
|||||||
}
|
}
|
||||||
}, [selectedCamera]);
|
}, [selectedCamera]);
|
||||||
|
|
||||||
const fetchSnapshot = async () => {
|
useEffect(() => {
|
||||||
if (!selectedCamera) return;
|
if (snapshot) {
|
||||||
try {
|
|
||||||
const res = await axios.get(`/api/camera/${selectedCamera}/snapshot/base64`);
|
|
||||||
setSnapshot(res.data.image);
|
|
||||||
const img = new Image();
|
const img = new Image();
|
||||||
img.onload = () => {
|
img.onload = () => {
|
||||||
const maxWidth = 800;
|
const maxWidth = 800;
|
||||||
const maxHeight = 600;
|
const maxHeight = 600;
|
||||||
const scale = Math.min(maxWidth / img.width, maxHeight / img.height);
|
const scale = Math.min(maxWidth / img.width, maxHeight / img.height);
|
||||||
setImageDim({ width: img.width * scale, height: img.height * scale });
|
setImageDim({ width: img.width * scale, height: img.height * scale });
|
||||||
|
setBackgroundImage(img);
|
||||||
};
|
};
|
||||||
img.src = `data:image/jpeg;base64,${res.data.image}`;
|
img.src = `data:image/jpeg;base64,${snapshot}`;
|
||||||
|
}
|
||||||
|
}, [snapshot]);
|
||||||
|
|
||||||
|
const fetchSnapshot = async () => {
|
||||||
|
if (!selectedCamera) return;
|
||||||
|
try {
|
||||||
|
const res = await axios.get(`/api/camera/${selectedCamera}/snapshot/base64`);
|
||||||
|
setSnapshot(res.data.image);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
message.error('获取截图失败');
|
message.error('获取截图失败');
|
||||||
}
|
}
|
||||||
@@ -89,34 +95,79 @@ const ROIEditor: React.FC = () => {
|
|||||||
const handleSaveROI = async (values: any) => {
|
const handleSaveROI = async (values: any) => {
|
||||||
if (!selectedCamera || !selectedROI) return;
|
if (!selectedCamera || !selectedROI) return;
|
||||||
try {
|
try {
|
||||||
await axios.put(`/api/camera/${selectedCamera}/roi/${selectedROI.id}`, values);
|
await axios.put(`/api/camera/${selectedCamera}/roi/${selectedROI.id}`, {
|
||||||
|
name: values.name,
|
||||||
|
roi_type: values.roi_type,
|
||||||
|
rule_type: values.rule_type,
|
||||||
|
threshold_sec: values.threshold_sec,
|
||||||
|
confirm_sec: values.confirm_sec,
|
||||||
|
enabled: values.enabled,
|
||||||
|
});
|
||||||
message.success('保存成功');
|
message.success('保存成功');
|
||||||
setDrawerVisible(false);
|
setDrawerVisible(false);
|
||||||
fetchROIs();
|
fetchROIs();
|
||||||
} catch (err) {
|
} catch (err: any) {
|
||||||
message.error('保存失败');
|
message.error(`保存失败: ${err.response?.data?.detail || '未知错误'}`);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleAddROI = async () => {
|
const handleAddROI = () => {
|
||||||
if (!selectedCamera) return;
|
if (!selectedCamera) {
|
||||||
const roi_id = `roi_${Date.now()}`;
|
message.warning('请先选择摄像头');
|
||||||
try {
|
return;
|
||||||
await axios.post(`/api/camera/${selectedCamera}/roi`, {
|
|
||||||
roi_id,
|
|
||||||
name: '新区域',
|
|
||||||
roi_type: 'polygon',
|
|
||||||
points: [[100, 100], [300, 100], [300, 300], [100, 300]],
|
|
||||||
rule_type: 'leave_post',
|
|
||||||
threshold_sec: 360,
|
|
||||||
confirm_sec: 30,
|
|
||||||
return_sec: 5,
|
|
||||||
});
|
|
||||||
message.success('添加成功');
|
|
||||||
fetchROIs();
|
|
||||||
} catch (err) {
|
|
||||||
message.error('添加失败');
|
|
||||||
}
|
}
|
||||||
|
setIsDrawing(true);
|
||||||
|
setTempPoints([]);
|
||||||
|
setSelectedROI(null);
|
||||||
|
message.info('点击画布绘制ROI区域,双击完成绘制');
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleStageClick = (e: any) => {
|
||||||
|
if (!isDrawing) return;
|
||||||
|
|
||||||
|
const stage = e.target.getStage();
|
||||||
|
const pos = stage.getPointerPosition();
|
||||||
|
if (pos) {
|
||||||
|
setTempPoints(prev => [...prev, [pos.x, pos.y]]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleStageDblClick = () => {
|
||||||
|
if (!isDrawing || tempPoints.length < 3) {
|
||||||
|
if (tempPoints.length > 0 && tempPoints.length < 3) {
|
||||||
|
message.warning('至少需要3个点才能形成多边形');
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const roi_id = `roi_${Date.now()}`;
|
||||||
|
axios.post(`/api/camera/${selectedCamera}/roi`, {
|
||||||
|
roi_id,
|
||||||
|
name: `区域${rois.length + 1}`,
|
||||||
|
roi_type: 'polygon',
|
||||||
|
points: tempPoints,
|
||||||
|
rule_type: 'intrusion',
|
||||||
|
threshold_sec: 60,
|
||||||
|
confirm_sec: 5,
|
||||||
|
return_sec: 5,
|
||||||
|
})
|
||||||
|
.then(() => {
|
||||||
|
message.success('ROI添加成功');
|
||||||
|
setIsDrawing(false);
|
||||||
|
setTempPoints([]);
|
||||||
|
fetchROIs();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
message.error(`添加失败: ${err.response?.data?.detail || '未知错误'}`);
|
||||||
|
setIsDrawing(false);
|
||||||
|
setTempPoints([]);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCancelDrawing = () => {
|
||||||
|
setIsDrawing(false);
|
||||||
|
setTempPoints([]);
|
||||||
|
message.info('已取消绘制');
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleDeleteROI = async (roiId: number) => {
|
const handleDeleteROI = async (roiId: number) => {
|
||||||
@@ -124,9 +175,13 @@ const ROIEditor: React.FC = () => {
|
|||||||
try {
|
try {
|
||||||
await axios.delete(`/api/camera/${selectedCamera}/roi/${roiId}`);
|
await axios.delete(`/api/camera/${selectedCamera}/roi/${roiId}`);
|
||||||
message.success('删除成功');
|
message.success('删除成功');
|
||||||
|
if (selectedROI?.id === roiId) {
|
||||||
|
setSelectedROI(null);
|
||||||
|
setDrawerVisible(false);
|
||||||
|
}
|
||||||
fetchROIs();
|
fetchROIs();
|
||||||
} catch (err) {
|
} catch (err: any) {
|
||||||
message.error('删除失败');
|
message.error(`删除失败: ${err.response?.data?.detail || '未知错误'}`);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -137,162 +192,264 @@ const ROIEditor: React.FC = () => {
|
|||||||
const renderROI = (roi: ROI) => {
|
const renderROI = (roi: ROI) => {
|
||||||
const points = roi.points.flat();
|
const points = roi.points.flat();
|
||||||
const color = getROIStrokeColor(roi.rule);
|
const color = getROIStrokeColor(roi.rule);
|
||||||
|
const isSelected = selectedROI?.id === roi.id;
|
||||||
|
|
||||||
if (roi.type === 'polygon') {
|
return (
|
||||||
return (
|
<Line
|
||||||
<Line
|
key={roi.id}
|
||||||
key={roi.id}
|
points={points}
|
||||||
points={points}
|
closed={roi.type === 'polygon'}
|
||||||
closed
|
stroke={isSelected ? '#1890ff' : color}
|
||||||
stroke={color}
|
strokeWidth={isSelected ? 3 : 2}
|
||||||
strokeWidth={2}
|
fill={`${color}33`}
|
||||||
fill={`${color}33`}
|
onClick={() => {
|
||||||
onClick={() => {
|
setSelectedROI(roi);
|
||||||
setSelectedROI(roi);
|
form.setFieldsValue({
|
||||||
form.setFieldsValue(roi);
|
name: roi.name,
|
||||||
setDrawerVisible(true);
|
roi_type: roi.type,
|
||||||
}}
|
rule_type: roi.rule,
|
||||||
/>
|
threshold_sec: roi.threshold_sec,
|
||||||
);
|
confirm_sec: roi.confirm_sec,
|
||||||
} else if (roi.type === 'line') {
|
enabled: roi.enabled,
|
||||||
return (
|
});
|
||||||
<Line
|
setDrawerVisible(true);
|
||||||
key={roi.id}
|
}}
|
||||||
points={points}
|
onMouseEnter={(e) => {
|
||||||
stroke={color}
|
const container = e.target.getStage()?.container();
|
||||||
strokeWidth={3}
|
if (container) {
|
||||||
onClick={() => {
|
container.style.cursor = 'pointer';
|
||||||
setSelectedROI(roi);
|
}
|
||||||
form.setFieldsValue(roi);
|
}}
|
||||||
setDrawerVisible(true);
|
onMouseLeave={(e) => {
|
||||||
}}
|
const container = e.target.getStage()?.container();
|
||||||
/>
|
if (container) {
|
||||||
);
|
container.style.cursor = 'default';
|
||||||
}
|
}
|
||||||
return null;
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
<Card>
|
<Card>
|
||||||
<Space style={{ marginBottom: 16 }}>
|
<Space style={{ marginBottom: 16 }} wrap>
|
||||||
<Select
|
<Select
|
||||||
placeholder="选择摄像头"
|
placeholder="选择摄像头"
|
||||||
value={selectedCamera}
|
value={selectedCamera}
|
||||||
onChange={setSelectedCamera}
|
onChange={(value) => {
|
||||||
|
setSelectedCamera(value);
|
||||||
|
setSelectedROI(null);
|
||||||
|
}}
|
||||||
style={{ width: 200 }}
|
style={{ width: 200 }}
|
||||||
options={cameras.map((c) => ({ label: c.name, value: c.id }))}
|
options={cameras.map((c) => ({ label: c.name, value: c.id }))}
|
||||||
/>
|
/>
|
||||||
<Button type="primary" onClick={fetchSnapshot}>
|
<Button onClick={fetchSnapshot}>刷新截图</Button>
|
||||||
刷新截图
|
{isDrawing ? (
|
||||||
</Button>
|
<>
|
||||||
<Button onClick={handleAddROI}>添加ROI</Button>
|
<Button danger onClick={handleCancelDrawing}>取消绘制</Button>
|
||||||
|
<Button type="primary" disabled={tempPoints.length < 3} onClick={handleStageDblClick}>
|
||||||
|
完成绘制 ({tempPoints.length} 点)
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Button type="primary" onClick={handleAddROI}>添加ROI</Button>
|
||||||
|
)}
|
||||||
</Space>
|
</Space>
|
||||||
|
|
||||||
<div className="roi-editor-container" style={{ display: 'flex', gap: 16 }}>
|
<div className="roi-editor-container" style={{ display: 'flex', gap: 16, flexDirection: 'row' }}>
|
||||||
<div style={{ flex: 1, background: '#f0f0f0', display: 'flex', justifyContent: 'center', alignItems: 'center' }}>
|
<div style={{
|
||||||
|
flex: 1,
|
||||||
|
background: '#f0f0f0',
|
||||||
|
display: 'flex',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
minHeight: 500,
|
||||||
|
border: isDrawing ? '2px solid #1890ff' : '1px solid #d9d9d9',
|
||||||
|
borderRadius: 4,
|
||||||
|
position: 'relative'
|
||||||
|
}}>
|
||||||
|
{isDrawing && (
|
||||||
|
<div style={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 10,
|
||||||
|
left: 10,
|
||||||
|
zIndex: 10,
|
||||||
|
background: 'rgba(24, 144, 255, 0.9)',
|
||||||
|
color: 'white',
|
||||||
|
padding: '8px 16px',
|
||||||
|
borderRadius: 4,
|
||||||
|
fontSize: 14
|
||||||
|
}}>
|
||||||
|
绘制模式 - 点击添加点,双击完成
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
{snapshot ? (
|
{snapshot ? (
|
||||||
<Stage width={imageDim.width} height={imageDim.height} ref={stageRef}>
|
<Stage
|
||||||
|
width={imageDim.width}
|
||||||
|
height={imageDim.height}
|
||||||
|
ref={stageRef}
|
||||||
|
onClick={handleStageClick}
|
||||||
|
onDblClick={handleStageDblClick}
|
||||||
|
style={{ cursor: isDrawing ? 'crosshair' : 'default' }}
|
||||||
|
>
|
||||||
<Layer>
|
<Layer>
|
||||||
<Rect
|
{backgroundImage && (
|
||||||
x={0}
|
<Rect
|
||||||
y={0}
|
x={0}
|
||||||
width={imageDim.width}
|
y={0}
|
||||||
height={imageDim.height}
|
width={imageDim.width}
|
||||||
fillPatternImage={
|
height={imageDim.height}
|
||||||
(() => {
|
fillPatternImage={backgroundImage}
|
||||||
const img = new Image();
|
fillPatternOffset={{ x: 0, y: 0 }}
|
||||||
img.src = `data:image/jpeg;base64,${snapshot}`;
|
fillPatternScale={{ x: 1, y: 1 }}
|
||||||
return img;
|
/>
|
||||||
})()
|
)}
|
||||||
}
|
|
||||||
fillPatternOffset={{ x: 0, y: 0 }}
|
|
||||||
fillPatternScale={{ x: 1, y: 1 }}
|
|
||||||
/>
|
|
||||||
{rois.map(renderROI)}
|
{rois.map(renderROI)}
|
||||||
|
{isDrawing && tempPoints.length > 0 && (
|
||||||
|
<>
|
||||||
|
<Line
|
||||||
|
points={tempPoints.flat()}
|
||||||
|
stroke="#1890ff"
|
||||||
|
strokeWidth={2}
|
||||||
|
dash={[5, 5]}
|
||||||
|
/>
|
||||||
|
{tempPoints.map((point, idx) => (
|
||||||
|
<Circle
|
||||||
|
key={idx}
|
||||||
|
x={point[0]}
|
||||||
|
y={point[1]}
|
||||||
|
radius={5}
|
||||||
|
fill="#1890ff"
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
{tempPoints.map((point, idx) => (
|
||||||
|
<KonvaText
|
||||||
|
key={`label-${idx}`}
|
||||||
|
x={point[0] + 10}
|
||||||
|
y={point[1] - 10}
|
||||||
|
text={`${idx + 1}`}
|
||||||
|
fontSize={14}
|
||||||
|
fill="#1890ff"
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Layer>
|
</Layer>
|
||||||
</Stage>
|
</Stage>
|
||||||
) : (
|
) : (
|
||||||
<div>加载中...</div>
|
<div style={{ color: '#999' }}>加载中...</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div style={{ width: 300 }}>
|
<div style={{ width: 280, flexShrink: 0 }}>
|
||||||
<Card title="ROI列表" size="small">
|
<Card title="ROI列表" size="small" bodyStyle={{ maxHeight: 500, overflow: 'auto' }}>
|
||||||
{rois.map((roi) => (
|
{rois.length === 0 ? (
|
||||||
<div
|
<div style={{ color: '#999', textAlign: 'center', padding: 20 }}>
|
||||||
key={roi.id}
|
暂无ROI区域,点击"添加ROI"开始绘制
|
||||||
style={{
|
</div>
|
||||||
padding: 8,
|
) : (
|
||||||
marginBottom: 8,
|
rois.map((roi) => (
|
||||||
background: '#fafafa',
|
<div
|
||||||
borderRadius: 4,
|
key={roi.id}
|
||||||
cursor: 'pointer',
|
style={{
|
||||||
border: selectedROI?.id === roi.id ? '2px solid #1890ff' : '1px solid #d9d9d9',
|
padding: 8,
|
||||||
}}
|
marginBottom: 8,
|
||||||
onClick={() => {
|
background: selectedROI?.id === roi.id ? '#e6f7ff' : '#fafafa',
|
||||||
setSelectedROI(roi);
|
borderRadius: 4,
|
||||||
form.setFieldsValue(roi);
|
cursor: 'pointer',
|
||||||
setDrawerVisible(true);
|
border: selectedROI?.id === roi.id ? '2px solid #1890ff' : '1px solid #d9d9d9',
|
||||||
}}
|
}}
|
||||||
>
|
onClick={() => {
|
||||||
<div style={{ fontWeight: 'bold' }}>{roi.name}</div>
|
setSelectedROI(roi);
|
||||||
<div style={{ fontSize: 12, color: '#666' }}>
|
form.setFieldsValue({
|
||||||
类型: {roi.type} | 规则: {roi.rule}
|
name: roi.name,
|
||||||
</div>
|
roi_type: roi.type,
|
||||||
<Button
|
rule_type: roi.rule,
|
||||||
type="text"
|
threshold_sec: roi.threshold_sec,
|
||||||
danger
|
confirm_sec: roi.confirm_sec,
|
||||||
size="small"
|
enabled: roi.enabled,
|
||||||
onClick={(e) => {
|
});
|
||||||
e.stopPropagation();
|
setDrawerVisible(true);
|
||||||
handleDeleteROI(roi.id);
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
删除
|
<div style={{ fontWeight: 'bold', marginBottom: 4 }}>{roi.name}</div>
|
||||||
</Button>
|
<div style={{ fontSize: 12, color: '#666', marginBottom: 4 }}>
|
||||||
</div>
|
类型: {roi.type === 'polygon' ? '多边形' : '线段'} | 规则: {roi.rule === 'intrusion' ? '入侵检测' : '离岗检测'}
|
||||||
))}
|
</div>
|
||||||
|
<Space size={4}>
|
||||||
|
<Button
|
||||||
|
type="link"
|
||||||
|
size="small"
|
||||||
|
danger
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
handleDeleteROI(roi.id);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
删除
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
))
|
||||||
|
)}
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
<Drawer
|
<Drawer
|
||||||
title="编辑ROI"
|
title={selectedROI ? `编辑ROI - ${selectedROI.name}` : '编辑ROI'}
|
||||||
open={drawerVisible}
|
open={drawerVisible}
|
||||||
onClose={() => setDrawerVisible(false)}
|
onClose={() => {
|
||||||
|
setDrawerVisible(false);
|
||||||
|
setSelectedROI(null);
|
||||||
|
}}
|
||||||
width={400}
|
width={400}
|
||||||
>
|
>
|
||||||
<Form form={form} layout="vertical" onFinish={handleSaveROI}>
|
<Form form={form} layout="vertical" onFinish={handleSaveROI}>
|
||||||
<Form.Item name="name" label="名称" rules={[{ required: true }]}>
|
<Form.Item name="name" label="名称" rules={[{ required: true, message: '请输入名称' }]}>
|
||||||
<Input />
|
<Input placeholder="例如:入口入侵区域" />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item name="roi_type" label="类型">
|
<Form.Item name="roi_type" label="类型" rules={[{ required: true }]}>
|
||||||
<Select options={[{ label: '多边形', value: 'polygon' }, { label: '线段', value: 'line' }]} />
|
|
||||||
</Form.Item>
|
|
||||||
<Form.Item name="rule_type" label="规则">
|
|
||||||
<Select
|
<Select
|
||||||
options={[
|
options={[
|
||||||
{ label: '离岗检测', value: 'leave_post' },
|
{ label: '多边形区域', value: 'polygon' },
|
||||||
{ label: '周界入侵', value: 'intrusion' },
|
{ label: '线段', value: 'line' },
|
||||||
]}
|
]}
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item name="threshold_sec" label="超时时间(秒)">
|
<Form.Item name="rule_type" label="检测规则" rules={[{ required: true }]}>
|
||||||
<InputNumber min={60} style={{ width: '100%' }} />
|
<Select
|
||||||
|
options={[
|
||||||
|
{ label: '周界入侵检测', value: 'intrusion' },
|
||||||
|
{ label: '离岗检测', value: 'leave_post' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item name="confirm_sec" label="确认时间(秒)">
|
{selectedROI?.rule === 'leave_post' && (
|
||||||
<InputNumber min={5} style={{ width: '100%' }} />
|
<>
|
||||||
</Form.Item>
|
<Form.Item name="threshold_sec" label="超时时间(秒)" rules={[{ required: true }]}>
|
||||||
<Form.Item name="enabled" label="启用" valuePropName="checked">
|
<InputNumber min={60} style={{ width: '100%' }} />
|
||||||
<input type="checkbox" />
|
</Form.Item>
|
||||||
|
<Form.Item name="confirm_sec" label="确认时间(秒)" rules={[{ required: true }]}>
|
||||||
|
<InputNumber min={5} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
<Form.Item name="enabled" label="启用状态" valuePropName="checked">
|
||||||
|
<Switch checkedChildren="启用" unCheckedChildren="停用" />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item>
|
<Form.Item>
|
||||||
<Space>
|
<Space>
|
||||||
<Button type="primary" htmlType="submit">
|
<Button type="primary" htmlType="submit">
|
||||||
保存
|
保存
|
||||||
</Button>
|
</Button>
|
||||||
<Button onClick={() => setDrawerVisible(false)}>取消</Button>
|
<Button onClick={() => {
|
||||||
|
setDrawerVisible(false);
|
||||||
|
setSelectedROI(null);
|
||||||
|
}}>
|
||||||
|
取消
|
||||||
|
</Button>
|
||||||
</Space>
|
</Space>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Form>
|
</Form>
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
os.environ["TENSORRT_DISABLE_MYELIN"] = "1"
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -6,12 +9,146 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
import torch
|
import torch
|
||||||
|
import onnxruntime as ort
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.engine.results import Results
|
from ultralytics.engine.results import Results, Boxes as UltralyticsBoxes
|
||||||
|
|
||||||
from config import get_config
|
from config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXEngine:
|
||||||
|
def __init__(self, onnx_path: Optional[str] = None, device: int = 0):
|
||||||
|
config = get_config()
|
||||||
|
self.onnx_path = onnx_path or config.model.onnx_path
|
||||||
|
self.device = device
|
||||||
|
self.imgsz = tuple(config.model.imgsz)
|
||||||
|
self.conf_thresh = config.model.conf_threshold
|
||||||
|
self.iou_thresh = config.model.iou_threshold
|
||||||
|
|
||||||
|
self.session = None
|
||||||
|
self.input_names = None
|
||||||
|
self.output_names = None
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
if not os.path.exists(self.onnx_path):
|
||||||
|
raise FileNotFoundError(f"ONNX模型文件不存在: {self.onnx_path}")
|
||||||
|
|
||||||
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.device >= 0 else ['CPUExecutionProvider']
|
||||||
|
self.session = ort.InferenceSession(self.onnx_path, providers=providers)
|
||||||
|
|
||||||
|
self.input_names = [inp.name for inp in self.session.get_inputs()]
|
||||||
|
self.output_names = [out.name for out in self.session.get_outputs()]
|
||||||
|
|
||||||
|
def preprocess(self, frame: np.ndarray) -> np.ndarray:
|
||||||
|
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
img = cv2.resize(img, self.imgsz)
|
||||||
|
|
||||||
|
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def postprocess(self, output: np.ndarray, orig_img: np.ndarray) -> List[Results]:
|
||||||
|
c, n = output.shape
|
||||||
|
output = output.T
|
||||||
|
|
||||||
|
boxes = output[:, :4]
|
||||||
|
scores = output[:, 4]
|
||||||
|
classes = output[:, 5:].argmax(axis=1) if output.shape[1] > 5 else np.zeros(len(output), dtype=np.int32)
|
||||||
|
|
||||||
|
mask = scores > self.conf_thresh
|
||||||
|
boxes = boxes[mask]
|
||||||
|
scores = scores[mask]
|
||||||
|
classes = classes[mask]
|
||||||
|
|
||||||
|
if len(boxes) == 0:
|
||||||
|
return [Results(orig_img=orig_img, path="", names={0: "person"})]
|
||||||
|
|
||||||
|
indices = cv2.dnn.NMSBoxes(
|
||||||
|
boxes.tolist(),
|
||||||
|
scores.tolist(),
|
||||||
|
self.conf_thresh,
|
||||||
|
self.iou_thresh,
|
||||||
|
)
|
||||||
|
|
||||||
|
orig_h, orig_w = orig_img.shape[:2]
|
||||||
|
scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0]
|
||||||
|
|
||||||
|
filtered_boxes = []
|
||||||
|
for idx in indices:
|
||||||
|
if idx >= len(boxes):
|
||||||
|
continue
|
||||||
|
box = boxes[idx]
|
||||||
|
x1, y1, x2, y2 = box
|
||||||
|
w, h = x2 - x1, y2 - y1
|
||||||
|
filtered_boxes.append([
|
||||||
|
int(x1 * scale_x),
|
||||||
|
int(y1 * scale_y),
|
||||||
|
int(w * scale_x),
|
||||||
|
int(h * scale_y),
|
||||||
|
float(scores[idx]),
|
||||||
|
int(classes[idx])
|
||||||
|
])
|
||||||
|
|
||||||
|
from ultralytics.engine.results import Boxes as BoxesObj
|
||||||
|
if filtered_boxes:
|
||||||
|
box_tensor = torch.tensor(filtered_boxes)
|
||||||
|
boxes_obj = BoxesObj(
|
||||||
|
box_tensor,
|
||||||
|
orig_shape=(orig_h, orig_w)
|
||||||
|
)
|
||||||
|
result = Results(
|
||||||
|
orig_img=orig_img,
|
||||||
|
path="",
|
||||||
|
names={0: "person"},
|
||||||
|
boxes=boxes_obj
|
||||||
|
)
|
||||||
|
return [result]
|
||||||
|
|
||||||
|
return [Results(orig_img=orig_img, path="", names={0: "person"})]
|
||||||
|
|
||||||
|
def inference(self, images: List[np.ndarray]) -> List[Results]:
|
||||||
|
if not images:
|
||||||
|
return []
|
||||||
|
|
||||||
|
batch_imgs = []
|
||||||
|
for frame in images:
|
||||||
|
img = self.preprocess(frame)
|
||||||
|
batch_imgs.append(img)
|
||||||
|
|
||||||
|
batch = np.stack(batch_imgs, axis=0)
|
||||||
|
|
||||||
|
inputs = {self.input_names[0]: batch}
|
||||||
|
outputs = self.session.run(self.output_names, inputs)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
output = outputs[0]
|
||||||
|
if output.shape[0] == 1:
|
||||||
|
result = self.postprocess(output[0], images[0])
|
||||||
|
results.extend(result)
|
||||||
|
else:
|
||||||
|
for i in range(output.shape[0]):
|
||||||
|
result = self.postprocess(output[i], images[i])
|
||||||
|
results.extend(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def inference_single(self, frame: np.ndarray) -> List[Results]:
|
||||||
|
return self.inference([frame])
|
||||||
|
|
||||||
|
def warmup(self, num_warmup: int = 10):
|
||||||
|
dummy_frame = np.zeros((640, 640, 3), dtype=np.uint8)
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
self.inference_single(dummy_frame)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.session:
|
||||||
|
try:
|
||||||
|
self.session.end_profiling()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TensorRTEngine:
|
class TensorRTEngine:
|
||||||
def __init__(self, engine_path: Optional[str] = None, device: int = 0):
|
def __init__(self, engine_path: Optional[str] = None, device: int = 0):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
@@ -25,9 +162,11 @@ class TensorRTEngine:
|
|||||||
self.logger = trt.Logger(trt.Logger.INFO)
|
self.logger = trt.Logger(trt.Logger.INFO)
|
||||||
self.engine = None
|
self.engine = None
|
||||||
self.context = None
|
self.context = None
|
||||||
self.stream = None
|
self.stream = torch.cuda.Stream(device=self.device)
|
||||||
self.input_buffer = None
|
self.input_buffer = None
|
||||||
self.output_buffers = []
|
self.output_buffers = []
|
||||||
|
self.input_name = None
|
||||||
|
self.output_name = None
|
||||||
|
|
||||||
self._load_engine()
|
self._load_engine()
|
||||||
|
|
||||||
@@ -44,29 +183,39 @@ class TensorRTEngine:
|
|||||||
self.context = self.engine.create_execution_context()
|
self.context = self.engine.create_execution_context()
|
||||||
|
|
||||||
self.stream = torch.cuda.Stream(device=self.device)
|
self.stream = torch.cuda.Stream(device=self.device)
|
||||||
|
self.batch_size = 1
|
||||||
|
|
||||||
for i in range(self.engine.num_io_tensors):
|
for i in range(self.engine.num_io_tensors):
|
||||||
name = self.engine.get_tensor_name(i)
|
name = self.engine.get_tensor_name(i)
|
||||||
dtype = self.engine.get_tensor_dtype(name)
|
dtype = self.engine.get_tensor_dtype(name)
|
||||||
shape = self.engine.get_tensor_shape(name)
|
shape = list(self.engine.get_tensor_shape(name))
|
||||||
|
|
||||||
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||||
self.context.set_tensor_address(name, None)
|
if -1 in shape:
|
||||||
|
shape = [self.batch_size if d == -1 else d for d in shape]
|
||||||
|
if dtype == trt.float16:
|
||||||
|
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
|
||||||
|
else:
|
||||||
|
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
|
||||||
|
self.input_buffer = buffer
|
||||||
|
self.input_name = name
|
||||||
else:
|
else:
|
||||||
|
if -1 in shape:
|
||||||
|
shape = [self.batch_size if d == -1 else d for d in shape]
|
||||||
if dtype == trt.float16:
|
if dtype == trt.float16:
|
||||||
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
|
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
|
||||||
else:
|
else:
|
||||||
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
|
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
|
||||||
self.output_buffers.append(buffer)
|
self.output_buffers.append(buffer)
|
||||||
self.context.set_tensor_address(name, buffer.data_ptr())
|
if self.output_name is None:
|
||||||
|
self.output_name = name
|
||||||
|
|
||||||
self.context.set_optimization_profile_async(0, self.stream)
|
self.context.set_tensor_address(name, buffer.data_ptr())
|
||||||
|
|
||||||
self.input_buffer = torch.zeros(
|
stream_handle = torch.cuda.current_stream(self.device).cuda_stream
|
||||||
(1, 3, self.imgsz[0], self.imgsz[1]),
|
self.context.set_optimization_profile_async(0, stream_handle)
|
||||||
dtype=torch.float16 if self.half else torch.float32,
|
|
||||||
device=self.device,
|
self.batch_size = 1
|
||||||
)
|
|
||||||
|
|
||||||
def preprocess(self, frame: np.ndarray) -> torch.Tensor:
|
def preprocess(self, frame: np.ndarray) -> torch.Tensor:
|
||||||
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
@@ -95,16 +244,20 @@ class TensorRTEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.context.set_tensor_address(
|
self.context.set_tensor_address(
|
||||||
"input", input_tensor.contiguous().data_ptr()
|
self.input_name, input_tensor.contiguous().data_ptr()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_shape = list(input_tensor.shape)
|
||||||
|
self.context.set_input_shape(self.input_name, input_shape)
|
||||||
|
|
||||||
torch.cuda.synchronize(self.stream)
|
torch.cuda.synchronize(self.stream)
|
||||||
self.context.execute_async_v3(self.stream.handle)
|
self.context.execute_async_v3(self.stream.cuda_stream)
|
||||||
torch.cuda.synchronize(self.stream)
|
torch.cuda.synchronize(self.stream)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
pred = self.output_buffers[0][i].cpu().numpy()
|
pred = self.output_buffers[0][i].cpu().numpy()
|
||||||
|
pred = pred.T # 转置: (8400, 84)
|
||||||
boxes = pred[:, :4]
|
boxes = pred[:, :4]
|
||||||
scores = pred[:, 4]
|
scores = pred[:, 4]
|
||||||
classes = pred[:, 5].astype(np.int32)
|
classes = pred[:, 5].astype(np.int32)
|
||||||
@@ -142,7 +295,7 @@ class TensorRTEngine:
|
|||||||
orig_img=images[i],
|
orig_img=images[i],
|
||||||
path="",
|
path="",
|
||||||
names={0: "person"},
|
names={0: "person"},
|
||||||
boxes=Boxes(
|
boxes=UltralyticsBoxes(
|
||||||
torch.tensor([box_orig + [conf, cls]]),
|
torch.tensor([box_orig + [conf, cls]]),
|
||||||
orig_shape=(orig_h, orig_w),
|
orig_shape=(orig_h, orig_w),
|
||||||
),
|
),
|
||||||
@@ -161,9 +314,15 @@ class TensorRTEngine:
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.context:
|
if self.context:
|
||||||
self.context.synchronize()
|
try:
|
||||||
|
self.context.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
if self.stream:
|
if self.stream:
|
||||||
self.stream.synchronize()
|
try:
|
||||||
|
self.stream.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Boxes:
|
class Boxes:
|
||||||
@@ -196,6 +355,15 @@ class Boxes:
|
|||||||
return self.data[:, 5]
|
return self.data[:, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def _check_pt_file_valid(pt_path: str) -> bool:
|
||||||
|
try:
|
||||||
|
with open(pt_path, 'rb') as f:
|
||||||
|
header = f.read(10)
|
||||||
|
return len(header) == 10
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class YOLOEngine:
|
class YOLOEngine:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -203,38 +371,61 @@ class YOLOEngine:
|
|||||||
device: int = 0,
|
device: int = 0,
|
||||||
use_trt: bool = True,
|
use_trt: bool = True,
|
||||||
):
|
):
|
||||||
self.use_trt = use_trt
|
self.use_trt = False
|
||||||
self.device = device
|
self.onnx_engine = None
|
||||||
self.trt_engine = None
|
self.trt_engine = None
|
||||||
|
self.device = device
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
if not use_trt:
|
if use_trt:
|
||||||
if model_path:
|
|
||||||
pt_path = model_path
|
|
||||||
elif hasattr(get_config().model, 'pt_model_path'):
|
|
||||||
pt_path = get_config().model.pt_model_path
|
|
||||||
else:
|
|
||||||
pt_path = get_config().model.engine_path.replace(".engine", ".pt")
|
|
||||||
self.model = YOLO(pt_path)
|
|
||||||
self.model.to(device)
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
self.trt_engine = TensorRTEngine(device=device)
|
self.trt_engine = TensorRTEngine(device=device)
|
||||||
self.trt_engine.warmup()
|
self.trt_engine.warmup()
|
||||||
|
self.use_trt = True
|
||||||
|
print("TensorRT引擎加载成功")
|
||||||
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"TensorRT加载失败,回退到PyTorch: {e}")
|
print(f"TensorRT加载失败: {e}")
|
||||||
self.use_trt = False
|
|
||||||
if model_path:
|
try:
|
||||||
pt_path = model_path
|
onnx_path = config.model.onnx_path
|
||||||
elif hasattr(get_config().model, 'pt_model_path'):
|
if os.path.exists(onnx_path):
|
||||||
pt_path = get_config().model.pt_model_path
|
self.onnx_engine = ONNXEngine(device=device)
|
||||||
else:
|
self.onnx_engine.warmup()
|
||||||
pt_path = get_config().model.engine_path.replace(".engine", ".pt")
|
print("ONNX引擎加载成功")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(f"ONNX模型不存在: {onnx_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ONNX加载失败: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
pt_path = model_path or config.model.pt_model_path
|
||||||
|
if os.path.exists(pt_path) and _check_pt_file_valid(pt_path):
|
||||||
self.model = YOLO(pt_path)
|
self.model = YOLO(pt_path)
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
|
print(f"PyTorch模型加载成功: {pt_path}")
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"PT文件无效或不存在: {pt_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"PyTorch加载失败: {e}")
|
||||||
|
raise RuntimeError("所有模型加载方式均失败")
|
||||||
|
|
||||||
def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]:
|
def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]:
|
||||||
if self.use_trt:
|
if self.use_trt and self.trt_engine:
|
||||||
return self.trt_engine.inference_single(frame)
|
try:
|
||||||
|
return self.trt_engine.inference_single(frame)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"TensorRT推理失败,切换到ONNX: {e}")
|
||||||
|
self.use_trt = False
|
||||||
|
if self.onnx_engine:
|
||||||
|
return self.onnx_engine.inference_single(frame)
|
||||||
|
elif self.model:
|
||||||
|
return self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
elif self.onnx_engine:
|
||||||
|
return self.onnx_engine.inference_single(frame)
|
||||||
else:
|
else:
|
||||||
results = self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
|
results = self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
|
||||||
return results
|
return results
|
||||||
@@ -242,3 +433,5 @@ class YOLOEngine:
|
|||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.trt_engine:
|
if self.trt_engine:
|
||||||
del self.trt_engine
|
del self.trt_engine
|
||||||
|
if self.onnx_engine:
|
||||||
|
del self.onnx_engine
|
||||||
|
|||||||
60
logs/app.log
60
logs/app.log
@@ -73,3 +73,63 @@
|
|||||||
2026-01-20 18:24:01,952 - security_monitor - INFO - 启动安保异常行为识别系统
|
2026-01-20 18:24:01,952 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
2026-01-20 18:24:01,963 - security_monitor - INFO - 数据库初始化完成
|
2026-01-20 18:24:01,963 - security_monitor - INFO - 数据库初始化完成
|
||||||
2026-01-20 18:24:14,477 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
2026-01-20 18:24:14,477 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-20 18:29:40,275 - security_monitor - INFO - 正在关闭系统...
|
||||||
|
2026-01-20 18:29:40,454 - security_monitor - INFO - 系统已关闭
|
||||||
|
2026-01-21 09:00:02,704 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:00:02,720 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:01:24,719 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:01:24,732 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:05:29,103 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:05:29,117 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:09:47,194 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:09:47,209 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:16:43,336 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:16:43,350 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:18:58,020 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:18:58,032 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:27:51,761 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:27:51,776 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:31:31,676 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:31:31,690 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:31:44,902 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 09:32:04,038 - security_monitor - INFO - 正在关闭系统...
|
||||||
|
2026-01-21 09:32:04,282 - security_monitor - INFO - 系统已关闭
|
||||||
|
2026-01-21 09:33:24,297 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 09:33:24,308 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 09:33:37,369 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 09:36:31,696 - security_monitor - INFO - 正在关闭系统...
|
||||||
|
2026-01-21 09:36:31,901 - security_monitor - INFO - 系统已关闭
|
||||||
|
2026-01-21 10:27:59,314 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 10:27:59,327 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 10:28:11,999 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 10:33:43,512 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 10:33:43,523 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 10:33:56,202 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 10:34:45,507 - security_monitor - INFO - 正在关闭系统...
|
||||||
|
2026-01-21 10:34:45,707 - security_monitor - INFO - 系统已关闭
|
||||||
|
2026-01-21 10:39:53,562 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 10:39:53,572 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 10:40:06,255 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 10:44:16,822 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 10:44:16,835 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 10:44:29,517 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 10:47:23,643 - security_monitor - INFO - 正在关闭系统...
|
||||||
|
2026-01-21 10:47:23,837 - security_monitor - INFO - 系统已关闭
|
||||||
|
2026-01-21 10:49:25,601 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 10:49:25,612 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 10:49:38,298 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 10:49:38,299 - security_monitor - INFO - 正在关闭系统...
|
||||||
|
2026-01-21 10:49:47,607 - security_monitor - INFO - 系统已关闭
|
||||||
|
2026-01-21 10:50:25,579 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 10:50:25,592 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 10:50:38,256 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 10:52:30,478 - security_monitor - INFO - 正在关闭系统...
|
||||||
|
2026-01-21 10:52:30,687 - security_monitor - INFO - 系统已关闭
|
||||||
|
2026-01-21 13:17:45,812 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 13:17:45,826 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 13:17:58,479 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
2026-01-21 13:17:58,480 - security_monitor - INFO - 正在关闭系统...
|
||||||
|
2026-01-21 13:18:07,687 - security_monitor - INFO - 系统已关闭
|
||||||
|
2026-01-21 13:18:55,795 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||||
|
2026-01-21 13:18:55,809 - security_monitor - INFO - 数据库初始化完成
|
||||||
|
2026-01-21 13:19:08,492 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||||
|
|||||||
4
main.py
4
main.py
@@ -7,6 +7,8 @@ from contextlib import asynccontextmanager
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
os.environ["TENSORRT_DISABLE_MYELIN"] = "1"
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
@@ -19,6 +21,7 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|||||||
from api.alarm import router as alarm_router
|
from api.alarm import router as alarm_router
|
||||||
from api.camera import router as camera_router
|
from api.camera import router as camera_router
|
||||||
from api.roi import router as roi_router
|
from api.roi import router as roi_router
|
||||||
|
from api.sync import router as sync_router
|
||||||
from config import get_config, load_config
|
from config import get_config, load_config
|
||||||
from db.models import init_db
|
from db.models import init_db
|
||||||
from inference.pipeline import get_pipeline, start_pipeline, stop_pipeline
|
from inference.pipeline import get_pipeline, start_pipeline, stop_pipeline
|
||||||
@@ -81,6 +84,7 @@ app.add_middleware(
|
|||||||
app.include_router(camera_router)
|
app.include_router(camera_router)
|
||||||
app.include_router(roi_router)
|
app.include_router(roi_router)
|
||||||
app.include_router(alarm_router)
|
app.include_router(alarm_router)
|
||||||
|
app.include_router(sync_router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
|
|||||||
@@ -10,37 +10,35 @@ sys.path.insert(0, project_root)
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
import tensorrt as trt
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
|
||||||
def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True):
|
def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True, imgsz=640):
|
||||||
"""构建TensorRT引擎"""
|
"""构建TensorRT引擎"""
|
||||||
from tensorrt import Builder, NetworkDefinitionLayer, Runtime
|
|
||||||
from tensorrt.parsers import onnxparser
|
|
||||||
|
|
||||||
logger = trt.Logger(trt.Logger.INFO)
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
builder = trt.Builder(logger)
|
builder = trt.Builder(logger)
|
||||||
|
|
||||||
network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||||
network = builder.create_network(network_flags)
|
network = builder.create_network(network_flags)
|
||||||
|
|
||||||
parser = onnxparser.create_onnx_parser(network)
|
parser = trt.OnnxParser(network, logger)
|
||||||
parser.parse(onnx_path)
|
with open(onnx_path, 'rb') as f:
|
||||||
parser.report_status()
|
if not parser.parse(f.read()):
|
||||||
|
for error in range(parser.num_errors):
|
||||||
# 动态形状配置
|
print(parser.get_error(error))
|
||||||
if dynamic_batch:
|
raise RuntimeError("ONNX 解析失败")
|
||||||
profile = builder.create_optimization_profile()
|
|
||||||
min_shape = (1, 3, 480, 480)
|
|
||||||
opt_shape = (4, 3, 480, 480)
|
|
||||||
max_shape = (8, 3, 480, 480)
|
|
||||||
|
|
||||||
profile.set_shape("input", min_shape, opt_shape, max_shape)
|
|
||||||
network.get_input(0).set_dynamic_range(-1.0, 1.0)
|
|
||||||
network.set_precision_constraints(trt.PrecisionConstraints.PREFER)
|
|
||||||
|
|
||||||
config = builder.create_builder_config()
|
config = builder.create_builder_config()
|
||||||
config.set_memory_allocator(trt.MemoryAllocator())
|
|
||||||
config.max_workspace_size = 4 << 30 # 4GB
|
if dynamic_batch:
|
||||||
|
profile = builder.create_optimization_profile()
|
||||||
|
min_shape = (1, 3, imgsz, imgsz)
|
||||||
|
opt_shape = (4, 3, imgsz, imgsz)
|
||||||
|
max_shape = (8, 3, imgsz, imgsz)
|
||||||
|
|
||||||
|
profile.set_shape("images", min_shape, opt_shape, max_shape)
|
||||||
|
config.add_optimization_profile(profile)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
config.set_flag(trt.BuilderFlag.FP16)
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
@@ -50,10 +48,10 @@ def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True):
|
|||||||
with open(engine_path, "wb") as f:
|
with open(engine_path, "wb") as f:
|
||||||
f.write(serialized_engine)
|
f.write(serialized_engine)
|
||||||
|
|
||||||
print(f"✅ TensorRT引擎已保存: {engine_path}")
|
print(f"TensorRT引擎已保存: {engine_path}")
|
||||||
|
|
||||||
|
|
||||||
def export_onnx(model_path, onnx_path, imgsz=480):
|
def export_onnx(model_path, onnx_path, imgsz=640):
|
||||||
"""导出ONNX模型"""
|
"""导出ONNX模型"""
|
||||||
model = YOLO(model_path)
|
model = YOLO(model_path)
|
||||||
model.export(
|
model.export(
|
||||||
@@ -63,17 +61,24 @@ def export_onnx(model_path, onnx_path, imgsz=480):
|
|||||||
opset=12,
|
opset=12,
|
||||||
dynamic=True,
|
dynamic=True,
|
||||||
)
|
)
|
||||||
print(f"✅ ONNX模型已导出: {onnx_path}")
|
import shutil
|
||||||
|
import glob
|
||||||
|
onnx_files = glob.glob("yolo11n*.onnx")
|
||||||
|
if onnx_files:
|
||||||
|
shutil.move(onnx_files[0], onnx_path)
|
||||||
|
print(f"ONNX模型已导出: {onnx_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="TensorRT Engine Builder")
|
parser = argparse.ArgumentParser(description="TensorRT Engine Builder")
|
||||||
parser.add_argument("--model", type=str, default="models/yolo11n.pt",
|
parser.add_argument("--model", type=str, default="models/yolo11n.pt",
|
||||||
help="YOLO模型路径")
|
help="YOLO模型路径")
|
||||||
parser.add_argument("--engine", type=str, default="models/yolo11n_fp16_480.engine",
|
parser.add_argument("--engine", type=str, default="models/yolo11n.engine",
|
||||||
help="输出引擎路径")
|
help="输出引擎路径")
|
||||||
parser.add_argument("--onnx", type=str, default="models/yolo11n_480.onnx",
|
parser.add_argument("--onnx", type=str, default="models/yolo11n.onnx",
|
||||||
help="临时ONNX路径")
|
help="ONNX模型路径")
|
||||||
|
parser.add_argument("--imgsz", type=int, default=640,
|
||||||
|
help="输入图像尺寸")
|
||||||
parser.add_argument("--fp16", action="store_true", default=True,
|
parser.add_argument("--fp16", action="store_true", default=True,
|
||||||
help="启用FP16")
|
help="启用FP16")
|
||||||
parser.add_argument("--no-dynamic", action="store_true",
|
parser.add_argument("--no-dynamic", action="store_true",
|
||||||
@@ -82,8 +87,10 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(args.engine), exist_ok=True)
|
os.makedirs(os.path.dirname(args.engine), exist_ok=True)
|
||||||
|
onnx_dir = os.path.dirname(args.onnx) if os.path.dirname(args.onnx) else '.'
|
||||||
|
os.makedirs(onnx_dir, exist_ok=True)
|
||||||
|
|
||||||
if not os.path.exists(args.onnx):
|
if not os.path.exists(args.onnx):
|
||||||
export_onnx(args.model, args.onnx)
|
export_onnx(args.model, args.onnx, args.imgsz)
|
||||||
|
|
||||||
build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic)
|
build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic, args.imgsz)
|
||||||
|
|||||||
Binary file not shown.
Reference in New Issue
Block a user