百度飞浆行人多目标跟踪笔记

x33g5p2x  于2021-11-10 转载在 其他  
字(4.2k)|赞(0)|评价(0)|浏览(521)

开源地址:

PaddleDetection/configs/mot at release/2.3 · PaddlePaddle/PaddleDetection · GitHub

百度飞浆集成了多目标跟踪的多种算法,地址:

PaddleDetection/configs/mot at release/2.3 · PaddlePaddle/PaddleDetection · GitHub

deepsort:

jde

farimot:

本人测试结果如下,后续可能继续跟踪跟进。

本机代码:运行ok:

PaddleDetection-release-2.3

环境,py37

测试入口类:

tools/infer_mot.py

测试结果:有漏检,

奇怪的地方:

如果读取的是视频文件,先用ffmpeg转为图片,然后排序,读取图片列表,

直接读取图片就可以把?

cap = cv2.VideoCapture(self.video_file)

电脑没有安装ffmpeg,所以把程序改了一下,直接读取文件夹的图片:

  1. def _load_video_images(self):
  2. if self.frame_rate == -1:
  3. # if frame_rate is not set for video, use cv2.VideoCapture
  4. cap = cv2.VideoCapture(self.video_file)
  5. self.frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
  6. extension = self.video_file.split('.')[-1]
  7. output_path = self.video_file.replace('.{}'.format(extension), '')
  8. # frames_path = video2frames(self.video_file, output_path,
  9. # self.frame_rate)
  10. self.video_frames = natsorted(
  11. glob.glob(os.path.join(output_path, '*.jpg')))
  12. self.video_length = len(self.video_frames)
  13. logger.info('Length of the video: {:d} frames.'.format(
  14. self.video_length))
  15. ct = 0
  16. records = []
  17. for image in self.video_frames:
  18. assert image != '' and os.path.isfile(image), \
  19. "Image {} not found".format(image)
  20. if self.sample_num > 0 and ct >= self.sample_num:
  21. break
  22. rec = {'im_id': np.array([ct]), 'im_file': image}
  23. if self.keep_ori_im:
  24. rec.update({'keep_ori_im': 1})
  25. self._imid2path[ct] = image
  26. ct += 1
  27. records.append(rec)
  28. assert len(records) > 0, "No image file found"
  29. return records

改后入口类:

  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import os
  5. import sys
  6. # add python path of PadleDetection to sys.path
  7. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  8. sys.path.insert(0, parent_path)
  9. import warnings
  10. warnings.filterwarnings('ignore')
  11. import paddle
  12. from paddle.distributed import ParallelEnv
  13. from ppdet.core.workspace import load_config, merge_config
  14. from ppdet.engine import Tracker
  15. from ppdet.utils.check import check_gpu, check_version, check_config
  16. from ppdet.utils.cli import ArgsParser
  17. from ppdet.utils.logger import setup_logger
  18. logger = setup_logger('train')
  19. def parse_args():
  20. parser = ArgsParser()
  21. parser.add_argument('--config', type=str, default="../configs/mot/fairmot/fairmot_dla34_30e_576x320.yml", help='Video name for tracking.')
  22. parser.add_argument('--video_file', type=str, default="1.mp4", help='Video name for tracking.')
  23. parser.add_argument('--frame_rate', type=int, default=-1, help='Video frame rate for tracking.')
  24. parser.add_argument("--image_dir", type=str, default=None, help="Directory for images to perform inference on.")
  25. parser.add_argument("--det_results_dir", type=str, default='', help="Directory name for detection results.")
  26. parser.add_argument('--output_dir', type=str, default='output', help='Directory name for output tracking results.')
  27. parser.add_argument('--save_images', default=False, help='Save tracking results (image).')
  28. parser.add_argument('--save_videos', default=False, help='Save tracking results (video).')
  29. parser.add_argument('--show_image', default=True, help='Show tracking results (image).')
  30. parser.add_argument('--scaled', type=bool, default=False, help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
  31. "True in general detector.")
  32. parser.add_argument("--draw_threshold", type=float, default=0.5, help="Threshold to reserve the result for visualization.")
  33. args = parser.parse_args()
  34. return args
  35. def run(FLAGS, cfg):
  36. # build Tracker
  37. tracker = Tracker(cfg, mode='test')
  38. # load weights
  39. if cfg.architecture in ['DeepSORT']:
  40. if cfg.det_weights != 'None':
  41. tracker.load_weights_sde(cfg.det_weights, cfg.reid_weights)
  42. else:
  43. tracker.load_weights_sde(None, cfg.reid_weights)
  44. else:
  45. tracker.load_weights_jde(cfg.weights)
  46. # inference
  47. tracker.mot_predict(video_file=FLAGS.video_file, frame_rate=FLAGS.frame_rate, image_dir=FLAGS.image_dir, data_type=cfg.metric.lower(), model_type=cfg.architecture, output_dir=FLAGS.output_dir,
  48. save_images=FLAGS.save_images, save_videos=FLAGS.save_videos, show_image=FLAGS.show_image, scaled=FLAGS.scaled, det_results_dir=FLAGS.det_results_dir, draw_threshold=FLAGS.draw_threshold)
  49. if __name__ == '__main__':
  50. FLAGS = parse_args()
  51. cfg = load_config(FLAGS.config)
  52. merge_config(FLAGS.opt)
  53. check_config(cfg)
  54. check_gpu(cfg.use_gpu)
  55. check_version()
  56. place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
  57. place = paddle.set_device(place)
  58. run(FLAGS, cfg)

相关文章