Skip to content

[RFC]: Schema for checking input shapes for multi-modal models #14764

@DarkLight1337

Description

@DarkLight1337

🚀 The feature, motivation and pitch

Currently, we use _parse_and_validate_*_input to validate the multi-modal inputs. However, only minimal checks are being made, with some models only checking the type of the inputs. It is easy for the actual shape of the inputs to not match what is being documented in classes like *ImagePixelInputs, confusing model developers and maintainers.

To avoid this, I propose adding a base class TensorSchema to validate the model inputs. For example:

Original code:

class Phi3VImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: Union[torch.Tensor, List[torch.Tensor]]
    """Shape: `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`"""

    image_sizes: torch.Tensor
    """Shape: `(batch_size * num_images, 2)`"""

The idea:

class Phi3VImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size (number of prompts)
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
    """
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[Union[torch.Tensor, List[torch.Tensor]], TensorShape("bn", "p", 3, "h", "w")]
    image_sizes: Annotated[Union[torch.Tensor, List[torch.Tensor]], TensorShape("bn", 2)]
  • Validation is done automatically, similar to Pydantic models
    • To avoid performance issues, we should be able to disable validation using a flag
  • We tag each tensor field with typing_extensions.Annotated and use the additional metadata to perform validation.
    • Can switch to typing.Annotated once we drop support for Python 3.9
  • Dimensions that are constants can be checked directly
    • Example: We validate that data.shape[2] == 3
  • Dimensions with the same name should be consistent between fields, e.g.
    • Example: Since data.shape[0] and image_sizes.shape[0] share the name bn, we validate that data.shape[0] == image_sizes[0]
  • If a field is a list/tuple instead of a tensor, we use len instead of shape to check the leading dimension, then recurse into each element of the list to check the remaining dimensions.

Notes

This idea can benefit projects outside of vLLM as well, so we should consider developing this as a separate package.

CC List

@ywang96 @Isotr0py @mgoin
@hmellor in case this is already a project on HF

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions