Skip to content

Commit 0e3231e

Browse files
Add multi-points input, foreground/background points input and box input to EfficientSAM model (#291)
* a * add efficientsam model and basic demo * update license * remove example images * update readme * update readme * update demo * update demo * update readme * update SAM and __init__ * update demo and sam * update label * add present gif * update readme * add efficientSAM gif to readme of opencvzoo * cv version 4.10.0, remove camera branch * 1. add multipoints infering(max: 6) 2. add box prompt(drag), add background point(long press) 3. model fix to 1024*1024 4. label padding -1 5. update demo * replace the model by new model support mutil-points input, update demo * update readme * update readme * change window size to (800*600), pictures be put in can not exceed it * add int8 model * update demo * update README * check OpenCV version * update model name in demo * update model name in demo * Add a key to exit ('q' and 'Q'); When clicks reach maximum, no box shows; comment useless print, delete useless whitespace * update demo with some ASCII
1 parent 21f2e86 commit 0e3231e

File tree

5 files changed

+262
-75
lines changed

5 files changed

+262
-75
lines changed

models/image_segmentation_efficientsam/README.md

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@
33
EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything
44

55
Notes:
6-
- The current implementation of the EfficientSAM demo uses the EfficientSAM-Ti model, which is specifically tailored for scenarios requiring higher speed and lightweight.
7-
- MD5 value of "efficient_sam_vitt.pt" is 7A804DA508F30EFC59EC06711C8DCD62
8-
- SHA-256 value of "efficient_sam_vitt.pt" is DFF858B19600A46461CBB7DE98F796B23A7A888D9F5E34C0B033F7D6EB9E4E6A
6+
- The current implementation of the EfficientSAM demo uses the EfficientSAM-Ti model, which is specifically tailored for scenarios requiring higher speed and lightweight.
7+
- image_segmentation_efficientsam_ti_2024may.onnx(supports only single point infering)
8+
- MD5 value: 117d6a6cac60039a20b399cc133c2a60
9+
- SHA-256 value: e3957d2cd1422855f350aa7b044f47f5b3eafada64b5904ed330b696229e2943
10+
- image_segmentation_efficientsam_ti_2025april.onnx
11+
- MD5 value: f23cecbb344547c960c933ff454536a3
12+
- SHA-256 value: 4eb496e0a7259d435b49b66faf1754aa45a5c382a34558ddda9a8c6fe5915d77
13+
- image_segmentation_efficientsam_ti_2025april_int8.onnx
14+
- MD5 value: a1164f44b0495b82e9807c7256e95a50
15+
- SHA-256 value: 5ecc8d59a2802c32246e68553e1cf8ce74cf74ba707b84f206eb9181ff774b4e
916

1017

1118
## Demo
@@ -17,7 +24,7 @@ Run the following command to try the demo:
1724
python demo.py --input /path/to/image
1825
```
1926

20-
Click only **once** on the object you wish to segment in the displayed image. After the click, the segmentation result will be shown in a new window.
27+
**Click** to select foreground points, **drag** to use box to select and **long press** to select background points on the object you wish to segment in the displayed image. After clicking the **Enter**, the segmentation result will be shown in a new window. Clicking the **Backspace** to clear all the prompts.
2128

2229
## Result
2330

@@ -41,4 +48,5 @@ All files in this directory are licensed under [Apache 2.0 License](./LICENSE).
4148
## Reference
4249

4350
- https://arxiv.org/abs/2312.00863
44-
- https://github.com/yformer/EfficientSAM
51+
- https://github.com/yformer/EfficientSAM
52+
- https://github.com/facebookresearch/segment-anything

models/image_segmentation_efficientsam/demo.py

Lines changed: 152 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
parser = argparse.ArgumentParser(description='EfficientSAM Demo')
2121
parser.add_argument('--input', '-i', type=str,
2222
help='Set input path to a certain image.')
23-
parser.add_argument('--model', '-m', type=str, default='image_segmentation_efficientsam_ti_2024may.onnx',
24-
help='Set model path, defaults to image_segmentation_efficientsam_ti_2024may.onnx.')
23+
parser.add_argument('--model', '-m', type=str, default='image_segmentation_efficientsam_ti_2025april.onnx',
24+
help='Set model path, defaults to image_segmentation_efficientsam_ti_2025april.onnx.')
2525
parser.add_argument('--backend_target', '-bt', type=int, default=0,
2626
help='''Choose one of the backend-target pair to run this demo:
2727
{:d}: (default) OpenCV implementation + CPU,
@@ -34,10 +34,14 @@
3434
help='Specify to save a file with results. Invalid in case of camera input.')
3535
args = parser.parse_args()
3636

37-
#global click listener
38-
clicked_left = False
39-
#global point record in the window
40-
point = []
37+
# Global configuration
38+
WINDOW_SIZE = (800, 600) # Fixed window size (width, height)
39+
MAX_POINTS = 6 # Maximum allowed points
40+
points = [] # Store clicked coordinates (original image scale)
41+
labels = [] # Point labels (-1: useless, 0: background, 1: foreground, 2: top-left, 3: bottom right)
42+
backend_point = []
43+
rectangle = False
44+
current_img = None
4145

4246
def visualize(image, result):
4347
"""
@@ -55,26 +59,88 @@ def visualize(image, result):
5559
mask = np.copy(result)
5660
# change mask to binary image
5761
t, binary = cv.threshold(mask, 127, 255, cv.THRESH_BINARY)
58-
assert set(np.unique(binary)) <= {0, 255}, "The mask must be a binary image"
62+
assert set(np.unique(binary)) <= {0, 255}, "The mask must be a binary image."
5963
# enhance red channel to make the segmentation more obviously
6064
enhancement_factor = 1.8
61-
red_channel = vis_result[:, :, 2]
65+
red_channel = vis_result[:, :, 2]
6266
# update the channel
6367
red_channel = np.where(binary == 255, np.minimum(red_channel * enhancement_factor, 255), red_channel)
64-
vis_result[:, :, 2] = red_channel
65-
68+
vis_result[:, :, 2] = red_channel
69+
6670
# draw borders
6771
contours, hierarchy = cv.findContours(binary, cv.RETR_LIST, cv.CHAIN_APPROX_TC89_L1)
6872
cv.drawContours(vis_result, contours, contourIdx = -1, color = (255,255,255), thickness=2)
6973
return vis_result
7074

7175
def select(event, x, y, flags, param):
72-
global clicked_left
73-
# When the left mouse button is pressed, record the coordinates of the point where it is pressed
74-
if event == cv.EVENT_LBUTTONUP:
75-
point.append([x,y])
76-
print("point:",point[0])
77-
clicked_left = True
76+
"""Handle mouse events with coordinate conversion"""
77+
global points, labels, backend_point, rectangle, current_img
78+
orig_img = param['original_img']
79+
image_window = param['image_window']
80+
81+
if event == cv.EVENT_LBUTTONDOWN:
82+
param['mouse_down_time'] = cv.getTickCount()
83+
backend_point = [x, y]
84+
85+
elif event == cv.EVENT_MOUSEMOVE:
86+
if rectangle == True:
87+
rectangle_change_img = current_img.copy()
88+
cv.rectangle(rectangle_change_img, (backend_point[0], backend_point[1]), (x, y), (255,0,0) , 2)
89+
cv.imshow(image_window, rectangle_change_img)
90+
elif len(backend_point) != 0 and len(points) < MAX_POINTS:
91+
rectangle = True
92+
93+
94+
elif event == cv.EVENT_LBUTTONUP:
95+
if len(points) >= MAX_POINTS:
96+
print(f"Maximum points reached {MAX_POINTS}.")
97+
return
98+
99+
if rectangle == False:
100+
duration = (cv.getTickCount() - param['mouse_down_time'])/cv.getTickFrequency()
101+
label = -1 if duration > 0.5 else 1 # Long press = background
102+
103+
points.append([backend_point[0], backend_point[1]])
104+
labels.append(label)
105+
print(f"Added {['background','foreground','background'][label]} point {backend_point}.")
106+
else:
107+
if len(points) + 1 >= MAX_POINTS:
108+
rectangle = False
109+
backend_point.clear()
110+
cv.imshow(image_window, current_img)
111+
print(f"Points reached {MAX_POINTS}, could not add box.")
112+
return
113+
point_leftup = []
114+
point_rightdown = []
115+
if x > backend_point[0] or y > backend_point[1]:
116+
point_leftup.extend(backend_point)
117+
point_rightdown.extend([x,y])
118+
else:
119+
point_leftup.extend([x,y])
120+
point_rightdown.extend(backend_point)
121+
points.append(point_leftup)
122+
points.append(point_rightdown)
123+
print(f"Added box from {point_leftup} to {point_rightdown}.")
124+
labels.append(2)
125+
labels.append(3)
126+
rectangle = False
127+
backend_point.clear()
128+
129+
marked_img = orig_img.copy()
130+
top_left = None
131+
for (px, py), lbl in zip(points, labels):
132+
if lbl == -1:
133+
cv.circle(marked_img, (px, py), 5, (0, 0, 255), -1)
134+
elif lbl == 1:
135+
cv.circle(marked_img, (px, py), 5, (0, 255, 0), -1)
136+
elif lbl == 2:
137+
top_left = (px, py)
138+
elif lbl == 3:
139+
bottom_right = (px, py)
140+
cv.rectangle(marked_img, top_left, bottom_right, (255,0,0) , 2)
141+
cv.imshow(image_window, marked_img)
142+
current_img = marked_img.copy()
143+
78144

79145
if __name__ == '__main__':
80146
backend_id = backend_target_pairs[args.backend_target][0]
@@ -89,49 +155,93 @@ def select(event, x, y, flags, param):
89155
print('Could not open or find the image:', args.input)
90156
exit(0)
91157
# create window
92-
image_window = "image: click on the thing whick you want to segment!"
158+
image_window = "Origin image"
93159
cv.namedWindow(image_window, cv.WINDOW_NORMAL)
94160
# change window size
95-
cv.resizeWindow(image_window, 800 if image.shape[0] > 800 else image.shape[0], 600 if image.shape[1] > 600 else image.shape[1])
161+
rate = 1
162+
rate1 = 1
163+
rate2 = 1
164+
if(image.shape[1]>WINDOW_SIZE[0]):
165+
rate1 = WINDOW_SIZE[0]/image.shape[1]
166+
if(image.shape[0]>WINDOW_SIZE[1]):
167+
rate2 = WINDOW_SIZE[1]/image.shape[0]
168+
rate = min(rate1, rate2)
169+
# width, height
170+
WINDOW_SIZE = (int(image.shape[1] * rate), int(image.shape[0] * rate))
171+
cv.resizeWindow(image_window, WINDOW_SIZE[0], WINDOW_SIZE[1])
96172
# put the window on the left of the screen
97173
cv.moveWindow(image_window, 50, 100)
98174
# set listener to record user's click point
99-
cv.setMouseCallback(image_window, select)
175+
param = {
176+
'original_img': image,
177+
'mouse_down_time': 0,
178+
'image_window' : image_window
179+
}
180+
cv.setMouseCallback(image_window, select, param)
100181
# tips in the terminal
101-
print("click the picture on the LEFT and see the result on the RIGHT!")
182+
print("Click — Select foreground point\n"
183+
"Long press — Select background point\n"
184+
"Drag — Create selection box\n"
185+
"Enter — Infer\n"
186+
"Backspace — Clear the prompts\n"
187+
"Q - Quit")
102188
# show image
103189
cv.imshow(image_window, image)
190+
current_img = image.copy()
191+
# create window to show visualized result
192+
vis_image = image.copy()
193+
segmentation_window = "Segment result"
194+
cv.namedWindow(segmentation_window, cv.WINDOW_NORMAL)
195+
cv.resizeWindow(segmentation_window, WINDOW_SIZE[0], WINDOW_SIZE[1])
196+
cv.moveWindow(segmentation_window, WINDOW_SIZE[0]+51, 100)
197+
cv.imshow(segmentation_window, vis_image)
104198
# waiting for click
105-
while cv.waitKey(1) == -1 or clicked_left:
106-
# receive click
107-
if clicked_left:
108-
# put the click point (x,y) into the model to predict
109-
result = model.infer(image=image, points=point, labels=[1])
110-
# get the visualized result
111-
vis_result = visualize(image, result)
112-
# create window to show visualized result
113-
cv.namedWindow("vis_result", cv.WINDOW_NORMAL)
114-
cv.resizeWindow("vis_result", 800 if vis_result.shape[0] > 800 else vis_result.shape[0], 600 if vis_result.shape[1] > 600 else vis_result.shape[1])
115-
cv.moveWindow("vis_result", 851, 100)
116-
cv.imshow("vis_result", vis_result)
117-
# set click false to listen another click
118-
clicked_left = False
119-
elif cv.getWindowProperty(image_window, cv.WND_PROP_VISIBLE) < 1:
120-
# if click × to close the image window then ending
199+
while True:
200+
# Check window status
201+
# if click × to close the image window then ending
202+
if (cv.getWindowProperty(image_window, cv.WND_PROP_VISIBLE) < 1 or
203+
cv.getWindowProperty(segmentation_window, cv.WND_PROP_VISIBLE) < 1):
121204
break
122-
else:
123-
# when not clicked, set point to empty
124-
point = []
205+
206+
# Handle keyboard input
207+
key = cv.waitKey(1)
208+
209+
# receive enter
210+
if key == 13:
211+
212+
vis_image = image.copy()
213+
cv.putText(vis_image, "infering...",
214+
(50, vis_image.shape[0]//2),
215+
cv.FONT_HERSHEY_SIMPLEX, 10, (255,255,255), 5)
216+
cv.imshow(segmentation_window, vis_image)
217+
218+
result = model.infer(image=image, points=points, labels=labels)
219+
if len(result) == 0:
220+
print("clear and select points again!")
221+
else:
222+
vis_result = visualize(image, result)
223+
224+
cv.imshow(segmentation_window, vis_result)
225+
elif key == 8 or key == 127: # ASCII for Backspace or Delete
226+
points.clear()
227+
labels.clear()
228+
backend_point = []
229+
rectangle = False
230+
current_img = image
231+
print("Points are cleared.")
232+
cv.imshow(image_window, image)
233+
elif key == ord('q') or key == ord('Q'):
234+
break
235+
125236
cv.destroyAllWindows()
126-
237+
127238
# Save results if save is true
128239
if args.save:
129240
cv.imwrite('./example_outputs/vis_result.jpg', vis_result)
130241
cv.imwrite("./example_outputs/mask.jpg", result)
131242
print('vis_result.jpg and mask.jpg are saved to ./example_outputs/')
132243

133-
134244
else:
135245
print('Set input path to a certain image.')
136246
pass
137-
247+

0 commit comments

Comments
 (0)