Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit f8a964f

Browse files
committed
Add stax.ImageResize mimicking https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html#jax.image.resize (thanks schsam@ for suggestion!)
PiperOrigin-RevId: 392300224
1 parent 8bd2542 commit f8a964f

2 files changed

Lines changed: 471 additions & 0 deletions

File tree

neural_tangents/stax.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3460,6 +3460,195 @@ def nngp_fn_diag(nngp):
34603460
return _elementwise(fn, f'ElementwiseNumerical({fn},deg={deg})', kernel_fn)
34613461

34623462

3463+
@layer
3464+
@supports_masking(remask_kernel=True)
3465+
def ImageResize(
3466+
shape: Sequence[int],
3467+
method: Union[str, jax.image.ResizeMethod],
3468+
antialias: bool = True,
3469+
precision: lax.Precision = lax.Precision.HIGHEST,
3470+
batch_axis: int = 0,
3471+
channel_axis: int = -1
3472+
) -> InternalLayer:
3473+
"""Image resize function mimicking `jax.image.resize`.
3474+
3475+
Docstring adapted from https://jax.readthedocs.io/en/latest/_modules/jax/_src/image/scale.html#resize.
3476+
Note two changes:
3477+
1. Only `"linear"` and `"nearest"` interpolation methods are supported;
3478+
2. Set `shape[i]` to `-1` if you want dimension `i` of `inputs` unchanged.
3479+
3480+
The `method` argument expects one of the following resize methods:
3481+
3482+
`ResizeMethod.NEAREST`, `"nearest"`:
3483+
Nearest neighbor interpolation_. The values of `antialias` and `precision`
3484+
are ignored.
3485+
3486+
`ResizeMethod.LINEAR`, `"linear"`, `"bilinear"`, `"trilinear"`, `"triangle"`:
3487+
Linear interpolation_. If `antialias` is ``True``, uses a triangular filter
3488+
when downsampling.
3489+
3490+
The following methods are NOT SUPPORTED in `kernel_fn` (only `init_fn` and
3491+
`apply_fn` work):
3492+
3493+
`ResizeMethod.CUBIC`, `"cubic"`, `"bicubic"`, `"tricubic"`:
3494+
Cubic interpolation_, using the Keys cubic kernel.
3495+
3496+
`ResizeMethod.LANCZOS3`, `"lanczos3"`:
3497+
Lanczos resampling_, using a kernel of radius 3.
3498+
3499+
`ResizeMethod.LANCZOS5`, `"lanczos5"`:
3500+
Lanczos resampling_, using a kernel of radius 5.
3501+
3502+
.. _Nearest neighbor interpolation: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
3503+
.. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
3504+
.. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
3505+
.. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
3506+
3507+
Args:
3508+
shape:
3509+
the output shape, as a sequence of integers with length equal to
3510+
the number of dimensions of `image`. Note that :func:`resize` does not
3511+
distinguish spatial dimensions from batch or channel dimensions, so this
3512+
includes all dimensions of the image. To leave a certain dimension
3513+
(e.g. batch or channel) unchanged, set the respective entry to `-1`.
3514+
Note that setting it to the respective size of the `input` also works,
3515+
but will make `kernel_fn` computation much more expensive with no benefit.
3516+
Further, note that `kernel_fn` does not support resizing the
3517+
`channel_axis`, therefore `shape[channel_axis]` should be set to `-1`.
3518+
3519+
method:
3520+
the resizing method to use; either a `ResizeMethod` instance or a
3521+
string. Available methods are: `"LINEAR"`, `"NEAREST"`. Other methods
3522+
like `"LANCZOS3"`, `"LANCZOS5"`, `"CUBIC"` only work for `apply_fn`, but
3523+
not `kernel_fn`.
3524+
3525+
antialias:
3526+
should an antialiasing filter be used when downsampling? Defaults to
3527+
`True`. Has no effect when upsampling.
3528+
3529+
precision:
3530+
`np.einsum` precision.
3531+
3532+
batch_axis:
3533+
batch axis for `inputs`. Defaults to `0`, the leading axis.
3534+
3535+
channel_axis:
3536+
channel axis for `inputs`. Defaults to `-1`, the trailing axis. For
3537+
`kernel_fn`, channel size is considered to be infinite.
3538+
3539+
Returns:
3540+
`(init_fn, apply_fn, kernel_fn)`.
3541+
"""
3542+
def _shape(input_shape):
3543+
return tuple(s if s != -1 else input_shape[i] for i, s in enumerate(shape))
3544+
3545+
def init_fn(rng, input_shape):
3546+
return _shape(input_shape), ()
3547+
3548+
def apply_fn(params, x, **kwargs):
3549+
return jax.image.resize(image=x,
3550+
shape=_shape(x.shape),
3551+
method=method,
3552+
antialias=antialias,
3553+
precision=precision)
3554+
3555+
def mask_fn(mask, input_shape):
3556+
# Interploation (except for "NEAREST") is done in float format:
3557+
# https://github.com/google/jax/issues/3811. Float converted back to bool
3558+
# rounds up all non-zero elements to `True`, so naively resizing the `mask`
3559+
# will mark any output that has at least one contribution from a masked
3560+
# input as fully masked. This can lead to mask growing unexpectedly, e.g.
3561+
# consider a 5x5 image with a single masked pixel in the center:
3562+
#
3563+
# >>> mask = np.array([[0, 0, 0, 0, 0],
3564+
# >>> [0, 0, 0, 0, 0],
3565+
# >>> [0, 0, 1, 0, 0],
3566+
# >>> [0, 0, 0, 0, 0],
3567+
# >>> [0, 0, 0, 0, 0]], dtype=np.bool_)
3568+
#
3569+
# Downsampling this mask to 2x2 will mark all output pixels as masked!
3570+
#
3571+
# >>> jax.image.resize(mask, (2, 2), method='bilinear').astype(np.bool_)
3572+
# >>> DeviceArray([[ True, True],
3573+
# >>> [ True, True]], dtype=bool)
3574+
#
3575+
# Therefore, througout `stax` we rather follow the convention of marking
3576+
# outputs as masked if they _only_ have contributions from masked elements
3577+
# (in other words, we don't let the mask destroy information; let content
3578+
# have preference over mask). For this we invert the mask before and after
3579+
# resizing, to round up unmasked outputs instead.
3580+
return ~jax.image.resize(image=~mask,
3581+
shape=_shape(mask.shape),
3582+
method=method,
3583+
antialias=antialias,
3584+
precision=precision).astype(np.bool_)
3585+
3586+
batch_axis, channel_axis = utils.mod((batch_axis, channel_axis), shape)
3587+
3588+
diagonal_batch = shape[batch_axis] == -1
3589+
diagonal_spatial = _Diagonal(
3590+
input=_Bool.NO
3591+
if any(shape[i] != -1 for i in range(len(shape))
3592+
if i not in (batch_axis, channel_axis))
3593+
else _Bool.YES) # pytype:disable=wrong-keyword-args
3594+
3595+
@_requires(batch_axis=batch_axis,
3596+
channel_axis=channel_axis,
3597+
diagonal_batch=diagonal_batch,
3598+
diagonal_spatial=diagonal_spatial) # pytype:disable=wrong-keyword-args
3599+
def kernel_fn(k: Kernel, **kwargs) -> Kernel:
3600+
if isinstance(method, str):
3601+
_method = jax.image.ResizeMethod.from_string(method)
3602+
3603+
if _method not in (jax.image.ResizeMethod.LINEAR,
3604+
jax.image.ResizeMethod.NEAREST):
3605+
raise NotImplementedError(
3606+
f'Only "linear" (`jax.image.ResizeMethod.LINEAR`) and '
3607+
f'"nearest" (`jax.image.ResizeMethod.NEAREST`) interpolation is '
3608+
f'supported in `kernel_fn`, got {_method}.')
3609+
3610+
if shape[channel_axis] != -1:
3611+
raise ValueError(f'Resizing the channel axis {channel_axis} is not '
3612+
f'well-defined in the infinite-width limit. Please '
3613+
f'either set `shape[channel_axis] = -1` or file '
3614+
f'an issue describing your use case at '
3615+
f'https://github.com/google/neural-tangents/issues/new.')
3616+
3617+
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
3618+
diagonal_spatial = k.diagonal_spatial
3619+
3620+
def resize(k, shape1, shape2, diagonal_batch):
3621+
if k is None or k.ndim == 0:
3622+
return k
3623+
3624+
k_shape = (shape1[batch_axis],)
3625+
if not diagonal_batch:
3626+
k_shape += (shape2[batch_axis],)
3627+
3628+
for i, (s1, s2) in enumerate(zip(shape1, shape2)):
3629+
if i not in (batch_axis, channel_axis):
3630+
k_shape += (s1,)
3631+
if not diagonal_spatial:
3632+
k_shape += (s2,)
3633+
3634+
return jax.image.resize(image=k,
3635+
shape=k_shape,
3636+
method=_method,
3637+
antialias=antialias,
3638+
precision=precision)
3639+
3640+
shape1 = _shape(k.shape1)
3641+
shape2 = _shape(k.shape2)
3642+
3643+
k = k.replace(cov1=resize(cov1, shape1, shape1, k.diagonal_batch),
3644+
nngp=resize(nngp, shape1, shape2, False),
3645+
cov2=resize(cov2, shape2, shape2, k.diagonal_batch),
3646+
ntk=resize(ntk, shape1, shape2, False))
3647+
return k
3648+
3649+
return init_fn, apply_fn, kernel_fn, mask_fn
3650+
3651+
34633652
# INTERNAL UTILITIES
34643653

34653654

0 commit comments

Comments
 (0)