Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add referring tracker tool for track anything #67

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions agentlego/tools/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .ref_tracker import ReferringTracker
__all__ = ['ReferringTracker']
181 changes: 181 additions & 0 deletions agentlego/tools/tracking/ref_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from agentlego.types import Annotated, ImageIO, Info, VideoIO
from agentlego.utils import load_or_build_object, require

from agentlego.tools import BaseTool
from agentlego.tools.segmentation.segment_anything import load_sam_and_predictor

from mmdet.models.trackers import OCSORTTracker, ByteTracker, QuasiDenseTracker
import os
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2

class ReferringTracker(BaseTool):

default_desc = ('The tool can track and segment objects location according to description.')

@require('mmdet>=3.1.0')
def __init__(self,
model: str = 'glip_atss_swin-t_b_fpn_dyhead_pretrain_obj365',
weight: str = '/root/justTrack/weights/glip_tiny_b_mmdet-6dfbd102.pth',
sam_weight: str = '/root/justTrack/weights/sam_vit_h_4b8939.pth',
device: str = 'cuda',
tracker: str = 'bytetrack',
toolmeta=None):
super().__init__(toolmeta=toolmeta)
self.model = model
self.weights = weight
self.sam_weight = sam_weight
self.device = device

self.TRACKER_DICT = {
'ocsort': OCSORTTracker,
'bytetrack': ByteTracker,
'qdtrack': QuasiDenseTracker
}

self.TRACKER_CONFIG_DICT = {
'ocsort': dict(obj_score_thr=0.5, init_track_thr=0.5, ),
'bytetrack': dict(obj_score_thrs=dict(high=0.5, low=0.1), init_track_thr=0.5, ),
'qdtrack': dict(init_score_thr=0.1, obj_score_thr=0.5, ),
}

self.tracker = self.TRACKER_DICT[tracker](motion=dict(type='KalmanFilter'),
**self.TRACKER_CONFIG_DICT[tracker])

self.top_K = 1
self.frame_cnt = 0

self.draw = True
self.save_dir = './output_dir_{text_disc}'

def setup(self):
from mmdet.apis import DetInferencer
self._inferencer = load_or_build_object(
DetInferencer, model=self.model, weights=self.weights, device=self.device)
self._visualizer = self._inferencer.visualizer

self.sam, self.sam_predictor = load_sam_and_predictor(
self.sam_weight, device=self.device)

def _draw_bboxes(self, image, bboxes, ids, text, masks=None):
"""
Draw current tracking results on image
"""
save_dir = self.save_dir.format(text_disc=text)

if not os.path.exists(save_dir):
os.makedirs(save_dir)

image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)

for bbox, id in zip(bboxes, ids):
bbox = bbox.int()
id = id.int().item()
x1, y1, x2, y2 = bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item()

color = get_color(id)
cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=3)
cv2.putText(image, text=str(id), org=(x1, y1), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=2.0, color=color, thickness=2)

if masks is not None:
for idx, mask in enumerate(masks):
id = ids[idx].int().item()
color = np.array(get_color(id))

mask = mask[0].cpu().numpy()
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255

image = cv2.addWeighted(image, 0.7, mask_image.astype('uint8'), 0.3, 0)

cv2.imwrite(filename=os.path.join(save_dir, '{:07d}.jpg'.format(self.frame_cnt)), img=image)

def _get_mask_with_boxes(self, image, boxes_filt):

boxes_filt = boxes_filt.cpu()
transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(
boxes_filt, image.shape[:2]).to(self.device)

features = self.sam_predictor.get_image_embedding(image)

masks, _, _ = self.sam_predictor.predict_torch(
features=features,
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.device),
multimask_output=False,
)
return masks

def _add_mask_to_image(self, ):
pass

def apply(
self,
video: VideoIO,
task: Annotated[str, Info("Task description, should be 'track' or 'segment'")],
text: Annotated[str, Info('The object description in English.')],
) -> Annotated[str,
Info('Tracked objects, include a set of bboxes in '
'(x1, y1, x2, y2) format, and detection scores and ids.')]:
from mmdet.structures import DetDataSample

need_segment = 'segment' in task.lower()

pred_descs = []
while not video.is_finish():
self.frame_cnt += 1
image_PIL = video.next_image()
if image_PIL is None: break

image = ImageIO(image_PIL)

results = self._inferencer(
image.to_array()[:, :, ::-1],
texts=text,
return_datasamples=True,
)
data_sample = results['predictions'][0]
preds: DetDataSample = data_sample.pred_instances
preds = preds[preds.scores > 0.3]
preds = preds[preds.scores.topk(min(preds.scores.shape[0], self.top_K)).indices]

data_sample = DetDataSample()
data_sample.pred_instances = preds

pred_track_instances = self.tracker.track(data_sample)

bboxes = pred_track_instances.bboxes
scores = pred_track_instances.scores
ids = pred_track_instances.instances_id + 1
labels = pred_track_instances.labels

masks = None
if need_segment:
masks = self._get_mask_with_boxes(image.to_array(), bboxes)

if self.draw:
self._draw_bboxes(image_PIL, bboxes, ids, text, masks)


if bboxes.shape[0] == 0:
pred_descs.append(f'frame {self.frame_cnt}, No object found.')

else:
pred_tmpl = '(In frame {:d}: id {:d}, bbox {:.0f}, {:.0f}, {:.0f}, {:.0f}, score {:.0f})'
for id, bbox, score in zip(ids, bboxes, scores):
pred_descs.append(pred_tmpl.format(self.frame_cnt, id, bbox[0], bbox[1], bbox[2], bbox[3], score * 100))
pred_str = '\n'.join(pred_descs)

return pred_str

def get_color(idx):
"""
aux func for plot_seq
get a unique color for each id
"""
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)

return color
85 changes: 85 additions & 0 deletions agentlego/tools/tracking/video_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from agentlego.types import IOType, ImageIO
from agentlego.utils.file import temp_path
from io import BytesIO, IOBase
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union

import numpy as np
from PIL import Image
from typing_extensions import Annotated

import os

class VideoIO(IOType):
support_types = {'path': str}

def __init__(self, value: str):
super().__init__(value)
if self.type == 'path' and not Path(self.value).exists():
raise FileNotFoundError(f"No such file: '{self.value}'")

self.root_path = Path(self.value)
self.images = sorted(os.listdir(self.root_path))
self.cnt = 0

def to_path(self) -> str:
return self.to('path')

def to_pil(self) -> Image.Image:
return self.to('pil')

def to_array(self) -> np.ndarray:
return self.to('array')

def to_file(self) -> IOBase:
if self.type == 'path':
return open(self.value, 'rb')
else:
file = BytesIO()
self.to_pil().save(file, 'PNG')
file.seek(0)
return file

def next_image(self) -> Image:
ret = Image.open(self.images[self.cnt])
self.cnt += 1
return ret

def is_finish(self) -> bool:
return self.cnt == len(self.images) - 1

@classmethod
def from_file(cls, file: IOBase) -> 'ImageIO':
from PIL import Image
return cls(Image.open(file))

@staticmethod
def _path_to_pil(path: str) -> Image.Image:
return Image.open(path)

@staticmethod
def _path_to_array(path: str) -> np.ndarray:
return np.array(Image.open(path).convert('RGB'))

@staticmethod
def _pil_to_path(image: Image.Image) -> str:
filename = temp_path('image', '.png')
image.save(filename)
return filename

@staticmethod
def _pil_to_array(image: Image.Image) -> np.ndarray:
return np.array(image.convert('RGB'))

@staticmethod
def _array_to_pil(image: np.ndarray) -> Image.Image:
return Image.fromarray(image)

@staticmethod
def _array_to_path(image: np.ndarray) -> str:
filename = temp_path('image', '.png')
Image.fromarray(image).save(filename)
return filename



95 changes: 94 additions & 1 deletion agentlego/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,98 @@ def apply(
"""
return Parameter(description=description, name=name, filetype=filetype)

import os
import cv2
class VideoIO(IOType):
support_types = {'path': str, 'file': str}

def __init__(self, value: str):
super().__init__(value)
if not Path(self.value).exists():
self.value = self.value[:-4] # remove the wrong suffix (.mp4 .avi etc.)
if not Path(self.value).exists():
raise FileNotFoundError(f"No such file: '{self.value}'")

if '.mp4' in value or '.avi' in value: self.type = 'file'

self.root_path, self.images, self.cap = None, None, None
if self.type == 'path':
self.root_path = Path(self.value)
self.images = sorted(os.listdir(self.root_path))
else:
self.cap = cv2.VideoCapture(value)

self.cnt = 0

def to_path(self) -> str:
return self.to('path')

def to_pil(self) -> Image.Image:
return self.to('pil')

def to_array(self) -> np.ndarray:
return self.to('array')

def to_file(self) -> IOBase:
if self.type == 'path':
return open(self.value, 'rb')
else:
file = BytesIO()
self.to_pil().save(file, 'PNG')
file.seek(0)
return file

def next_image(self) -> Image:
if self.type == 'path':
ret = Image.open(os.path.join(self.root_path, self.images[self.cnt]))
else:
ret, frame = self.cap.read()
if not ret: return None

frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
ret = Image.fromarray(frame)

self.cnt += 1
return ret

def is_finish(self) -> bool:
if self.type == 'path':
return self.cnt == len(self.images) - 1
else:
return False

@classmethod
def from_file(cls, file: IOBase) -> 'ImageIO':
from PIL import Image
return cls(Image.open(file))

@staticmethod
def _path_to_pil(path: str) -> Image.Image:
return Image.open(path)

@staticmethod
def _path_to_array(path: str) -> np.ndarray:
return np.array(Image.open(path).convert('RGB'))

@staticmethod
def _pil_to_path(image: Image.Image) -> str:
filename = temp_path('image', '.png')
image.save(filename)
return filename

@staticmethod
def _pil_to_array(image: Image.Image) -> np.ndarray:
return np.array(image.convert('RGB'))

@staticmethod
def _array_to_pil(image: np.ndarray) -> Image.Image:
return Image.fromarray(image)

@staticmethod
def _array_to_path(image: np.ndarray) -> str:
filename = temp_path('image', '.png')
Image.fromarray(image).save(filename)
return filename

CatgoryToIO = {
'image': ImageIO,
Expand All @@ -257,6 +349,7 @@ def apply(
'int': int,
'float': float,
'file': File,
'video': VideoIO,
}

__all__ = ['ImageIO', 'AudioIO', 'CatgoryToIO', 'Info', 'Annotated']
__all__ = ['ImageIO', 'AudioIO', 'CatgoryToIO', 'Info', 'Annotated', 'VideoIO']