@@ -3460,6 +3460,195 @@ def nngp_fn_diag(nngp):
34603460 return _elementwise (fn , f'ElementwiseNumerical({ fn } ,deg={ deg } )' , kernel_fn )
34613461
34623462
3463+ @layer
3464+ @supports_masking (remask_kernel = True )
3465+ def ImageResize (
3466+ shape : Sequence [int ],
3467+ method : Union [str , jax .image .ResizeMethod ],
3468+ antialias : bool = True ,
3469+ precision : lax .Precision = lax .Precision .HIGHEST ,
3470+ batch_axis : int = 0 ,
3471+ channel_axis : int = - 1
3472+ ) -> InternalLayer :
3473+ """Image resize function mimicking `jax.image.resize`.
3474+
3475+ Docstring adapted from https://jax.readthedocs.io/en/latest/_modules/jax/_src/image/scale.html#resize.
3476+ Note two changes:
3477+ 1. Only `"linear"` and `"nearest"` interpolation methods are supported;
3478+ 2. Set `shape[i]` to `-1` if you want dimension `i` of `inputs` unchanged.
3479+
3480+ The `method` argument expects one of the following resize methods:
3481+
3482+ `ResizeMethod.NEAREST`, `"nearest"`:
3483+ Nearest neighbor interpolation_. The values of `antialias` and `precision`
3484+ are ignored.
3485+
3486+ `ResizeMethod.LINEAR`, `"linear"`, `"bilinear"`, `"trilinear"`, `"triangle"`:
3487+ Linear interpolation_. If `antialias` is ``True``, uses a triangular filter
3488+ when downsampling.
3489+
3490+ The following methods are NOT SUPPORTED in `kernel_fn` (only `init_fn` and
3491+ `apply_fn` work):
3492+
3493+ `ResizeMethod.CUBIC`, `"cubic"`, `"bicubic"`, `"tricubic"`:
3494+ Cubic interpolation_, using the Keys cubic kernel.
3495+
3496+ `ResizeMethod.LANCZOS3`, `"lanczos3"`:
3497+ Lanczos resampling_, using a kernel of radius 3.
3498+
3499+ `ResizeMethod.LANCZOS5`, `"lanczos5"`:
3500+ Lanczos resampling_, using a kernel of radius 5.
3501+
3502+ .. _Nearest neighbor interpolation: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
3503+ .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
3504+ .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
3505+ .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
3506+
3507+ Args:
3508+ shape:
3509+ the output shape, as a sequence of integers with length equal to
3510+ the number of dimensions of `image`. Note that :func:`resize` does not
3511+ distinguish spatial dimensions from batch or channel dimensions, so this
3512+ includes all dimensions of the image. To leave a certain dimension
3513+ (e.g. batch or channel) unchanged, set the respective entry to `-1`.
3514+ Note that setting it to the respective size of the `input` also works,
3515+ but will make `kernel_fn` computation much more expensive with no benefit.
3516+ Further, note that `kernel_fn` does not support resizing the
3517+ `channel_axis`, therefore `shape[channel_axis]` should be set to `-1`.
3518+
3519+ method:
3520+ the resizing method to use; either a `ResizeMethod` instance or a
3521+ string. Available methods are: `"LINEAR"`, `"NEAREST"`. Other methods
3522+ like `"LANCZOS3"`, `"LANCZOS5"`, `"CUBIC"` only work for `apply_fn`, but
3523+ not `kernel_fn`.
3524+
3525+ antialias:
3526+ should an antialiasing filter be used when downsampling? Defaults to
3527+ `True`. Has no effect when upsampling.
3528+
3529+ precision:
3530+ `np.einsum` precision.
3531+
3532+ batch_axis:
3533+ batch axis for `inputs`. Defaults to `0`, the leading axis.
3534+
3535+ channel_axis:
3536+ channel axis for `inputs`. Defaults to `-1`, the trailing axis. For
3537+ `kernel_fn`, channel size is considered to be infinite.
3538+
3539+ Returns:
3540+ `(init_fn, apply_fn, kernel_fn)`.
3541+ """
3542+ def _shape (input_shape ):
3543+ return tuple (s if s != - 1 else input_shape [i ] for i , s in enumerate (shape ))
3544+
3545+ def init_fn (rng , input_shape ):
3546+ return _shape (input_shape ), ()
3547+
3548+ def apply_fn (params , x , ** kwargs ):
3549+ return jax .image .resize (image = x ,
3550+ shape = _shape (x .shape ),
3551+ method = method ,
3552+ antialias = antialias ,
3553+ precision = precision )
3554+
3555+ def mask_fn (mask , input_shape ):
3556+ # Interploation (except for "NEAREST") is done in float format:
3557+ # https://github.com/google/jax/issues/3811. Float converted back to bool
3558+ # rounds up all non-zero elements to `True`, so naively resizing the `mask`
3559+ # will mark any output that has at least one contribution from a masked
3560+ # input as fully masked. This can lead to mask growing unexpectedly, e.g.
3561+ # consider a 5x5 image with a single masked pixel in the center:
3562+ #
3563+ # >>> mask = np.array([[0, 0, 0, 0, 0],
3564+ # >>> [0, 0, 0, 0, 0],
3565+ # >>> [0, 0, 1, 0, 0],
3566+ # >>> [0, 0, 0, 0, 0],
3567+ # >>> [0, 0, 0, 0, 0]], dtype=np.bool_)
3568+ #
3569+ # Downsampling this mask to 2x2 will mark all output pixels as masked!
3570+ #
3571+ # >>> jax.image.resize(mask, (2, 2), method='bilinear').astype(np.bool_)
3572+ # >>> DeviceArray([[ True, True],
3573+ # >>> [ True, True]], dtype=bool)
3574+ #
3575+ # Therefore, througout `stax` we rather follow the convention of marking
3576+ # outputs as masked if they _only_ have contributions from masked elements
3577+ # (in other words, we don't let the mask destroy information; let content
3578+ # have preference over mask). For this we invert the mask before and after
3579+ # resizing, to round up unmasked outputs instead.
3580+ return ~ jax .image .resize (image = ~ mask ,
3581+ shape = _shape (mask .shape ),
3582+ method = method ,
3583+ antialias = antialias ,
3584+ precision = precision ).astype (np .bool_ )
3585+
3586+ batch_axis , channel_axis = utils .mod ((batch_axis , channel_axis ), shape )
3587+
3588+ diagonal_batch = shape [batch_axis ] == - 1
3589+ diagonal_spatial = _Diagonal (
3590+ input = _Bool .NO
3591+ if any (shape [i ] != - 1 for i in range (len (shape ))
3592+ if i not in (batch_axis , channel_axis ))
3593+ else _Bool .YES ) # pytype:disable=wrong-keyword-args
3594+
3595+ @_requires (batch_axis = batch_axis ,
3596+ channel_axis = channel_axis ,
3597+ diagonal_batch = diagonal_batch ,
3598+ diagonal_spatial = diagonal_spatial ) # pytype:disable=wrong-keyword-args
3599+ def kernel_fn (k : Kernel , ** kwargs ) -> Kernel :
3600+ if isinstance (method , str ):
3601+ _method = jax .image .ResizeMethod .from_string (method )
3602+
3603+ if _method not in (jax .image .ResizeMethod .LINEAR ,
3604+ jax .image .ResizeMethod .NEAREST ):
3605+ raise NotImplementedError (
3606+ f'Only "linear" (`jax.image.ResizeMethod.LINEAR`) and '
3607+ f'"nearest" (`jax.image.ResizeMethod.NEAREST`) interpolation is '
3608+ f'supported in `kernel_fn`, got { _method } .' )
3609+
3610+ if shape [channel_axis ] != - 1 :
3611+ raise ValueError (f'Resizing the channel axis { channel_axis } is not '
3612+ f'well-defined in the infinite-width limit. Please '
3613+ f'either set `shape[channel_axis] = -1` or file '
3614+ f'an issue describing your use case at '
3615+ f'https://github.com/google/neural-tangents/issues/new.' )
3616+
3617+ cov1 , nngp , cov2 , ntk = k .cov1 , k .nngp , k .cov2 , k .ntk
3618+ diagonal_spatial = k .diagonal_spatial
3619+
3620+ def resize (k , shape1 , shape2 , diagonal_batch ):
3621+ if k is None or k .ndim == 0 :
3622+ return k
3623+
3624+ k_shape = (shape1 [batch_axis ],)
3625+ if not diagonal_batch :
3626+ k_shape += (shape2 [batch_axis ],)
3627+
3628+ for i , (s1 , s2 ) in enumerate (zip (shape1 , shape2 )):
3629+ if i not in (batch_axis , channel_axis ):
3630+ k_shape += (s1 ,)
3631+ if not diagonal_spatial :
3632+ k_shape += (s2 ,)
3633+
3634+ return jax .image .resize (image = k ,
3635+ shape = k_shape ,
3636+ method = _method ,
3637+ antialias = antialias ,
3638+ precision = precision )
3639+
3640+ shape1 = _shape (k .shape1 )
3641+ shape2 = _shape (k .shape2 )
3642+
3643+ k = k .replace (cov1 = resize (cov1 , shape1 , shape1 , k .diagonal_batch ),
3644+ nngp = resize (nngp , shape1 , shape2 , False ),
3645+ cov2 = resize (cov2 , shape2 , shape2 , k .diagonal_batch ),
3646+ ntk = resize (ntk , shape1 , shape2 , False ))
3647+ return k
3648+
3649+ return init_fn , apply_fn , kernel_fn , mask_fn
3650+
3651+
34633652# INTERNAL UTILITIES
34643653
34653654
0 commit comments