diff --git a/labelme/shape.py b/labelme/shape.py index 50d73fc0e..609e9b68e 100644 --- a/labelme/shape.py +++ b/labelme/shape.py @@ -325,6 +325,10 @@ def nearestEdge(self, point, epsilon): def containsPoint(self, point) -> bool: if self.shape_type in ["line", "linestrip", "points"]: return False + if self.shape_type == "point": + if not self.points: + return False + return labelme.utils.distance(point - self.points[0]) <= self.point_size / 2 if self.mask is not None: raw_y = int(round(point.y() - self.points[0].y())) raw_x = int(round(point.x() - self.points[0].x())) diff --git a/tests/unit/shape_contains_point_test.py b/tests/unit/shape_contains_point_test.py new file mode 100644 index 000000000..5e9e518d4 --- /dev/null +++ b/tests/unit/shape_contains_point_test.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from PyQt5 import QtCore + +from labelme.shape import Shape + + +def _make_point_shape(x: float, y: float) -> Shape: + """Create a point shape with a single point at (x, y).""" + shape = Shape(shape_type="point") + shape.addPoint(QtCore.QPointF(x, y)) + return shape + + +def test_point_shape_contains_center(): + """Clicking exactly on a point shape should return True.""" + shape = _make_point_shape(100.0, 200.0) + assert shape.containsPoint(QtCore.QPointF(100.0, 200.0)) is True + + +def test_point_shape_contains_within_radius(): + """Clicking within point_size/2 of the center should return True.""" + shape = _make_point_shape(100.0, 200.0) + # point_size defaults to 8, so radius = 4. A point 3px away should hit. + assert shape.containsPoint(QtCore.QPointF(103.0, 200.0)) is True + + +def test_point_shape_at_exact_boundary(): + """Clicking exactly at point_size/2 distance should return True (inclusive).""" + shape = _make_point_shape(100.0, 200.0) + # point_size defaults to 8, so radius = 4. Exactly 4px away should hit. + assert shape.containsPoint(QtCore.QPointF(104.0, 200.0)) is True + + +def test_point_shape_outside_radius(): + """Clicking more than point_size/2 away should return False.""" + shape = _make_point_shape(100.0, 200.0) + # 10px away, well outside the radius of 4 + assert shape.containsPoint(QtCore.QPointF(110.0, 200.0)) is False + + +def test_point_shape_empty_points(): + """A point shape with no points should return False, not raise.""" + shape = Shape(shape_type="point") + assert shape.containsPoint(QtCore.QPointF(0.0, 0.0)) is False