@@ -313,37 +313,45 @@ def test_repeat(x, kw, data):
313
313
314
314
assume (n_repititions <= hh .SQRT_MAX_ARRAY_SIZE )
315
315
316
- out = xp .repeat (x , repeats , ** kw )
317
- ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
318
- if axis is None :
319
- expected_shape = (n_repititions ,)
320
- else :
321
- expected_shape = list (shape )
322
- expected_shape [axis ] = n_repititions
323
- expected_shape = tuple (expected_shape )
324
- ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
316
+ repro_snippet = ph .format_snippet (f"xp.repeat({ x !r} ,{ repeats !r} , **kw) with { kw = } " )
317
+ try :
318
+ out = xp .repeat (x , repeats , ** kw )
325
319
326
- # Test values
320
+ ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
321
+ if axis is None :
322
+ expected_shape = (n_repititions ,)
323
+ else :
324
+ expected_shape = list (shape )
325
+ expected_shape [axis ] = n_repititions
326
+ expected_shape = tuple (expected_shape )
327
+ ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
328
+
329
+ # Test values
330
+
331
+ if isinstance (repeats , int ):
332
+ repeats_array = xp .full (size , repeats , dtype = xp .int32 )
333
+ else :
334
+ repeats_array = repeats
335
+
336
+ if kw .get ("axis" ) is None :
337
+ x = xp .reshape (x , (- 1 ,))
338
+ axis = 0
339
+
340
+ for idx , in sh .iter_indices (x .shape , skip_axes = axis ):
341
+ x_slice = x [idx ]
342
+ out_slice = out [idx ]
343
+ start = 0
344
+ for i , count in enumerate (repeats_array ):
345
+ end = start + count
346
+ ph .assert_array_elements ("repeat" , out = out_slice [start :end ],
347
+ expected = xp .full ((count ,), x_slice [i ], dtype = x .dtype ),
348
+ kw = kw )
349
+ start = end
350
+
351
+ except Exception as exc :
352
+ exc .add_note (repro_snippet )
353
+ raise
327
354
328
- if isinstance (repeats , int ):
329
- repeats_array = xp .full (size , repeats , dtype = xp .int32 )
330
- else :
331
- repeats_array = repeats
332
-
333
- if kw .get ("axis" ) is None :
334
- x = xp .reshape (x , (- 1 ,))
335
- axis = 0
336
-
337
- for idx , in sh .iter_indices (x .shape , skip_axes = axis ):
338
- x_slice = x [idx ]
339
- out_slice = out [idx ]
340
- start = 0
341
- for i , count in enumerate (repeats_array ):
342
- end = start + count
343
- ph .assert_array_elements ("repeat" , out = out_slice [start :end ],
344
- expected = xp .full ((count ,), x_slice [i ], dtype = x .dtype ),
345
- kw = kw )
346
- start = end
347
355
348
356
reshape_shape = st .shared (hh .shapes (), key = "reshape_shape" )
349
357
0 commit comments