Lời mở đầu
Phân vùng đối tượng là một bài toán khá phổ biến trong lĩnh vực computer vision. Trong open cv có hỗ trợ cho chúng ta một số hàm để phân vùng đối tượng rất dễ sử dụng. Đặc điểm chung của các hàm này là độ chính xác không được cao cho lắm. Ở bài viết này, chúng ta sẽ tìm hiểu cách sử dụng mô hình pretrain của DNN để phân vùng các đối tượng trong ảnh.
Sử dụng pretrain model
Đầu tiên, các bạn download file pretrain model, giải nén ra và để ở đâu đó trong ổ cứng của máy bạn. Đường dẫn file pretrain model các bạn có thể download ở http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_v2_coco_2018_01_28.tar.gz. Các bạn có thể download các file pretrain khác nếu có hứng thú tìm hiểu.
Tiếp theo, chúng ta sẽ load mô hình lên:
1import numpy as np
2import os
3import sys
4import tarfile
5import tensorflow as tf
6
7from collections import defaultdict
8from io import StringIO
9from matplotlib import pyplot as plt
10from PIL import Image
11import PIL.ImageDraw as ImageDraw
12import PIL.ImageFont as ImageFont
13import cv2
14
15import pprint
16
17import PIL.Image as Image
18import PIL.ImageColor as ImageColor
19
20# Model preparation
21
22
23# Path to frozen detection graph. This is the actual model that is used for the object detection.
24PATH_TO_CKPT = 'mask_rcnn_inception_v2_coco_2018_01_28' + '/frozen_inference_graph.pb'
25
26# List of the strings that is used to add correct label for each box.
27#PATH_TO_LABELS = 'mscoco_label_map.pbtxt'
28
29NUM_CLASSES = 1
30
31
32# categories
33
34category_index = {1: {'id': 1, 'name': 'person'},
35# 3: {'id': 3, 'name': 'car'},
36 }
37
38detection_graph = tf.Graph()
39with detection_graph.as_default():
40 od_graph_def = tf.GraphDef()
41 with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
42 serialized_graph = fid.read()
43 od_graph_def.ParseFromString(serialized_graph)
44 tf.import_graph_def(od_graph_def, name='')
Ở đây, mình chỉ demo detect người trong hình, nên mình chỉ để category_index chỉ là “person”. Thực tế, mô hình COCO hỗ trợ cho chúng ta nhận dạng 90 loại đối tượng khác nhau, các bạn có nhu cầu tìm hiểu thì thay bằng đoạn mã sau:
1category_index = {1: {'id': 1, 'name': 'person'},
2 2: {'id': 2, 'name': 'bicycle'},
3 3: {'id': 3, 'name': 'car'},
4 4: {'id': 4, 'name': 'motorcycle'},
5 5: {'id': 5, 'name': 'airplane'},
6 6: {'id': 6, 'name': 'bus'},
7 7: {'id': 7, 'name': 'train'},
8 8: {'id': 8, 'name': 'truck'},
9 9: {'id': 9, 'name': 'boat'},
10 10: {'id': 10, 'name': 'traffic light'},
11 11: {'id': 11, 'name': 'fire hydrant'},
12 13: {'id': 13, 'name': 'stop sign'},
13 14: {'id': 14, 'name': 'parking meter'},
14 15: {'id': 15, 'name': 'bench'},
15 16: {'id': 16, 'name': 'bird'},
16 17: {'id': 17, 'name': 'cat'},
17 18: {'id': 18, 'name': 'dog'},
18 19: {'id': 19, 'name': 'horse'},
19 20: {'id': 20, 'name': 'sheep'},
20 21: {'id': 21, 'name': 'cow'},
21 22: {'id': 22, 'name': 'elephant'},
22 23: {'id': 23, 'name': 'bear'},
23 24: {'id': 24, 'name': 'zebra'},
24 25: {'id': 25, 'name': 'giraffe'},
25 27: {'id': 27, 'name': 'backpack'},
26 28: {'id': 28, 'name': 'umbrella'},
27 31: {'id': 31, 'name': 'handbag'},
28 32: {'id': 32, 'name': 'tie'},
29 33: {'id': 33, 'name': 'suitcase'},
30 34: {'id': 34, 'name': 'frisbee'},
31 35: {'id': 35, 'name': 'skis'},
32 36: {'id': 36, 'name': 'snowboard'},
33 37: {'id': 37, 'name': 'sports ball'},
34 38: {'id': 38, 'name': 'kite'},
35 39: {'id': 39, 'name': 'baseball bat'},
36 40: {'id': 40, 'name': 'baseball glove'},
37 41: {'id': 41, 'name': 'skateboard'},
38 42: {'id': 42, 'name': 'surfboard'},
39 43: {'id': 43, 'name': 'tennis racket'},
40 44: {'id': 44, 'name': 'bottle'},
41 46: {'id': 46, 'name': 'wine glass'},
42 47: {'id': 47, 'name': 'cup'},
43 48: {'id': 48, 'name': 'fork'},
44 49: {'id': 49, 'name': 'knife'},
45 50: {'id': 50, 'name': 'spoon'},
46 51: {'id': 51, 'name': 'bowl'},
47 52: {'id': 52, 'name': 'banana'},
48 53: {'id': 53, 'name': 'apple'},
49 54: {'id': 54, 'name': 'sandwich'},
50 55: {'id': 55, 'name': 'orange'},
51 56: {'id': 56, 'name': 'broccoli'},
52 57: {'id': 57, 'name': 'carrot'},
53 58: {'id': 58, 'name': 'hot dog'},
54 59: {'id': 59, 'name': 'pizza'},
55 60: {'id': 60, 'name': 'donut'},
56 61: {'id': 61, 'name': 'cake'},
57 62: {'id': 62, 'name': 'chair'},
58 63: {'id': 63, 'name': 'couch'},
59 64: {'id': 64, 'name': 'potted plant'},
60 65: {'id': 65, 'name': 'bed'},
61 67: {'id': 67, 'name': 'dining table'},
62 70: {'id': 70, 'name': 'toilet'},
63 72: {'id': 72, 'name': 'tv'},
64 73: {'id': 73, 'name': 'laptop'},
65 74: {'id': 74, 'name': 'mouse'},
66 75: {'id': 75, 'name': 'remote'},
67 76: {'id': 76, 'name': 'keyboard'},
68 77: {'id': 77, 'name': 'cell phone'},
69 78: {'id': 78, 'name': 'microwave'},
70 79: {'id': 79, 'name': 'oven'},
71 80: {'id': 80, 'name': 'toaster'},
72 81: {'id': 81, 'name': 'sink'},
73 82: {'id': 82, 'name': 'refrigerator'},
74 84: {'id': 84, 'name': 'book'},
75 85: {'id': 85, 'name': 'clock'},
76 86: {'id': 86, 'name': 'vase'},
77 87: {'id': 87, 'name': 'scissors'},
78 88: {'id': 88, 'name': 'teddy bear'},
79 89: {'id': 89, 'name': 'hair drier'},
80 90: {'id': 90, 'name': 'toothbrush'}}
Tiếp theo, chúng ta sẽ load một số hàm giúp hỗ trợ việc hậu xử lý ảnh để vẽ các mask cho chúng ta xem trực quan hơn.
1
2 draw = ImageDraw.Draw(image)
3 im_width, im_height = image.size
4 if use_normalized_coordinates:
5 (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
6 ymin * im_height, ymax * im_height)
7 else:
8 (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
9 draw.line([(left, top), (left, bottom), (right, bottom),
10 (right, top), (left, top)], width=thickness, fill=color)
11 try:
12 font = ImageFont.truetype('arial.ttf', 24)
13 except IOError:
14 font = ImageFont.load_default()
15
16 # If the total height of the display strings added to the top of the bounding
17 # box exceeds the top of the image, stack the strings below the bounding box
18 # instead of above.
19 display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
20 # Each display_str has a top and bottom margin of 0.05x.
21 total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
22
23 if top > total_display_str_height:
24 text_bottom = top
25 else:
26 text_bottom = bottom + total_display_str_height
27 # Reverse list and print from bottom to top.
28 for display_str in display_str_list[::-1]:
29 text_width, text_height = font.getsize(display_str)
30 margin = np.ceil(0.05 * text_height)
31 draw.rectangle(
32 [(left, text_bottom - text_height - 2 * margin), (left + text_width,
33 text_bottom)],
34 fill=color)
35 draw.text(
36 (left + margin, text_bottom - text_height - margin),
37 display_str,
38 fill='black',
39 font=font)
40 text_bottom -= text_height - 2 * margin
41
42
43
44def visualize_boxes_and_labels_on_image_array(
45 image,
46 boxes,
47 classes,
48 scores,
49 category_index,
50 instance_masks=None,
51 instance_boundaries=None,
52 keypoints=None,
53 use_normalized_coordinates=False,
54 max_boxes_to_draw=20,
55 min_score_thresh=.5,
56 agnostic_mode=False,
57 line_thickness=4,
58 groundtruth_box_visualization_color='black',
59 skip_scores=False,
60 skip_labels=False):
61
62 box_to_display_str_map = collections.defaultdict(list)
63 box_to_color_map = collections.defaultdict(str)
64 box_to_instance_masks_map = {}
65 box_to_instance_boundaries_map = {}
66 box_to_keypoints_map = collections.defaultdict(list)
67 if not max_boxes_to_draw:
68 max_boxes_to_draw = boxes.shape[0]
69 #print(boxes)
70 for i in range(min(max_boxes_to_draw, boxes.shape[0])):
71 if scores is None or scores[i] > min_score_thresh:
72 box = tuple(boxes[i].tolist())
73 if instance_masks is not None:
74 box_to_instance_masks_map[box] = instance_masks[i]
75 if instance_boundaries is not None:
76 box_to_instance_boundaries_map[box] = instance_boundaries[i]
77 if keypoints is not None:
78 box_to_keypoints_map[box].extend(keypoints[i])
79 if scores is None:
80 box_to_color_map[box] = groundtruth_box_visualization_color
81 else:
82 display_str = ''
83 if not skip_labels:
84 if not agnostic_mode:
85 if classes[i] in category_index.keys():
86 class_name = category_index[classes[i]]['name']
87 else:
88 class_name = 'N/A'
89 display_str = str(class_name)
90 if not skip_scores:
91 if not display_str:
92 display_str = '{}%'.format(int(100 * scores[i]))
93 else:
94 display_str = '{}: {}%'.format(
95 display_str, int(100 * scores[i]))
96 box_to_display_str_map[box].append(display_str)
97 if agnostic_mode:
98 box_to_color_map[box] = 'DarkOrange'
99 else:
100 box_to_color_map[box] = STANDARD_COLORS[classes[i] %
101 len(STANDARD_COLORS)]
102
103 # Draw all boxes onto image.
104 for box, color in box_to_color_map.items():
105 ymin, xmin, ymax, xmax = box
106 if instance_masks is not None:
107 draw_mask_on_image_array(image, box_to_instance_masks_map[box], color=color)
108
109 draw_bounding_box_on_image_array(
110 image,
111 ymin,
112 xmin,
113 ymax,
114 xmax,
115 color=color,
116 thickness=line_thickness,
117 display_str_list=box_to_display_str_map[box],
118 use_normalized_coordinates=use_normalized_coordinates)
119
120 return image
121
122
123def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
124 image_width):
125 """Transforms the box masks back to full image masks.
126
127 Embeds masks in bounding boxes of larger masks whose shapes correspond to
128 image shape.
129
130 Args:
131 box_masks: A tf.float32 tensor of size [num_masks, mask_height, mask_width].
132 boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
133 corners. Row i contains [ymin, xmin, ymax, xmax] of the box
134 corresponding to mask i. Note that the box corners are in
135 normalized coordinates.
136 image_height: Image height. The output mask will have the same height as
137 the image height.
138 image_width: Image width. The output mask will have the same width as the
139 image width.
140
141 Returns:
142 A tf.float32 tensor of size [num_masks, image_height, image_width].
143 """
144 # TODO(rathodv): Make this a public function.
145 def reframe_box_masks_to_image_masks_default():
146 """The default function when there are more than 0 box masks."""
147 def transform_boxes_relative_to_boxes(boxes, reference_boxes):
148 boxes = tf.reshape(boxes, [-1, 2, 2])
149 min_corner = tf.expand_dims(reference_boxes[:, 0:2], 1)
150 max_corner = tf.expand_dims(reference_boxes[:, 2:4], 1)
151 transformed_boxes = (boxes - min_corner) / \
152 (max_corner - min_corner)
153 return tf.reshape(transformed_boxes, [-1, 4])
154
155 box_masks_expanded = tf.expand_dims(box_masks, axis=3)
156 num_boxes = tf.shape(box_masks_expanded)[0]
157 unit_boxes = tf.concat(
158 [tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1)
159 reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes)
160 return tf.image.crop_and_resize(
161 image=box_masks_expanded,
162 boxes=reverse_boxes,
163 box_ind=tf.range(num_boxes),
164 crop_size=[image_height, image_width],
165 extrapolation_value=0.0)
166 image_masks = tf.cond(
167 tf.shape(box_masks)[0] > 0,
168 reframe_box_masks_to_image_masks_default,
169 lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32))
170 return tf.squeeze(image_masks, axis=3)
Cho hình ảnh vào và rút ra kết quả.
1
2def detect_frame(image_np, sess, detection_graph):
3
4 with detection_graph.as_default():
5
6 ops = tf.get_default_graph().get_operations()
7 all_tensor_names = {output.name for op in ops for output in op.outputs}
8 tensor_dict = {}
9 for key in [
10 'num_detections', 'detection_boxes', 'detection_scores',
11 'detection_classes', 'detection_masks'
12 ]:
13 tensor_name = key + ':0'
14 if tensor_name in all_tensor_names:
15 tensor_dict[key] = tf.get_default_graph(
16 ).get_tensor_by_name(tensor_name)
17 if 'detection_masks' in tensor_dict:
18 # The following processing is only for single image
19 detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
20 detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
21 # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
22 real_num_detection = tf.cast(
23 tensor_dict['num_detections'][0], tf.int32)
24
25 detection_boxes = tf.slice(detection_boxes, [0, 0], [
26 real_num_detection, -1])
27 detection_masks = tf.slice(detection_masks, [0, 0, 0], [
28 real_num_detection, -1, -1])
29 detection_masks_reframed = reframe_box_masks_to_image_masks(
30 detection_masks, detection_boxes, image_np.shape[0], image_np.shape[1])
31 detection_masks_reframed = tf.cast(
32 tf.greater(detection_masks_reframed, 0.5), tf.uint8)
33 # Follow the convention by adding back the batch dimension
34 tensor_dict['detection_masks'] = tf.expand_dims(
35 detection_masks_reframed, 0)
36 image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
37
38 # Run inference
39 output_dict = sess.run(tensor_dict,
40 feed_dict={image_tensor: np.expand_dims(image_np, 0)})
41
42 # all outputs are float32 numpy arrays, so convert types as appropriate
43 output_dict['num_detections'] = int(output_dict['num_detections'][0])
44 #print("num detect "+str(output_dict['num_detections']))
45 output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(
46 np.uint8)
47 output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
48 output_dict['detection_scores'] = output_dict['detection_scores'][0]
49 if 'detection_masks' in output_dict:
50 output_dict['detection_masks'] = output_dict['detection_masks'][0]
51
52 visualize_boxes_and_labels_on_image_array(
53 image_np,
54 output_dict['detection_boxes'],
55 output_dict['detection_classes'],
56 output_dict['detection_scores'],
57 category_index,
58 instance_masks=output_dict.get('detection_masks'),
59 use_normalized_coordinates=True,
60 line_thickness=1,
61 max_boxes_to_draw=min(output_dict['num_detections'],20)
62 )
63
64 return image_np
1image = cv2.imread('img2.jpg')
2with detection_graph.as_default():
3 with tf.Session(graph=detection_graph) as sess:
4 image_np = detect_frame(image, sess, detection_graph)
5
6cv2.imwrite('output.jpg', image)
Kết quả file output.jpg của chúng ta là:
Thử với bức ảnh người và xe hơi.
Cảm ơn các bạn đã theo dõi. Hẹn gặp bạn ở các bài viết tiếp theo.
Comments