生成新engine

This commit is contained in:
2026-01-21 13:29:39 +08:00
parent e965b10603
commit 2c00b5afe3
6 changed files with 547 additions and 181 deletions

View File

@@ -33,11 +33,15 @@ class Camera(Base):
__tablename__ = "cameras"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
cloud_id: Mapped[Optional[int]] = mapped_column(Integer, unique=True, nullable=True)
name: Mapped[str] = mapped_column(String(64), nullable=False)
rtsp_url: Mapped[str] = mapped_column(Text, nullable=False)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
fps_limit: Mapped[int] = mapped_column(Integer, default=30)
process_every_n_frames: Mapped[int] = mapped_column(Integer, default=3)
pending_sync: Mapped[bool] = mapped_column(Boolean, default=False)
sync_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
sync_retry_count: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
@@ -74,6 +78,7 @@ class ROI(Base):
__tablename__ = "rois"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
cloud_id: Mapped[Optional[int]] = mapped_column(Integer, unique=True, nullable=True)
camera_id: Mapped[int] = mapped_column(
Integer, ForeignKey("cameras.id"), nullable=False
)
@@ -88,6 +93,8 @@ class ROI(Base):
threshold_sec: Mapped[int] = mapped_column(Integer, default=360)
confirm_sec: Mapped[int] = mapped_column(Integer, default=30)
return_sec: Mapped[int] = mapped_column(Integer, default=5)
pending_sync: Mapped[bool] = mapped_column(Boolean, default=False)
sync_version: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
@@ -100,6 +107,7 @@ class Alarm(Base):
__tablename__ = "alarms"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
cloud_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
camera_id: Mapped[int] = mapped_column(
Integer, ForeignKey("cameras.id"), nullable=False
)
@@ -107,6 +115,10 @@ class Alarm(Base):
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
confidence: Mapped[float] = mapped_column(Float, default=0.0)
snapshot_path: Mapped[Optional[str]] = mapped_column(Text)
region_data: Mapped[Optional[str]] = mapped_column(Text)
upload_status: Mapped[str] = mapped_column(String(32), default='pending_upload')
upload_retry_count: Mapped[int] = mapped_column(Integer, default=0)
error_message: Mapped[Optional[str]] = mapped_column(Text)
llm_checked: Mapped[bool] = mapped_column(Boolean, default=False)
llm_result: Mapped[Optional[str]] = mapped_column(Text)
processed: Mapped[bool] = mapped_column(Boolean, default=False)

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Card, Row, Col, Statistic, List, Tag, Button, Space, Timeline } from 'antd';
import { Card, Row, Col, Statistic, List, Tag, Button, Space } from 'antd';
import { AlertOutlined, VideoCameraOutlined, ClockCircleOutlined } from '@ant-design/icons';
import axios from 'axios';

View File

@@ -1,16 +1,14 @@
import React, { useEffect, useState, useRef } from 'react';
import { Card, Button, Space, Select, message, Modal, Form, Input, InputNumber, Drawer } from 'antd';
import { Card, Button, Space, Select, message, Drawer, Form, Input, InputNumber, Switch } from 'antd';
import { Stage, Layer, Rect, Line, Circle, Text as KonvaText } from 'react-konva';
import axios from 'axios';
interface ROI {
id: number;
roi_id: string;
name: string;
type: string;
points: number[][];
rule: string;
direction: string | null;
enabled: boolean;
threshold_sec: number;
confirm_sec: number;
@@ -27,12 +25,14 @@ const ROIEditor: React.FC = () => {
const [selectedCamera, setSelectedCamera] = useState<number | null>(null);
const [rois, setRois] = useState<ROI[]>([]);
const [snapshot, setSnapshot] = useState<string>('');
const [loading, setLoading] = useState(false);
const [imageDim, setImageDim] = useState({ width: 800, height: 600 });
const [selectedROI, setSelectedROI] = useState<ROI | null>(null);
const [drawerVisible, setDrawerVisible] = useState(false);
const [form] = Form.useForm();
const [isDrawing, setIsDrawing] = useState(false);
const [tempPoints, setTempPoints] = useState<number[][]>([]);
const [backgroundImage, setBackgroundImage] = useState<HTMLImageElement | null>(null);
const stageRef = useRef<any>(null);
const fetchCameras = async () => {
@@ -58,19 +58,25 @@ const ROIEditor: React.FC = () => {
}
}, [selectedCamera]);
const fetchSnapshot = async () => {
if (!selectedCamera) return;
try {
const res = await axios.get(`/api/camera/${selectedCamera}/snapshot/base64`);
setSnapshot(res.data.image);
useEffect(() => {
if (snapshot) {
const img = new Image();
img.onload = () => {
const maxWidth = 800;
const maxHeight = 600;
const scale = Math.min(maxWidth / img.width, maxHeight / img.height);
setImageDim({ width: img.width * scale, height: img.height * scale });
setBackgroundImage(img);
};
img.src = `data:image/jpeg;base64,${res.data.image}`;
img.src = `data:image/jpeg;base64,${snapshot}`;
}
}, [snapshot]);
const fetchSnapshot = async () => {
if (!selectedCamera) return;
try {
const res = await axios.get(`/api/camera/${selectedCamera}/snapshot/base64`);
setSnapshot(res.data.image);
} catch (err) {
message.error('获取截图失败');
}
@@ -89,34 +95,79 @@ const ROIEditor: React.FC = () => {
const handleSaveROI = async (values: any) => {
if (!selectedCamera || !selectedROI) return;
try {
await axios.put(`/api/camera/${selectedCamera}/roi/${selectedROI.id}`, values);
await axios.put(`/api/camera/${selectedCamera}/roi/${selectedROI.id}`, {
name: values.name,
roi_type: values.roi_type,
rule_type: values.rule_type,
threshold_sec: values.threshold_sec,
confirm_sec: values.confirm_sec,
enabled: values.enabled,
});
message.success('保存成功');
setDrawerVisible(false);
fetchROIs();
} catch (err) {
message.error('保存失败');
} catch (err: any) {
message.error(`保存失败: ${err.response?.data?.detail || '未知错误'}`);
}
};
const handleAddROI = async () => {
if (!selectedCamera) return;
const roi_id = `roi_${Date.now()}`;
try {
await axios.post(`/api/camera/${selectedCamera}/roi`, {
roi_id,
name: '新区域',
roi_type: 'polygon',
points: [[100, 100], [300, 100], [300, 300], [100, 300]],
rule_type: 'leave_post',
threshold_sec: 360,
confirm_sec: 30,
return_sec: 5,
});
message.success('添加成功');
fetchROIs();
} catch (err) {
message.error('添加失败');
const handleAddROI = () => {
if (!selectedCamera) {
message.warning('请先选择摄像头');
return;
}
setIsDrawing(true);
setTempPoints([]);
setSelectedROI(null);
message.info('点击画布绘制ROI区域双击完成绘制');
};
const handleStageClick = (e: any) => {
if (!isDrawing) return;
const stage = e.target.getStage();
const pos = stage.getPointerPosition();
if (pos) {
setTempPoints(prev => [...prev, [pos.x, pos.y]]);
}
};
const handleStageDblClick = () => {
if (!isDrawing || tempPoints.length < 3) {
if (tempPoints.length > 0 && tempPoints.length < 3) {
message.warning('至少需要3个点才能形成多边形');
}
return;
}
const roi_id = `roi_${Date.now()}`;
axios.post(`/api/camera/${selectedCamera}/roi`, {
roi_id,
name: `区域${rois.length + 1}`,
roi_type: 'polygon',
points: tempPoints,
rule_type: 'intrusion',
threshold_sec: 60,
confirm_sec: 5,
return_sec: 5,
})
.then(() => {
message.success('ROI添加成功');
setIsDrawing(false);
setTempPoints([]);
fetchROIs();
})
.catch((err) => {
message.error(`添加失败: ${err.response?.data?.detail || '未知错误'}`);
setIsDrawing(false);
setTempPoints([]);
});
};
const handleCancelDrawing = () => {
setIsDrawing(false);
setTempPoints([]);
message.info('已取消绘制');
};
const handleDeleteROI = async (roiId: number) => {
@@ -124,9 +175,13 @@ const ROIEditor: React.FC = () => {
try {
await axios.delete(`/api/camera/${selectedCamera}/roi/${roiId}`);
message.success('删除成功');
if (selectedROI?.id === roiId) {
setSelectedROI(null);
setDrawerVisible(false);
}
fetchROIs();
} catch (err) {
message.error('删除失败');
} catch (err: any) {
message.error(`删除失败: ${err.response?.data?.detail || '未知错误'}`);
}
};
@@ -137,162 +192,264 @@ const ROIEditor: React.FC = () => {
const renderROI = (roi: ROI) => {
const points = roi.points.flat();
const color = getROIStrokeColor(roi.rule);
const isSelected = selectedROI?.id === roi.id;
if (roi.type === 'polygon') {
return (
<Line
key={roi.id}
points={points}
closed
stroke={color}
strokeWidth={2}
fill={`${color}33`}
onClick={() => {
setSelectedROI(roi);
form.setFieldsValue(roi);
setDrawerVisible(true);
}}
/>
);
} else if (roi.type === 'line') {
return (
<Line
key={roi.id}
points={points}
stroke={color}
strokeWidth={3}
onClick={() => {
setSelectedROI(roi);
form.setFieldsValue(roi);
setDrawerVisible(true);
}}
/>
);
}
return null;
return (
<Line
key={roi.id}
points={points}
closed={roi.type === 'polygon'}
stroke={isSelected ? '#1890ff' : color}
strokeWidth={isSelected ? 3 : 2}
fill={`${color}33`}
onClick={() => {
setSelectedROI(roi);
form.setFieldsValue({
name: roi.name,
roi_type: roi.type,
rule_type: roi.rule,
threshold_sec: roi.threshold_sec,
confirm_sec: roi.confirm_sec,
enabled: roi.enabled,
});
setDrawerVisible(true);
}}
onMouseEnter={(e) => {
const container = e.target.getStage()?.container();
if (container) {
container.style.cursor = 'pointer';
}
}}
onMouseLeave={(e) => {
const container = e.target.getStage()?.container();
if (container) {
container.style.cursor = 'default';
}
}}
/>
);
};
return (
<div>
<Card>
<Space style={{ marginBottom: 16 }}>
<Space style={{ marginBottom: 16 }} wrap>
<Select
placeholder="选择摄像头"
value={selectedCamera}
onChange={setSelectedCamera}
onChange={(value) => {
setSelectedCamera(value);
setSelectedROI(null);
}}
style={{ width: 200 }}
options={cameras.map((c) => ({ label: c.name, value: c.id }))}
/>
<Button type="primary" onClick={fetchSnapshot}>
</Button>
<Button onClick={handleAddROI}>ROI</Button>
<Button onClick={fetchSnapshot}></Button>
{isDrawing ? (
<>
<Button danger onClick={handleCancelDrawing}></Button>
<Button type="primary" disabled={tempPoints.length < 3} onClick={handleStageDblClick}>
({tempPoints.length} )
</Button>
</>
) : (
<Button type="primary" onClick={handleAddROI}>ROI</Button>
)}
</Space>
<div className="roi-editor-container" style={{ display: 'flex', gap: 16 }}>
<div style={{ flex: 1, background: '#f0f0f0', display: 'flex', justifyContent: 'center', alignItems: 'center' }}>
<div className="roi-editor-container" style={{ display: 'flex', gap: 16, flexDirection: 'row' }}>
<div style={{
flex: 1,
background: '#f0f0f0',
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
minHeight: 500,
border: isDrawing ? '2px solid #1890ff' : '1px solid #d9d9d9',
borderRadius: 4,
position: 'relative'
}}>
{isDrawing && (
<div style={{
position: 'absolute',
top: 10,
left: 10,
zIndex: 10,
background: 'rgba(24, 144, 255, 0.9)',
color: 'white',
padding: '8px 16px',
borderRadius: 4,
fontSize: 14
}}>
-
</div>
)}
{snapshot ? (
<Stage width={imageDim.width} height={imageDim.height} ref={stageRef}>
<Stage
width={imageDim.width}
height={imageDim.height}
ref={stageRef}
onClick={handleStageClick}
onDblClick={handleStageDblClick}
style={{ cursor: isDrawing ? 'crosshair' : 'default' }}
>
<Layer>
<Rect
x={0}
y={0}
width={imageDim.width}
height={imageDim.height}
fillPatternImage={
(() => {
const img = new Image();
img.src = `data:image/jpeg;base64,${snapshot}`;
return img;
})()
}
fillPatternOffset={{ x: 0, y: 0 }}
fillPatternScale={{ x: 1, y: 1 }}
/>
{backgroundImage && (
<Rect
x={0}
y={0}
width={imageDim.width}
height={imageDim.height}
fillPatternImage={backgroundImage}
fillPatternOffset={{ x: 0, y: 0 }}
fillPatternScale={{ x: 1, y: 1 }}
/>
)}
{rois.map(renderROI)}
{isDrawing && tempPoints.length > 0 && (
<>
<Line
points={tempPoints.flat()}
stroke="#1890ff"
strokeWidth={2}
dash={[5, 5]}
/>
{tempPoints.map((point, idx) => (
<Circle
key={idx}
x={point[0]}
y={point[1]}
radius={5}
fill="#1890ff"
/>
))}
{tempPoints.map((point, idx) => (
<KonvaText
key={`label-${idx}`}
x={point[0] + 10}
y={point[1] - 10}
text={`${idx + 1}`}
fontSize={14}
fill="#1890ff"
/>
))}
</>
)}
</Layer>
</Stage>
) : (
<div>...</div>
<div style={{ color: '#999' }}>...</div>
)}
</div>
<div style={{ width: 300 }}>
<Card title="ROI列表" size="small">
{rois.map((roi) => (
<div
key={roi.id}
style={{
padding: 8,
marginBottom: 8,
background: '#fafafa',
borderRadius: 4,
cursor: 'pointer',
border: selectedROI?.id === roi.id ? '2px solid #1890ff' : '1px solid #d9d9d9',
}}
onClick={() => {
setSelectedROI(roi);
form.setFieldsValue(roi);
setDrawerVisible(true);
}}
>
<div style={{ fontWeight: 'bold' }}>{roi.name}</div>
<div style={{ fontSize: 12, color: '#666' }}>
: {roi.type} | : {roi.rule}
</div>
<Button
type="text"
danger
size="small"
onClick={(e) => {
e.stopPropagation();
handleDeleteROI(roi.id);
<div style={{ width: 280, flexShrink: 0 }}>
<Card title="ROI列表" size="small" bodyStyle={{ maxHeight: 500, overflow: 'auto' }}>
{rois.length === 0 ? (
<div style={{ color: '#999', textAlign: 'center', padding: 20 }}>
ROI区域"添加ROI"
</div>
) : (
rois.map((roi) => (
<div
key={roi.id}
style={{
padding: 8,
marginBottom: 8,
background: selectedROI?.id === roi.id ? '#e6f7ff' : '#fafafa',
borderRadius: 4,
cursor: 'pointer',
border: selectedROI?.id === roi.id ? '2px solid #1890ff' : '1px solid #d9d9d9',
}}
onClick={() => {
setSelectedROI(roi);
form.setFieldsValue({
name: roi.name,
roi_type: roi.type,
rule_type: roi.rule,
threshold_sec: roi.threshold_sec,
confirm_sec: roi.confirm_sec,
enabled: roi.enabled,
});
setDrawerVisible(true);
}}
>
</Button>
</div>
))}
<div style={{ fontWeight: 'bold', marginBottom: 4 }}>{roi.name}</div>
<div style={{ fontSize: 12, color: '#666', marginBottom: 4 }}>
: {roi.type === 'polygon' ? '多边形' : '线段'} | : {roi.rule === 'intrusion' ? '入侵检测' : '离岗检测'}
</div>
<Space size={4}>
<Button
type="link"
size="small"
danger
onClick={(e) => {
e.stopPropagation();
handleDeleteROI(roi.id);
}}
>
</Button>
</Space>
</div>
))
)}
</Card>
</div>
</div>
</Card>
<Drawer
title="编辑ROI"
title={selectedROI ? `编辑ROI - ${selectedROI.name}` : '编辑ROI'}
open={drawerVisible}
onClose={() => setDrawerVisible(false)}
onClose={() => {
setDrawerVisible(false);
setSelectedROI(null);
}}
width={400}
>
<Form form={form} layout="vertical" onFinish={handleSaveROI}>
<Form.Item name="name" label="名称" rules={[{ required: true }]}>
<Input />
<Form.Item name="name" label="名称" rules={[{ required: true, message: '请输入名称' }]}>
<Input placeholder="例如:入口入侵区域" />
</Form.Item>
<Form.Item name="roi_type" label="类型">
<Select options={[{ label: '多边形', value: 'polygon' }, { label: '线段', value: 'line' }]} />
</Form.Item>
<Form.Item name="rule_type" label="规则">
<Form.Item name="roi_type" label="类型" rules={[{ required: true }]}>
<Select
options={[
{ label: '离岗检测', value: 'leave_post' },
{ label: '周界入侵', value: 'intrusion' },
{ label: '多边形区域', value: 'polygon' },
{ label: '线段', value: 'line' },
]}
/>
</Form.Item>
<Form.Item name="threshold_sec" label="超时时间(秒)">
<InputNumber min={60} style={{ width: '100%' }} />
<Form.Item name="rule_type" label="检测规则" rules={[{ required: true }]}>
<Select
options={[
{ label: '周界入侵检测', value: 'intrusion' },
{ label: '离岗检测', value: 'leave_post' },
]}
/>
</Form.Item>
<Form.Item name="confirm_sec" label="确认时间(秒)">
<InputNumber min={5} style={{ width: '100%' }} />
</Form.Item>
<Form.Item name="enabled" label="启用" valuePropName="checked">
<input type="checkbox" />
{selectedROI?.rule === 'leave_post' && (
<>
<Form.Item name="threshold_sec" label="超时时间(秒)" rules={[{ required: true }]}>
<InputNumber min={60} style={{ width: '100%' }} />
</Form.Item>
<Form.Item name="confirm_sec" label="确认时间(秒)" rules={[{ required: true }]}>
<InputNumber min={5} style={{ width: '100%' }} />
</Form.Item>
</>
)}
<Form.Item name="enabled" label="启用状态" valuePropName="checked">
<Switch checkedChildren="启用" unCheckedChildren="停用" />
</Form.Item>
<Form.Item>
<Space>
<Button type="primary" htmlType="submit">
</Button>
<Button onClick={() => setDrawerVisible(false)}></Button>
<Button onClick={() => {
setDrawerVisible(false);
setSelectedROI(null);
}}>
</Button>
</Space>
</Form.Item>
</Form>

View File

@@ -1,4 +1,7 @@
import os
os.environ["TENSORRT_DISABLE_MYELIN"] = "1"
import time
from typing import Any, Dict, List, Optional, Tuple
@@ -6,12 +9,146 @@ import cv2
import numpy as np
import tensorrt as trt
import torch
import onnxruntime as ort
from ultralytics import YOLO
from ultralytics.engine.results import Results
from ultralytics.engine.results import Results, Boxes as UltralyticsBoxes
from config import get_config
class ONNXEngine:
def __init__(self, onnx_path: Optional[str] = None, device: int = 0):
config = get_config()
self.onnx_path = onnx_path or config.model.onnx_path
self.device = device
self.imgsz = tuple(config.model.imgsz)
self.conf_thresh = config.model.conf_threshold
self.iou_thresh = config.model.iou_threshold
self.session = None
self.input_names = None
self.output_names = None
self.load_model()
def load_model(self):
if not os.path.exists(self.onnx_path):
raise FileNotFoundError(f"ONNX模型文件不存在: {self.onnx_path}")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.device >= 0 else ['CPUExecutionProvider']
self.session = ort.InferenceSession(self.onnx_path, providers=providers)
self.input_names = [inp.name for inp in self.session.get_inputs()]
self.output_names = [out.name for out in self.session.get_outputs()]
def preprocess(self, frame: np.ndarray) -> np.ndarray:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, self.imgsz)
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
return img
def postprocess(self, output: np.ndarray, orig_img: np.ndarray) -> List[Results]:
c, n = output.shape
output = output.T
boxes = output[:, :4]
scores = output[:, 4]
classes = output[:, 5:].argmax(axis=1) if output.shape[1] > 5 else np.zeros(len(output), dtype=np.int32)
mask = scores > self.conf_thresh
boxes = boxes[mask]
scores = scores[mask]
classes = classes[mask]
if len(boxes) == 0:
return [Results(orig_img=orig_img, path="", names={0: "person"})]
indices = cv2.dnn.NMSBoxes(
boxes.tolist(),
scores.tolist(),
self.conf_thresh,
self.iou_thresh,
)
orig_h, orig_w = orig_img.shape[:2]
scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0]
filtered_boxes = []
for idx in indices:
if idx >= len(boxes):
continue
box = boxes[idx]
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
filtered_boxes.append([
int(x1 * scale_x),
int(y1 * scale_y),
int(w * scale_x),
int(h * scale_y),
float(scores[idx]),
int(classes[idx])
])
from ultralytics.engine.results import Boxes as BoxesObj
if filtered_boxes:
box_tensor = torch.tensor(filtered_boxes)
boxes_obj = BoxesObj(
box_tensor,
orig_shape=(orig_h, orig_w)
)
result = Results(
orig_img=orig_img,
path="",
names={0: "person"},
boxes=boxes_obj
)
return [result]
return [Results(orig_img=orig_img, path="", names={0: "person"})]
def inference(self, images: List[np.ndarray]) -> List[Results]:
if not images:
return []
batch_imgs = []
for frame in images:
img = self.preprocess(frame)
batch_imgs.append(img)
batch = np.stack(batch_imgs, axis=0)
inputs = {self.input_names[0]: batch}
outputs = self.session.run(self.output_names, inputs)
results = []
output = outputs[0]
if output.shape[0] == 1:
result = self.postprocess(output[0], images[0])
results.extend(result)
else:
for i in range(output.shape[0]):
result = self.postprocess(output[i], images[i])
results.extend(result)
return results
def inference_single(self, frame: np.ndarray) -> List[Results]:
return self.inference([frame])
def warmup(self, num_warmup: int = 10):
dummy_frame = np.zeros((640, 640, 3), dtype=np.uint8)
for _ in range(num_warmup):
self.inference_single(dummy_frame)
def __del__(self):
if self.session:
try:
self.session.end_profiling()
except Exception:
pass
class TensorRTEngine:
def __init__(self, engine_path: Optional[str] = None, device: int = 0):
config = get_config()
@@ -25,9 +162,11 @@ class TensorRTEngine:
self.logger = trt.Logger(trt.Logger.INFO)
self.engine = None
self.context = None
self.stream = None
self.stream = torch.cuda.Stream(device=self.device)
self.input_buffer = None
self.output_buffers = []
self.input_name = None
self.output_name = None
self._load_engine()
@@ -44,29 +183,39 @@ class TensorRTEngine:
self.context = self.engine.create_execution_context()
self.stream = torch.cuda.Stream(device=self.device)
self.batch_size = 1
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
dtype = self.engine.get_tensor_dtype(name)
shape = self.engine.get_tensor_shape(name)
shape = list(self.engine.get_tensor_shape(name))
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_tensor_address(name, None)
if -1 in shape:
shape = [self.batch_size if d == -1 else d for d in shape]
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
self.input_buffer = buffer
self.input_name = name
else:
if -1 in shape:
shape = [self.batch_size if d == -1 else d for d in shape]
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
self.output_buffers.append(buffer)
self.context.set_tensor_address(name, buffer.data_ptr())
if self.output_name is None:
self.output_name = name
self.context.set_optimization_profile_async(0, self.stream)
self.context.set_tensor_address(name, buffer.data_ptr())
self.input_buffer = torch.zeros(
(1, 3, self.imgsz[0], self.imgsz[1]),
dtype=torch.float16 if self.half else torch.float32,
device=self.device,
)
stream_handle = torch.cuda.current_stream(self.device).cuda_stream
self.context.set_optimization_profile_async(0, stream_handle)
self.batch_size = 1
def preprocess(self, frame: np.ndarray) -> torch.Tensor:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
@@ -95,16 +244,20 @@ class TensorRTEngine:
)
self.context.set_tensor_address(
"input", input_tensor.contiguous().data_ptr()
self.input_name, input_tensor.contiguous().data_ptr()
)
input_shape = list(input_tensor.shape)
self.context.set_input_shape(self.input_name, input_shape)
torch.cuda.synchronize(self.stream)
self.context.execute_async_v3(self.stream.handle)
self.context.execute_async_v3(self.stream.cuda_stream)
torch.cuda.synchronize(self.stream)
results = []
for i in range(batch_size):
pred = self.output_buffers[0][i].cpu().numpy()
pred = pred.T # 转置: (8400, 84)
boxes = pred[:, :4]
scores = pred[:, 4]
classes = pred[:, 5].astype(np.int32)
@@ -142,7 +295,7 @@ class TensorRTEngine:
orig_img=images[i],
path="",
names={0: "person"},
boxes=Boxes(
boxes=UltralyticsBoxes(
torch.tensor([box_orig + [conf, cls]]),
orig_shape=(orig_h, orig_w),
),
@@ -161,9 +314,15 @@ class TensorRTEngine:
def __del__(self):
if self.context:
self.context.synchronize()
try:
self.context.synchronize()
except Exception:
pass
if self.stream:
self.stream.synchronize()
try:
self.stream.synchronize()
except Exception:
pass
class Boxes:
@@ -196,6 +355,15 @@ class Boxes:
return self.data[:, 5]
def _check_pt_file_valid(pt_path: str) -> bool:
try:
with open(pt_path, 'rb') as f:
header = f.read(10)
return len(header) == 10
except Exception:
return False
class YOLOEngine:
def __init__(
self,
@@ -203,38 +371,61 @@ class YOLOEngine:
device: int = 0,
use_trt: bool = True,
):
self.use_trt = use_trt
self.device = device
self.use_trt = False
self.onnx_engine = None
self.trt_engine = None
self.device = device
config = get_config()
if not use_trt:
if model_path:
pt_path = model_path
elif hasattr(get_config().model, 'pt_model_path'):
pt_path = get_config().model.pt_model_path
else:
pt_path = get_config().model.engine_path.replace(".engine", ".pt")
self.model = YOLO(pt_path)
self.model.to(device)
else:
if use_trt:
try:
self.trt_engine = TensorRTEngine(device=device)
self.trt_engine.warmup()
self.use_trt = True
print("TensorRT引擎加载成功")
return
except Exception as e:
print(f"TensorRT加载失败回退到PyTorch: {e}")
self.use_trt = False
if model_path:
pt_path = model_path
elif hasattr(get_config().model, 'pt_model_path'):
pt_path = get_config().model.pt_model_path
else:
pt_path = get_config().model.engine_path.replace(".engine", ".pt")
print(f"TensorRT加载失败: {e}")
try:
onnx_path = config.model.onnx_path
if os.path.exists(onnx_path):
self.onnx_engine = ONNXEngine(device=device)
self.onnx_engine.warmup()
print("ONNX引擎加载成功")
return
else:
print(f"ONNX模型不存在: {onnx_path}")
except Exception as e:
print(f"ONNX加载失败: {e}")
try:
pt_path = model_path or config.model.pt_model_path
if os.path.exists(pt_path) and _check_pt_file_valid(pt_path):
self.model = YOLO(pt_path)
self.model.to(device)
print(f"PyTorch模型加载成功: {pt_path}")
else:
raise FileNotFoundError(f"PT文件无效或不存在: {pt_path}")
except Exception as e:
print(f"PyTorch加载失败: {e}")
raise RuntimeError("所有模型加载方式均失败")
def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]:
if self.use_trt:
return self.trt_engine.inference_single(frame)
if self.use_trt and self.trt_engine:
try:
return self.trt_engine.inference_single(frame)
except Exception as e:
print(f"TensorRT推理失败切换到ONNX: {e}")
self.use_trt = False
if self.onnx_engine:
return self.onnx_engine.inference_single(frame)
elif self.model:
return self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
else:
return []
elif self.onnx_engine:
return self.onnx_engine.inference_single(frame)
else:
results = self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
return results
@@ -242,3 +433,5 @@ class YOLOEngine:
def __del__(self):
if self.trt_engine:
del self.trt_engine
if self.onnx_engine:
del self.onnx_engine

View File

@@ -7,6 +7,8 @@ from contextlib import asynccontextmanager
from datetime import datetime
from typing import Optional
os.environ["TENSORRT_DISABLE_MYELIN"] = "1"
import cv2
import numpy as np
from fastapi import FastAPI, HTTPException
@@ -19,6 +21,7 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from api.alarm import router as alarm_router
from api.camera import router as camera_router
from api.roi import router as roi_router
from api.sync import router as sync_router
from config import get_config, load_config
from db.models import init_db
from inference.pipeline import get_pipeline, start_pipeline, stop_pipeline
@@ -81,6 +84,7 @@ app.add_middleware(
app.include_router(camera_router)
app.include_router(roi_router)
app.include_router(alarm_router)
app.include_router(sync_router)
@app.get("/")

Binary file not shown.