√(noham)² 3cf13b815c clean
2024-07-18 00:42:59 +02:00

306 lines
11 KiB
Python

"""
main code for track
"""
import sys, os
import numpy as np
import torch
import cv2
from PIL import Image
from tqdm import tqdm
import yaml
from loguru import logger
import argparse
from tracking_utils.envs import select_device
from tracking_utils.tools import *
from tracking_utils.visualization import plot_img, save_video
from my_timer import Timer
from tracker_dataloader import TestDataset
# trackers
from trackers.byte_tracker import ByteTracker
from trackers.sort_tracker import SortTracker
from trackers.botsort_tracker import BotTracker
from trackers.c_biou_tracker import C_BIoUTracker
from trackers.ocsort_tracker import OCSortTracker
from trackers.deepsort_tracker import DeepSortTracker
from trackers.strongsort_tracker import StrongSortTracker
from trackers.sparse_tracker import SparseTracker
# YOLOX modules
try:
from yolox.exp import get_exp
from yolox_utils.postprocess import postprocess_yolox
from yolox.utils import fuse_model
except Exception as e:
logger.warning(e)
logger.warning('Load yolox fail. If you want to use yolox, please check the installation.')
pass
# YOLOv7 modules
try:
sys.path.append(os.getcwd())
from models.experimental import attempt_load
from utils.torch_utils import select_device, time_synchronized, TracedModel
from utils.general import non_max_suppression, scale_coords, check_img_size
from yolov7_utils.postprocess import postprocess as postprocess_yolov7
except Exception as e:
logger.warning(e)
logger.warning('Load yolov7 fail. If you want to use yolov7, please check the installation.')
pass
# YOLOv8 modules
try:
from ultralytics import YOLO
from yolov8_utils.postprocess import postprocess as postprocess_yolov8
except Exception as e:
logger.warning(e)
logger.warning('Load yolov8 fail. If you want to use yolov8, please check the installation.')
pass
TRACKER_DICT = {
'sort': SortTracker,
'bytetrack': ByteTracker,
'botsort': BotTracker,
'c_bioutrack': C_BIoUTracker,
'ocsort': OCSortTracker,
'deepsort': DeepSortTracker,
'strongsort': StrongSortTracker,
'sparsetrack': SparseTracker
}
def get_args():
parser = argparse.ArgumentParser()
"""general"""
parser.add_argument('--dataset', type=str, default='visdrone_part', help='visdrone, mot17, etc.')
parser.add_argument('--detector', type=str, default='yolov8', help='yolov7, yolox, etc.')
parser.add_argument('--tracker', type=str, default='sort', help='sort, deepsort, etc')
parser.add_argument('--reid_model', type=str, default='osnet_x0_25', help='osnet or deppsort')
parser.add_argument('--kalman_format', type=str, default='default', help='use what kind of Kalman, sort, deepsort, byte, etc.')
parser.add_argument('--img_size', type=int, default=1280, help='image size, [h, w]')
parser.add_argument('--conf_thresh', type=float, default=0.2, help='filter tracks')
parser.add_argument('--nms_thresh', type=float, default=0.7, help='thresh for NMS')
parser.add_argument('--iou_thresh', type=float, default=0.5, help='IOU thresh to filter tracks')
parser.add_argument('--device', type=str, default='6', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
"""yolox"""
parser.add_argument('--yolox_exp_file', type=str, default='./tracker/yolox_utils/yolox_m.py')
"""model path"""
parser.add_argument('--detector_model_path', type=str, default='./weights/best.pt', help='model path')
parser.add_argument('--trace', type=bool, default=False, help='traced model of YOLO v7')
# other model path
parser.add_argument('--reid_model_path', type=str, default='./weights/osnet_x0_25.pth', help='path for reid model path')
parser.add_argument('--dhn_path', type=str, default='./weights/DHN.pth', help='path of DHN path for DeepMOT')
"""other options"""
parser.add_argument('--discard_reid', action='store_true', help='discard reid model, only work in bot-sort etc. which need a reid part')
parser.add_argument('--track_buffer', type=int, default=30, help='tracking buffer')
parser.add_argument('--gamma', type=float, default=0.1, help='param to control fusing motion and apperance dist')
parser.add_argument('--min_area', type=float, default=150, help='use to filter small bboxs')
parser.add_argument('--save_dir', type=str, default='track_results/{dataset_name}/{split}')
parser.add_argument('--save_images', action='store_true', help='save tracking results (image)')
parser.add_argument('--save_videos', action='store_true', help='save tracking results (video)')
parser.add_argument('--track_eval', type=bool, default=True, help='Use TrackEval to evaluate')
return parser.parse_args()
def main(args, dataset_cfgs):
"""1. set some params"""
# NOTE: if save video, you must save image
if args.save_videos:
args.save_images = True
"""2. load detector"""
device = select_device(args.device)
if args.detector == 'yolox':
exp = get_exp(args.yolox_exp_file, None) # TODO: modify num_classes etc. for specific dataset
model_img_size = exp.input_size
model = exp.get_model()
model.to(device)
model.eval()
logger.info(f"loading detector {args.detector} checkpoint {args.detector_model_path}")
ckpt = torch.load(args.detector_model_path, map_location=device)
model.load_state_dict(ckpt['model'])
logger.info("loaded checkpoint done")
model = fuse_model(model)
stride = None # match with yolo v7
logger.info(f'Now detector is on device {next(model.parameters()).device}')
elif args.detector == 'yolov7':
logger.info(f"loading detector {args.detector} checkpoint {args.detector_model_path}")
model = attempt_load(args.detector_model_path, map_location=device)
# get inference img size
stride = int(model.stride.max()) # model stride
model_img_size = check_img_size(args.img_size, s=stride) # check img_size
# Traced model
model = TracedModel(model, device=device, img_size=args.img_size)
# model.half()
logger.info("loaded checkpoint done")
logger.info(f'Now detector is on device {next(model.parameters()).device}')
elif args.detector == 'yolov8':
logger.info(f"loading detector {args.detector} checkpoint {args.detector_model_path}")
model = YOLO(args.detector_model_path)
model_img_size = [None, None]
stride = None
logger.info("loaded checkpoint done")
else:
logger.error(f"detector {args.detector} is not supprted")
exit(0)
"""3. load sequences"""
DATA_ROOT = dataset_cfgs['DATASET_ROOT']
SPLIT = dataset_cfgs['SPLIT']
seqs = sorted(os.listdir(os.path.join(DATA_ROOT, 'images', SPLIT)))
seqs = [seq for seq in seqs if seq not in dataset_cfgs['IGNORE_SEQS']]
if not None in dataset_cfgs['CERTAIN_SEQS']:
seqs = dataset_cfgs['CERTAIN_SEQS']
logger.info(f'Total {len(seqs)} seqs will be tracked: {seqs}')
save_dir = args.save_dir.format(dataset_name=args.dataset, split=SPLIT)
"""4. Tracking"""
# set timer
timer = Timer()
seq_fps = []
for seq in seqs:
logger.info(f'--------------tracking seq {seq}--------------')
dataset = TestDataset(DATA_ROOT, SPLIT, seq_name=seq, img_size=model_img_size, model=args.detector, stride=stride)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
tracker = TRACKER_DICT[args.tracker](args, )
process_bar = enumerate(data_loader)
process_bar = tqdm(process_bar, total=len(data_loader), ncols=150)
results = []
for frame_idx, (ori_img, img) in process_bar:
# start timing this frame
timer.tic()
if args.detector == 'yolov8':
img = img.squeeze(0).cpu().numpy()
else:
img = img.to(device) # (1, C, H, W)
img = img.float()
ori_img = ori_img.squeeze(0)
# get detector output
with torch.no_grad():
if args.detector == 'yolov8':
output = model.predict(img, conf=args.conf_thresh, iou=args.nms_thresh)
else:
output = model(img)
# postprocess output to original scales
if args.detector == 'yolox':
output = postprocess_yolox(output, len(dataset_cfgs['CATEGORY_NAMES']), conf_thresh=args.conf_thresh,
img=img, ori_img=ori_img)
elif args.detector == 'yolov7':
output = postprocess_yolov7(output, args.conf_thresh, args.nms_thresh, img.shape[2:], ori_img.shape)
elif args.detector == 'yolov8':
output = postprocess_yolov8(output)
else: raise NotImplementedError
# output: (tlbr, conf, cls)
# convert tlbr to tlwh
if isinstance(output, torch.Tensor):
output = output.detach().cpu().numpy()
output[:, 2] -= output[:, 0]
output[:, 3] -= output[:, 1]
current_tracks = tracker.update(output, img, ori_img.cpu().numpy())
# save results
cur_tlwh, cur_id, cur_cls, cur_score = [], [], [], []
for trk in current_tracks:
bbox = trk.tlwh
id = trk.track_id
cls = trk.category
score = trk.score
# filter low area bbox
if bbox[2] * bbox[3] > args.min_area:
cur_tlwh.append(bbox)
cur_id.append(id)
cur_cls.append(cls)
cur_score.append(score)
# results.append((frame_id + 1, id, bbox, cls))
results.append((frame_idx + 1, cur_id, cur_tlwh, cur_cls, cur_score))
timer.toc()
if args.save_images:
plot_img(img=ori_img, frame_id=frame_idx, results=[cur_tlwh, cur_id, cur_cls],
save_dir=os.path.join(save_dir, 'vis_results'))
save_results(folder_name=os.path.join(args.dataset, SPLIT),
seq_name=seq,
results=results)
# show the fps
seq_fps.append(frame_idx / timer.total_time)
logger.info(f'fps of seq {seq}: {seq_fps[-1]}')
timer.clear()
if args.save_videos:
save_video(images_path=os.path.join(save_dir, 'vis_results'))
logger.info(f'save video of {seq} done')
# show the average fps
logger.info(f'average fps: {np.mean(seq_fps)}')
if __name__ == '__main__':
args = get_args()
with open(f'./tracker/config_files/{args.dataset}.yaml', 'r') as f:
cfgs = yaml.load(f, Loader=yaml.FullLoader)
main(args, cfgs)