@@ -423,6 +423,13 @@ def sample_position(values, max_value):
423
423
format = tv_tensors .BoundingBoxFormat [format ]
424
424
425
425
dtype = dtype or torch .float32
426
+ int_dtype = dtype in (
427
+ torch .uint8 ,
428
+ torch .int8 ,
429
+ torch .int16 ,
430
+ torch .int32 ,
431
+ torch .int64 ,
432
+ )
426
433
427
434
h , w = (torch .randint (1 , s , (num_boxes ,)) for s in canvas_size )
428
435
y = sample_position (h , canvas_size [0 ])
@@ -449,17 +456,17 @@ def sample_position(values, max_value):
449
456
elif format is tv_tensors .BoundingBoxFormat .XYXYXYXY :
450
457
r_rad = r * torch .pi / 180.0
451
458
cos , sin = torch .cos (r_rad ), torch .sin (r_rad )
452
- x1 , y1 = x , y
453
- x2 = x1 + w * cos
454
- y2 = y1 - w * sin
455
- x3 = x2 + h * sin
456
- y3 = y2 + h * cos
457
- x4 = x1 + h * sin
458
- y4 = y1 + h * cos
459
+ x1 = torch .round (x ) if int_dtype else x
460
+ y1 = torch .round (y ) if int_dtype else y
461
+ x2 = torch .round (x1 + w * cos ) if int_dtype else x1 + w * cos
462
+ y2 = torch .round (y1 - w * sin ) if int_dtype else y1 - w * sin
463
+ x3 = torch .round (x2 + h * sin ) if int_dtype else x2 + h * sin
464
+ y3 = torch .round (y2 + h * cos ) if int_dtype else y2 + h * cos
465
+ x4 = torch .round (x1 + h * sin ) if int_dtype else x1 + h * sin
466
+ y4 = torch .round (y1 + h * cos ) if int_dtype else y1 + h * cos
459
467
parts = (x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 )
460
468
else :
461
469
raise ValueError (f"Format { format } is not supported" )
462
-
463
470
return tv_tensors .BoundingBoxes (
464
471
torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , canvas_size = canvas_size
465
472
)
0 commit comments