Skip to content

Commit df95e89

Browse files
committed
Merge branch 'main' of github.com:pytorch/vision into Alexandre-SCHOEPP/main
2 parents 2788662 + 904dad4 commit df95e89

File tree

4 files changed

+265
-76
lines changed

4 files changed

+265
-76
lines changed

test/common_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,13 @@ def sample_position(values, max_value):
423423
format = tv_tensors.BoundingBoxFormat[format]
424424

425425
dtype = dtype or torch.float32
426+
int_dtype = dtype in (
427+
torch.uint8,
428+
torch.int8,
429+
torch.int16,
430+
torch.int32,
431+
torch.int64,
432+
)
426433

427434
h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size)
428435
y = sample_position(h, canvas_size[0])
@@ -449,17 +456,17 @@ def sample_position(values, max_value):
449456
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
450457
r_rad = r * torch.pi / 180.0
451458
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
452-
x1, y1 = x, y
453-
x2 = x1 + w * cos
454-
y2 = y1 - w * sin
455-
x3 = x2 + h * sin
456-
y3 = y2 + h * cos
457-
x4 = x1 + h * sin
458-
y4 = y1 + h * cos
459+
x1 = torch.round(x) if int_dtype else x
460+
y1 = torch.round(y) if int_dtype else y
461+
x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos
462+
y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin
463+
x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin
464+
y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos
465+
x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin
466+
y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos
459467
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
460468
else:
461469
raise ValueError(f"Format {format} is not supported")
462-
463470
return tv_tensors.BoundingBoxes(
464471
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
465472
)

0 commit comments

Comments
 (0)