diff --git a/PixelSAM.py b/PixelSAM.py index d337313..3ff1a7a 100644 --- a/PixelSAM.py +++ b/PixelSAM.py @@ -21,9 +21,27 @@ def __init__(self, container): super().__init__(container) # self['text'] = 'Options' + # The left frame containing buttons + self.left_tab = ttk.Frame(self) + # Load button - self.load_btn = tk.Button(self, text="Load Dataset", font='sans 10 bold', height=2, width=12, background="#343434", foreground="white", command=self.load_data) - self.load_btn.pack(side=tk.LEFT, padx=(30,30), pady=20, anchor="n") + self.load_btn = tk.Button(self.left_tab, text="Load Dataset", font='sans 10 bold', height=2, width=12, background="#343434", foreground="white", command=self.load_data) + self.load_btn.pack(side=tk.TOP, padx=(30,30), pady=20, anchor="n") + + self.boundingbox_var = tk.IntVar() + self.boundingbox_logo = ImageTk.PhotoImage(Image.open(os.path.join(".","assets","bbox.png")).resize((40,40), Image.Resampling.LANCZOS)) # https://www.flaticon.com/free-icon/square_7559227 + self.boundingbox_selector = tk.Checkbutton(self.left_tab, image=self.boundingbox_logo, variable=self.boundingbox_var, font='sans 10 bold', indicatoron=False, text="Show Bbox", width=100, height = 60, compound="top", selectcolor="#34b233", command=self.statechange_callback) + self.boundingbox_selector.pack(side=tk.BOTTOM, padx=(30,30), pady=20, anchor="n") + + # Checkbox for selecting between outer edged and all points + # The coco based annotations have only outer edge marked + # To make it compatible this is being added here. + self.checkbox_var = tk.IntVar() + self.polygon_logo = ImageTk.PhotoImage(Image.open(os.path.join(".","assets","polygon.png")).resize((40,40), Image.Resampling.LANCZOS)) # https://www.flaticon.com/free-icon/polygon_9726538 + self.checkbox = tk.Checkbutton(self.left_tab, image=self.polygon_logo, variable=self.checkbox_var, font='sans 10 bold', indicatoron=False, text="Outer Edge", width=100, height = 60, compound="top", selectcolor="#34b233", command=self.statechange_callback) + self.checkbox.pack(side=tk.BOTTOM, padx=(30,30), pady=20, anchor="n") + + self.left_tab.pack(side=tk.LEFT, padx=5, pady=5, anchor="n") # The frame which includes the image player self.imageplayer = ttk.Frame(self) @@ -90,6 +108,7 @@ def __init__(self, container): self.cur_image_index = 0 # The index of the current image self.image_list = [] # The list of images to be displayed self.window_height = 0 # The height of the app window + self.state_changed = False # The state of the bbox and outer edge buttons # Path to the intro image self.cur_image_path = os.path.join(".","assets","intro.png") @@ -143,7 +162,8 @@ def frame_update(self): # Get the height of the app window self.window_height = app.winfo_height()-20 # Display the image - if self.cur_image_path != self.prev_image_path or len(self.cur_annotation) != self.annotation_count or self.window_height != self.prev_window_height or self.mask_count != len(self.mask_images): + # If a new image is loaded or there is a new annotation or the window is resized, or there is a change in the state of the buttons, then update the image + if self.cur_image_path != self.prev_image_path or len(self.cur_annotation) != self.annotation_count or self.window_height != self.prev_window_height or self.state_changed: # Read the image and convert it to RGB self.OCV_image = cv2.imread(self.cur_image_path) self.cv2image = cv2.cvtColor(self.OCV_image, cv2.COLOR_BGR2RGB) @@ -160,13 +180,14 @@ def frame_update(self): self.annotation_count = len(self.cur_annotation) self.prev_window_height = self.window_height self.mask_count = len(self.mask_images) + self.state_changed = False # Get the image dimensions self.img_height, self.img_width, _ = self.OCV_image.shape # Draw the annotations if len(self.cur_annotation) > 0: - self.cv2image, self.mask_image, self.bbox_corners = SAM_prediction(self.cv2image, self.cur_annotation, self.predictor, self.img_height, self.img_width,self.mask_images) + self.cv2image, self.mask_image, self.bbox_corners = SAM_prediction(self.cv2image, self.cur_annotation, self.predictor, self.img_height, self.img_width, self.mask_images, self.checkbox_var.get(), self.boundingbox_var.get()) #get SAM polygons @@ -263,7 +284,8 @@ def new_object(self, event=None): if len(self.object_list.curselection())>0: self.cur_annotation = [] self.mask_images.append(self.mask_image) - self.bbox_list.append([self.object_list.curselection()[0], *self.bbox_corners]) + if len(self.bbox_corners) > 0: + self.bbox_list.append([self.object_list.curselection()[0], *self.bbox_corners]) self.object_list.selection_clear(0, tk.END) print("Previous masks:",len(self.mask_images)) else: @@ -293,14 +315,15 @@ def save_annotation(self, event): if len(self.cur_annotation) > 0: self.cur_annotation = [] self.mask_images.append(self.mask_image) - self.bbox_list.append([self.object_list.curselection()[0], *self.bbox_corners]) + if len(self.bbox_corners) > 0: + self.bbox_list.append([self.object_list.curselection()[0], *self.bbox_corners]) self.object_list.selection_clear(0, tk.END) # Save the bbox list to the file with open(os.path.join(self.file_path,os.path.basename(self.cur_image_path).split(".")[0]+".txt"), "w") as f: for bbox in self.bbox_list: f.write(str(bbox[0])+" "+str(bbox[1])+" "+str(bbox[2])+" "+str(bbox[3])+" "+str(bbox[4])+"\n") # If there is only one object labelled - elif len(self.cur_annotation) > 0: + elif len(self.cur_annotation) > 0 and len(self.bbox_corners)>0: if len(self.object_list.curselection())>0: # Save the bbox list to the file with open(os.path.join(self.file_path,os.path.basename(self.cur_image_path).split(".")[0]+".txt"), "w") as f: @@ -309,6 +332,10 @@ def save_annotation(self, event): messagebox.showwarning("Warning","Please select an object from the list before saving") else: messagebox.showwarning("Warning","Please label the image before saving") + + # When there is a state change + def statechange_callback(self, event=None): + self.state_changed = True # App class diff --git a/assets/bbox.png b/assets/bbox.png new file mode 100644 index 0000000..ad692bb Binary files /dev/null and b/assets/bbox.png differ diff --git a/assets/polygon.png b/assets/polygon.png new file mode 100644 index 0000000..75f5943 Binary files /dev/null and b/assets/polygon.png differ diff --git a/utils/predict.py b/utils/predict.py index 691978a..f9d09c8 100644 --- a/utils/predict.py +++ b/utils/predict.py @@ -5,7 +5,87 @@ import numpy as np from segment_anything import sam_model_registry, SamPredictor +# BBox to Yolo format +def bbox_to_yolo(bbox, image_width, image_height): + # Get the center of the bbox + center_x = (bbox[0] + bbox[2]) / 2 + center_y = (bbox[1] + bbox[3]) / 2 + # Get the width and height of the bbox + width = bbox[2] - bbox[0] + height = bbox[3] - bbox[1] + + # Convert the bbox to Yolo format + yolo_bbox = [ + center_x / image_width, + center_y / image_height, + width / image_width, + height / image_height, + ] + + # Round the values to 6 decimal places + yolo_bbox = [round(value, 6) for value in yolo_bbox] + + return yolo_bbox + +# Function to filter the contour by area +def filter_contour_by_area(contour, min_area, max_area): + area = cv2.contourArea(contour) + if area < min_area or area > max_area: + return False + return True + +# Function to approximate the polygon +def approx_contour(contour, percentage, epsilon_step=0.005): + if percentage < 0 or percentage >= 1: + raise ValueError("Percentage must be in the range [0, 1).") + + target_points = max(int(contour.shape[0] * (1 - percentage)), 3) + + epsilon = 0 + while True: + epsilon += epsilon_step + approximated_contour = cv2.approxPolyDP(contour, epsilon, closed=True) + if approximated_contour.shape[0] <= target_points: + break + + return approximated_contour + +# Function to convert the mask to a polygon +def mask_to_polygon(mask_image, approximation_percentage = 0.75): + + # Add padding to the mask_image to capture the outer edge + padded_mask_image = cv2.copyMakeBorder(mask_image, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) + + # Get the contours of the mask + contours, _ = cv2.findContours(cv2.Canny(padded_mask_image, 100, 200), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS) + + # Remove the contours which have less than 10 points + contours = [contour for contour in contours if contour.shape[0] >= 10] + + # Remove the previously added padding from the contours + contours = [np.subtract(contour,(1,1)) for contour in contours] + + # Filter the contours by area + height, width = mask_image.shape[-2:] + contours = [contour for contour in contours if filter_contour_by_area( + contour=contour, + min_area=0.005*height*width, + max_area=1*height*width)] + + # Reduce the complexity of contours + best_contour = [approx_contour(contour=contour, percentage=0.75) for contour in contours] + + # Ensure that the contour is closed + best_contour = [np.vstack((contour, contour[0][np.newaxis, :])) for contour in best_contour if not np.array_equal(contour[0], contour[-1])] + + # Convert the contours to polygons + best_polygon = [contour.flatten().tolist() for contour in best_contour] + + return best_contour, best_polygon + + +# Function to setup the SAM model def SAM_setup(model_type, model_path, device_id): sam = sam_model_registry[model_type](checkpoint=model_path) if device_id != "cpu": @@ -14,7 +94,9 @@ def SAM_setup(model_type, model_path, device_id): print("Warning: Running on CPU. This will be slow.") return SamPredictor(sam) -def SAM_prediction(image, points, predictor, img_height, img_width, mask_array=[]): + +# Function to predict the mask +def SAM_prediction(image, points, predictor, img_height, img_width, mask_array=[], outer_edge=0, bounding_box=0): # The points are in the format [x, y, color, label] input_point = np.array([[p[0], p[1]] for p in points]) input_label = np.array([p[3] for p in points]) @@ -45,33 +127,59 @@ def SAM_prediction(image, points, predictor, img_height, img_width, mask_array=[ # Morphological operations on the mask to be saved mask_save_image = cv2.morphologyEx(mask_save_image, cv2.MORPH_OPEN, kernel) - - # Get the edges of the mask - edges = cv2.Canny(mask_image[:, :, 0], 100, 200) - - # Overlay the mask on the image - image = cv2.addWeighted(mask_image, 0.3, image, 0.7, 0) + + # Add the previous masks to the image if len(mask_array) > 0: for m in mask_array: image = cv2.addWeighted(m, 0.3, image, 1, 0) - # Plot the edges on the image - gy, gx = np.where(edges != 0) - for i in range(len(gx)): - image = cv2.circle(image, (gx[i], gy[i]), int((img_height+img_width)/400), (0, 0, 255),-1) + # Check if the outer edge is to be detected + if outer_edge==1: + + # Create a copy of the image + overlay = image.copy() + + # Get the best contours in the mask and its corresponding polygon + best_contour, best_polygon = mask_to_polygon(mask_image) + + # Overlay the controur + cv2.drawContours(overlay, best_contour, -1, (0, 0, 255), thickness=cv2.FILLED) + + # Add the overlay to the image + image = cv2.addWeighted(overlay, 0.5, image, 0.5, 0) + # Marking the bounding box for each of the contours + if bounding_box==1: + for cnt_cur in best_contour: + x, y, w, h = cv2.boundingRect(cnt_cur) + image = cv2.rectangle(image, (x, y), (x+w, y+h), (0, 255, 0), int((img_height+img_width)/400)) + + gx, gy = [], [] + # Extract x and y coordinates into separate arrays + if len(best_polygon) > 0: + gx = np.array(np.hstack(best_polygon)).reshape(-1, 2)[:, 0] + gy = np.array(np.hstack(best_polygon)).reshape(-1, 2)[:, 1] + + else: + # Overlay the mask on the image + image = cv2.addWeighted(mask_image, 0.3, image, 0.7, 0) + + # Get the edges of the mask + edges = cv2.Canny(mask_image[:, :, 0], 100, 200) + + # Plot the edges on the image + gy, gx = np.where(edges != 0) + for i in range(len(gx)): + image = cv2.circle(image, (gx[i], gy[i]), int((img_height+img_width)/400), (0, 0, 255),-1) + + bbox_corners = [] # Find the bounding box for the object if len(gx) > 0: # Calculate bounding box dimensions and center coordinates in YOLO format - bbox_width = (np.max(gx) - np.min(gx)) / img_width - bbox_height = (np.max(gy) - np.min(gy)) / img_height - bbox_center_x = (np.max(gx) + np.min(gx)) / (2 * img_width) - bbox_center_y = (np.max(gy) + np.min(gy)) / (2 * img_height) - - bbox_corners = [bbox_center_x, bbox_center_y, bbox_width, bbox_height] - bbox_corners = [round(x, 6) for x in bbox_corners] - else: - bbox_corners = [0, 0, 0, 0] + bbox_corners = bbox_to_yolo([np.min(gx), np.min(gy), np.max(gx), np.max(gy)], img_width, img_height) + # If bounding_box is 1, plot the bounding box for the object + if bounding_box==1: + image = cv2.rectangle(image, (int((bbox_corners[0] - bbox_corners[2]/2) * img_width), int((bbox_corners[1] - bbox_corners[3]/2) * img_height)), (int((bbox_corners[0] + bbox_corners[2]/2) * img_width), int((bbox_corners[1] + bbox_corners[3]/2) * img_height)), (255, 0, 0), int((img_height+img_width)/400)) return image, mask_save_image, bbox_corners