ROI选区01
This commit is contained in:
168
tests/test_core.py
Normal file
168
tests/test_core.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
def test_roi_filter_parse_points():
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
|
||||
filter = ROIFilter()
|
||||
points = filter.parse_points("[[100, 200], [300, 400]]")
|
||||
assert len(points) == 2
|
||||
assert points[0] == (100.0, 200.0)
|
||||
|
||||
|
||||
def test_roi_filter_polygon():
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
|
||||
filter = ROIFilter()
|
||||
points = [(0, 0), (100, 0), (100, 100), (0, 100)]
|
||||
polygon = filter.create_polygon(points)
|
||||
|
||||
assert polygon.area == 10000
|
||||
assert filter.is_point_in_polygon((50, 50), polygon) == True
|
||||
assert filter.is_point_in_polygon((150, 150), polygon) == False
|
||||
|
||||
|
||||
def test_roi_filter_bbox_center():
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
|
||||
filter = ROIFilter()
|
||||
center = filter.get_bbox_center([10, 20, 100, 200])
|
||||
assert center == (55.0, 110.0)
|
||||
|
||||
|
||||
def test_leave_post_algorithm_init():
|
||||
from inference.rules.algorithms import LeavePostAlgorithm
|
||||
|
||||
algo = LeavePostAlgorithm(
|
||||
threshold_sec=300,
|
||||
confirm_sec=30,
|
||||
return_sec=5,
|
||||
)
|
||||
|
||||
assert algo.threshold_sec == 300
|
||||
assert algo.confirm_sec == 30
|
||||
assert algo.return_sec == 5
|
||||
|
||||
|
||||
def test_leave_post_algorithm_process():
|
||||
from inference.rules.algorithms import LeavePostAlgorithm
|
||||
from datetime import datetime
|
||||
|
||||
algo = LeavePostAlgorithm(threshold_sec=360, confirm_sec=30, return_sec=5)
|
||||
|
||||
tracks = [
|
||||
{"bbox": [100, 100, 200, 200], "conf": 0.9, "cls": 0},
|
||||
]
|
||||
|
||||
alerts = algo.process("test_cam", tracks, datetime.now())
|
||||
assert isinstance(alerts, list)
|
||||
|
||||
|
||||
def test_intrusion_algorithm_init():
|
||||
from inference.rules.algorithms import IntrusionAlgorithm
|
||||
|
||||
algo = IntrusionAlgorithm(
|
||||
check_interval_sec=1.0,
|
||||
direction_sensitive=False,
|
||||
)
|
||||
|
||||
assert algo.check_interval_sec == 1.0
|
||||
|
||||
|
||||
def test_algorithm_manager():
|
||||
from inference.rules.algorithms import AlgorithmManager
|
||||
|
||||
manager = AlgorithmManager()
|
||||
|
||||
manager.register_algorithm("roi_1", "leave_post", {"threshold_sec": 300})
|
||||
|
||||
assert "roi_1" in manager.algorithms
|
||||
assert "leave_post" in manager.algorithms["roi_1"]
|
||||
|
||||
|
||||
def test_config_load():
|
||||
from config import load_config, get_config
|
||||
|
||||
config = load_config()
|
||||
|
||||
assert config.database.dialect in ["sqlite", "mysql"]
|
||||
assert config.model.imgsz == [480, 480]
|
||||
assert config.model.batch_size == 8
|
||||
|
||||
|
||||
def test_database_models():
|
||||
from db.models import Camera, ROI, Alarm, init_db
|
||||
|
||||
init_db()
|
||||
|
||||
camera = Camera(
|
||||
name="测试摄像头",
|
||||
rtsp_url="rtsp://test.local/cam1",
|
||||
fps_limit=30,
|
||||
process_every_n_frames=3,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
assert camera.name == "测试摄像头"
|
||||
assert camera.enabled == True
|
||||
|
||||
|
||||
def test_camera_crud():
|
||||
from db.crud import create_camera, get_camera_by_id
|
||||
from db.models import get_session_factory, init_db
|
||||
|
||||
init_db()
|
||||
SessionLocal = get_session_factory()
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
camera = create_camera(
|
||||
db,
|
||||
name="测试摄像头",
|
||||
rtsp_url="rtsp://test.local/cam1",
|
||||
fps_limit=30,
|
||||
)
|
||||
|
||||
assert camera.id is not None
|
||||
|
||||
fetched = get_camera_by_id(db, camera.id)
|
||||
assert fetched is not None
|
||||
assert fetched.name == "测试摄像头"
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_stream_reader_init():
|
||||
from inference.stream import StreamReader
|
||||
|
||||
reader = StreamReader(
|
||||
camera_id="test_cam",
|
||||
rtsp_url="rtsp://test.local/stream",
|
||||
buffer_size=2,
|
||||
)
|
||||
|
||||
assert reader.camera_id == "test_cam"
|
||||
assert reader.buffer_size == 2
|
||||
|
||||
|
||||
def test_utils_helpers():
|
||||
from utils.helpers import draw_bbox, draw_roi, format_duration
|
||||
|
||||
import numpy as np
|
||||
|
||||
image = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
result = draw_bbox(image, [100, 100, 200, 200], "Test")
|
||||
|
||||
assert result.shape == image.shape
|
||||
|
||||
duration = format_duration(125.5)
|
||||
assert "2分" in duration
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user