169 lines
4.0 KiB
Python
169 lines
4.0 KiB
Python
import pytest
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
def test_roi_filter_parse_points():
|
|
from inference.roi.roi_filter import ROIFilter
|
|
|
|
filter = ROIFilter()
|
|
points = filter.parse_points("[[100, 200], [300, 400]]")
|
|
assert len(points) == 2
|
|
assert points[0] == (100.0, 200.0)
|
|
|
|
|
|
def test_roi_filter_polygon():
|
|
from inference.roi.roi_filter import ROIFilter
|
|
|
|
filter = ROIFilter()
|
|
points = [(0, 0), (100, 0), (100, 100), (0, 100)]
|
|
polygon = filter.create_polygon(points)
|
|
|
|
assert polygon.area == 10000
|
|
assert filter.is_point_in_polygon((50, 50), polygon) == True
|
|
assert filter.is_point_in_polygon((150, 150), polygon) == False
|
|
|
|
|
|
def test_roi_filter_bbox_center():
|
|
from inference.roi.roi_filter import ROIFilter
|
|
|
|
filter = ROIFilter()
|
|
center = filter.get_bbox_center([10, 20, 100, 200])
|
|
assert center == (55.0, 110.0)
|
|
|
|
|
|
def test_leave_post_algorithm_init():
|
|
from inference.rules.algorithms import LeavePostAlgorithm
|
|
|
|
algo = LeavePostAlgorithm(
|
|
threshold_sec=300,
|
|
confirm_sec=30,
|
|
return_sec=5,
|
|
)
|
|
|
|
assert algo.threshold_sec == 300
|
|
assert algo.confirm_sec == 30
|
|
assert algo.return_sec == 5
|
|
|
|
|
|
def test_leave_post_algorithm_process():
|
|
from inference.rules.algorithms import LeavePostAlgorithm
|
|
from datetime import datetime
|
|
|
|
algo = LeavePostAlgorithm(threshold_sec=360, confirm_sec=30, return_sec=5)
|
|
|
|
tracks = [
|
|
{"bbox": [100, 100, 200, 200], "conf": 0.9, "cls": 0},
|
|
]
|
|
|
|
alerts = algo.process("roi_1", "test_cam", tracks, datetime.now())
|
|
assert isinstance(alerts, list)
|
|
|
|
|
|
def test_intrusion_algorithm_init():
|
|
from inference.rules.algorithms import IntrusionAlgorithm
|
|
|
|
algo = IntrusionAlgorithm(
|
|
check_interval_sec=1.0,
|
|
direction_sensitive=False,
|
|
)
|
|
|
|
assert algo.check_interval_sec == 1.0
|
|
|
|
|
|
def test_algorithm_manager():
|
|
from inference.rules.algorithms import AlgorithmManager
|
|
|
|
manager = AlgorithmManager()
|
|
|
|
manager.register_algorithm("roi_1", "leave_post", {"threshold_sec": 300})
|
|
|
|
assert "roi_1" in manager.algorithms
|
|
assert "leave_post" in manager.algorithms["roi_1"]
|
|
|
|
|
|
def test_config_load():
|
|
from config import load_config, get_config
|
|
|
|
config = load_config()
|
|
|
|
assert config.database.dialect in ["sqlite", "mysql"]
|
|
assert config.model.imgsz == [480, 480]
|
|
assert config.model.batch_size == 8
|
|
|
|
|
|
def test_database_models():
|
|
from db.models import Camera, ROI, Alarm, init_db
|
|
|
|
init_db()
|
|
|
|
camera = Camera(
|
|
name="测试摄像头",
|
|
rtsp_url="rtsp://test.local/cam1",
|
|
fps_limit=30,
|
|
process_every_n_frames=3,
|
|
enabled=True,
|
|
)
|
|
|
|
assert camera.name == "测试摄像头"
|
|
assert camera.enabled == True
|
|
|
|
|
|
def test_camera_crud():
|
|
from db.crud import create_camera, get_camera_by_id
|
|
from db.models import get_session_factory, init_db
|
|
|
|
init_db()
|
|
SessionLocal = get_session_factory()
|
|
db = SessionLocal()
|
|
|
|
try:
|
|
camera = create_camera(
|
|
db,
|
|
name="测试摄像头",
|
|
rtsp_url="rtsp://test.local/cam1",
|
|
fps_limit=30,
|
|
)
|
|
|
|
assert camera.id is not None
|
|
|
|
fetched = get_camera_by_id(db, camera.id)
|
|
assert fetched is not None
|
|
assert fetched.name == "测试摄像头"
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_stream_reader_init():
|
|
from inference.stream import StreamReader
|
|
|
|
reader = StreamReader(
|
|
camera_id="test_cam",
|
|
rtsp_url="rtsp://test.local/stream",
|
|
buffer_size=2,
|
|
)
|
|
|
|
assert reader.camera_id == "test_cam"
|
|
assert reader.buffer_size == 2
|
|
|
|
|
|
def test_utils_helpers():
|
|
from utils.helpers import draw_bbox, draw_roi, format_duration
|
|
|
|
import numpy as np
|
|
|
|
image = np.zeros((480, 640, 3), dtype=np.uint8)
|
|
|
|
result = draw_bbox(image, [100, 100, 200, 200], "Test")
|
|
|
|
assert result.shape == image.shape
|
|
|
|
duration = format_duration(125.5)
|
|
assert "2分" in duration
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|