diff --git a/track.py b/track.py new file mode 100644 index 0000000000000000000000000000000000000000..e3df5553b1b7955e6a9c428deca0a4f3eef308d7 --- /dev/null +++ b/track.py @@ -0,0 +1,41 @@ +from collections import defaultdict + +import cv2 + +from ultralytics import YOLO +from ultralytics.utils.plotting import Annotator, colors + +track_history = defaultdict(lambda: []) + +model = YOLO("./runs/segment/train2/weights/best.pt") # segmentation model +cap = cv2.VideoCapture("./data/video/hens_smol.mp4") +w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + +out = cv2.VideoWriter("hen-segmentation-object-tracking.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h)) + +while True: + ret, im0 = cap.read() + if not ret: + print("Video frame is empty or video processing has been successfully completed.") + break + + annotator = Annotator(im0, line_width=2) + + results = model.track(im0, persist=True) + + if results[0].boxes.id is not None and results[0].masks is not None: + masks = results[0].masks.xy + track_ids = results[0].boxes.id.int().cpu().tolist() + + for mask, track_id in zip(masks, track_ids): + annotator.seg_bbox(mask=mask, mask_color=colors(track_id, True), track_label=str(track_id)) + + out.write(im0) + # cv2.imshow("instance-segmentation-object-tracking", im0) + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + +out.release() +cap.release() +cv2.destroyAllWindows() \ No newline at end of file