Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions PixelSAM.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,27 @@ def __init__(self, container):
super().__init__(container)
# self['text'] = 'Options'

# The left frame containing buttons
self.left_tab = ttk.Frame(self)

# Load button
self.load_btn = tk.Button(self, text="Load Dataset", font='sans 10 bold', height=2, width=12, background="#343434", foreground="white", command=self.load_data)
self.load_btn.pack(side=tk.LEFT, padx=(30,30), pady=20, anchor="n")
self.load_btn = tk.Button(self.left_tab, text="Load Dataset", font='sans 10 bold', height=2, width=12, background="#343434", foreground="white", command=self.load_data)
self.load_btn.pack(side=tk.TOP, padx=(30,30), pady=20, anchor="n")

self.boundingbox_var = tk.IntVar()
self.boundingbox_logo = ImageTk.PhotoImage(Image.open(os.path.join(".","assets","bbox.png")).resize((40,40), Image.Resampling.LANCZOS)) # https://www.flaticon.com/free-icon/square_7559227
self.boundingbox_selector = tk.Checkbutton(self.left_tab, image=self.boundingbox_logo, variable=self.boundingbox_var, font='sans 10 bold', indicatoron=False, text="Show Bbox", width=100, height = 60, compound="top", selectcolor="#34b233", command=self.statechange_callback)
self.boundingbox_selector.pack(side=tk.BOTTOM, padx=(30,30), pady=20, anchor="n")

# Checkbox for selecting between outer edged and all points
# The coco based annotations have only outer edge marked
# To make it compatible this is being added here.
self.checkbox_var = tk.IntVar()
self.polygon_logo = ImageTk.PhotoImage(Image.open(os.path.join(".","assets","polygon.png")).resize((40,40), Image.Resampling.LANCZOS)) # https://www.flaticon.com/free-icon/polygon_9726538
self.checkbox = tk.Checkbutton(self.left_tab, image=self.polygon_logo, variable=self.checkbox_var, font='sans 10 bold', indicatoron=False, text="Outer Edge", width=100, height = 60, compound="top", selectcolor="#34b233", command=self.statechange_callback)
self.checkbox.pack(side=tk.BOTTOM, padx=(30,30), pady=20, anchor="n")

self.left_tab.pack(side=tk.LEFT, padx=5, pady=5, anchor="n")

# The frame which includes the image player
self.imageplayer = ttk.Frame(self)
Expand Down Expand Up @@ -90,6 +108,7 @@ def __init__(self, container):
self.cur_image_index = 0 # The index of the current image
self.image_list = [] # The list of images to be displayed
self.window_height = 0 # The height of the app window
self.state_changed = False # The state of the bbox and outer edge buttons

# Path to the intro image
self.cur_image_path = os.path.join(".","assets","intro.png")
Expand Down Expand Up @@ -143,7 +162,8 @@ def frame_update(self):
# Get the height of the app window
self.window_height = app.winfo_height()-20
# Display the image
if self.cur_image_path != self.prev_image_path or len(self.cur_annotation) != self.annotation_count or self.window_height != self.prev_window_height or self.mask_count != len(self.mask_images):
# If a new image is loaded or there is a new annotation or the window is resized, or there is a change in the state of the buttons, then update the image
if self.cur_image_path != self.prev_image_path or len(self.cur_annotation) != self.annotation_count or self.window_height != self.prev_window_height or self.state_changed:
# Read the image and convert it to RGB
self.OCV_image = cv2.imread(self.cur_image_path)
self.cv2image = cv2.cvtColor(self.OCV_image, cv2.COLOR_BGR2RGB)
Expand All @@ -160,13 +180,14 @@ def frame_update(self):
self.annotation_count = len(self.cur_annotation)
self.prev_window_height = self.window_height
self.mask_count = len(self.mask_images)
self.state_changed = False

# Get the image dimensions
self.img_height, self.img_width, _ = self.OCV_image.shape

# Draw the annotations
if len(self.cur_annotation) > 0:
self.cv2image, self.mask_image, self.bbox_corners = SAM_prediction(self.cv2image, self.cur_annotation, self.predictor, self.img_height, self.img_width,self.mask_images)
self.cv2image, self.mask_image, self.bbox_corners = SAM_prediction(self.cv2image, self.cur_annotation, self.predictor, self.img_height, self.img_width, self.mask_images, self.checkbox_var.get(), self.boundingbox_var.get())
#get SAM polygons


Expand Down Expand Up @@ -263,7 +284,8 @@ def new_object(self, event=None):
if len(self.object_list.curselection())>0:
self.cur_annotation = []
self.mask_images.append(self.mask_image)
self.bbox_list.append([self.object_list.curselection()[0], *self.bbox_corners])
if len(self.bbox_corners) > 0:
self.bbox_list.append([self.object_list.curselection()[0], *self.bbox_corners])
self.object_list.selection_clear(0, tk.END)
print("Previous masks:",len(self.mask_images))
else:
Expand Down Expand Up @@ -293,14 +315,15 @@ def save_annotation(self, event):
if len(self.cur_annotation) > 0:
self.cur_annotation = []
self.mask_images.append(self.mask_image)
self.bbox_list.append([self.object_list.curselection()[0], *self.bbox_corners])
if len(self.bbox_corners) > 0:
self.bbox_list.append([self.object_list.curselection()[0], *self.bbox_corners])
self.object_list.selection_clear(0, tk.END)
# Save the bbox list to the file
with open(os.path.join(self.file_path,os.path.basename(self.cur_image_path).split(".")[0]+".txt"), "w") as f:
for bbox in self.bbox_list:
f.write(str(bbox[0])+" "+str(bbox[1])+" "+str(bbox[2])+" "+str(bbox[3])+" "+str(bbox[4])+"\n")
# If there is only one object labelled
elif len(self.cur_annotation) > 0:
elif len(self.cur_annotation) > 0 and len(self.bbox_corners)>0:
if len(self.object_list.curselection())>0:
# Save the bbox list to the file
with open(os.path.join(self.file_path,os.path.basename(self.cur_image_path).split(".")[0]+".txt"), "w") as f:
Expand All @@ -309,6 +332,10 @@ def save_annotation(self, event):
messagebox.showwarning("Warning","Please select an object from the list before saving")
else:
messagebox.showwarning("Warning","Please label the image before saving")

# When there is a state change
def statechange_callback(self, event=None):
self.state_changed = True


# App class
Expand Down
Binary file added assets/bbox.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/polygon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
148 changes: 128 additions & 20 deletions utils/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,87 @@
import numpy as np
from segment_anything import sam_model_registry, SamPredictor

# BBox to Yolo format
def bbox_to_yolo(bbox, image_width, image_height):
# Get the center of the bbox
center_x = (bbox[0] + bbox[2]) / 2
center_y = (bbox[1] + bbox[3]) / 2

# Get the width and height of the bbox
width = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]

# Convert the bbox to Yolo format
yolo_bbox = [
center_x / image_width,
center_y / image_height,
width / image_width,
height / image_height,
]

# Round the values to 6 decimal places
yolo_bbox = [round(value, 6) for value in yolo_bbox]

return yolo_bbox

# Function to filter the contour by area
def filter_contour_by_area(contour, min_area, max_area):
area = cv2.contourArea(contour)
if area < min_area or area > max_area:
return False
return True

# Function to approximate the polygon
def approx_contour(contour, percentage, epsilon_step=0.005):
if percentage < 0 or percentage >= 1:
raise ValueError("Percentage must be in the range [0, 1).")

target_points = max(int(contour.shape[0] * (1 - percentage)), 3)

epsilon = 0
while True:
epsilon += epsilon_step
approximated_contour = cv2.approxPolyDP(contour, epsilon, closed=True)
if approximated_contour.shape[0] <= target_points:
break

return approximated_contour

# Function to convert the mask to a polygon
def mask_to_polygon(mask_image, approximation_percentage = 0.75):

# Add padding to the mask_image to capture the outer edge
padded_mask_image = cv2.copyMakeBorder(mask_image, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)

# Get the contours of the mask
contours, _ = cv2.findContours(cv2.Canny(padded_mask_image, 100, 200), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)

# Remove the contours which have less than 10 points
contours = [contour for contour in contours if contour.shape[0] >= 10]

# Remove the previously added padding from the contours
contours = [np.subtract(contour,(1,1)) for contour in contours]

# Filter the contours by area
height, width = mask_image.shape[-2:]
contours = [contour for contour in contours if filter_contour_by_area(
contour=contour,
min_area=0.005*height*width,
max_area=1*height*width)]

# Reduce the complexity of contours
best_contour = [approx_contour(contour=contour, percentage=0.75) for contour in contours]

# Ensure that the contour is closed
best_contour = [np.vstack((contour, contour[0][np.newaxis, :])) for contour in best_contour if not np.array_equal(contour[0], contour[-1])]

# Convert the contours to polygons
best_polygon = [contour.flatten().tolist() for contour in best_contour]

return best_contour, best_polygon


# Function to setup the SAM model
def SAM_setup(model_type, model_path, device_id):
sam = sam_model_registry[model_type](checkpoint=model_path)
if device_id != "cpu":
Expand All @@ -14,7 +94,9 @@ def SAM_setup(model_type, model_path, device_id):
print("Warning: Running on CPU. This will be slow.")
return SamPredictor(sam)

def SAM_prediction(image, points, predictor, img_height, img_width, mask_array=[]):

# Function to predict the mask
def SAM_prediction(image, points, predictor, img_height, img_width, mask_array=[], outer_edge=0, bounding_box=0):
# The points are in the format [x, y, color, label]
input_point = np.array([[p[0], p[1]] for p in points])
input_label = np.array([p[3] for p in points])
Expand Down Expand Up @@ -45,33 +127,59 @@ def SAM_prediction(image, points, predictor, img_height, img_width, mask_array=[

# Morphological operations on the mask to be saved
mask_save_image = cv2.morphologyEx(mask_save_image, cv2.MORPH_OPEN, kernel)

# Get the edges of the mask
edges = cv2.Canny(mask_image[:, :, 0], 100, 200)

# Overlay the mask on the image
image = cv2.addWeighted(mask_image, 0.3, image, 0.7, 0)

# Add the previous masks to the image
if len(mask_array) > 0:
for m in mask_array:
image = cv2.addWeighted(m, 0.3, image, 1, 0)

# Plot the edges on the image
gy, gx = np.where(edges != 0)
for i in range(len(gx)):
image = cv2.circle(image, (gx[i], gy[i]), int((img_height+img_width)/400), (0, 0, 255),-1)
# Check if the outer edge is to be detected
if outer_edge==1:

# Create a copy of the image
overlay = image.copy()

# Get the best contours in the mask and its corresponding polygon
best_contour, best_polygon = mask_to_polygon(mask_image)

# Overlay the controur
cv2.drawContours(overlay, best_contour, -1, (0, 0, 255), thickness=cv2.FILLED)

# Add the overlay to the image
image = cv2.addWeighted(overlay, 0.5, image, 0.5, 0)

# Marking the bounding box for each of the contours
if bounding_box==1:
for cnt_cur in best_contour:
x, y, w, h = cv2.boundingRect(cnt_cur)
image = cv2.rectangle(image, (x, y), (x+w, y+h), (0, 255, 0), int((img_height+img_width)/400))

gx, gy = [], []
# Extract x and y coordinates into separate arrays
if len(best_polygon) > 0:
gx = np.array(np.hstack(best_polygon)).reshape(-1, 2)[:, 0]
gy = np.array(np.hstack(best_polygon)).reshape(-1, 2)[:, 1]

else:
# Overlay the mask on the image
image = cv2.addWeighted(mask_image, 0.3, image, 0.7, 0)

# Get the edges of the mask
edges = cv2.Canny(mask_image[:, :, 0], 100, 200)

# Plot the edges on the image
gy, gx = np.where(edges != 0)
for i in range(len(gx)):
image = cv2.circle(image, (gx[i], gy[i]), int((img_height+img_width)/400), (0, 0, 255),-1)

bbox_corners = []
# Find the bounding box for the object
if len(gx) > 0:
# Calculate bounding box dimensions and center coordinates in YOLO format
bbox_width = (np.max(gx) - np.min(gx)) / img_width
bbox_height = (np.max(gy) - np.min(gy)) / img_height
bbox_center_x = (np.max(gx) + np.min(gx)) / (2 * img_width)
bbox_center_y = (np.max(gy) + np.min(gy)) / (2 * img_height)

bbox_corners = [bbox_center_x, bbox_center_y, bbox_width, bbox_height]
bbox_corners = [round(x, 6) for x in bbox_corners]
else:
bbox_corners = [0, 0, 0, 0]
bbox_corners = bbox_to_yolo([np.min(gx), np.min(gy), np.max(gx), np.max(gy)], img_width, img_height)

# If bounding_box is 1, plot the bounding box for the object
if bounding_box==1:
image = cv2.rectangle(image, (int((bbox_corners[0] - bbox_corners[2]/2) * img_width), int((bbox_corners[1] - bbox_corners[3]/2) * img_height)), (int((bbox_corners[0] + bbox_corners[2]/2) * img_width), int((bbox_corners[1] + bbox_corners[3]/2) * img_height)), (255, 0, 0), int((img_height+img_width)/400))

return image, mask_save_image, bbox_corners