-
-
Notifications
You must be signed in to change notification settings - Fork 9.1k
Closed
Labels
RFCfeature requestNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomersmulti-modalityRelated to multi-modality (#4194)Related to multi-modality (#4194)
Description
🚀 The feature, motivation and pitch
Currently, we use _parse_and_validate_*_input
to validate the multi-modal inputs. However, only minimal checks are being made, with some models only checking the type of the inputs. It is easy for the actual shape of the inputs to not match what is being documented in classes like *ImagePixelInputs
, confusing model developers and maintainers.
To avoid this, I propose adding a base class TensorSchema
to validate the model inputs. For example:
Original code:
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`"""
image_sizes: torch.Tensor
"""Shape: `(batch_size * num_images, 2)`"""
The idea:
class Phi3VImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size (number of prompts)
- n: Number of images
- p: Number of patches
- h: Height of each patch
- w: Width of each patch
"""
type: Literal["pixel_values"] = "pixel_values"
data: Annotated[Union[torch.Tensor, List[torch.Tensor]], TensorShape("bn", "p", 3, "h", "w")]
image_sizes: Annotated[Union[torch.Tensor, List[torch.Tensor]], TensorShape("bn", 2)]
- Validation is done automatically, similar to Pydantic models
- To avoid performance issues, we should be able to disable validation using a flag
- We tag each tensor field with
typing_extensions.Annotated
and use the additional metadata to perform validation.- Can switch to
typing.Annotated
once we drop support for Python 3.9
- Can switch to
- Dimensions that are constants can be checked directly
- Example: We validate that
data.shape[2] == 3
- Example: We validate that
- Dimensions with the same name should be consistent between fields, e.g.
- Example: Since
data.shape[0]
andimage_sizes.shape[0]
share the namebn
, we validate thatdata.shape[0] == image_sizes[0]
- Example: Since
- If a field is a list/tuple instead of a tensor, we use
len
instead ofshape
to check the leading dimension, then recurse into each element of the list to check the remaining dimensions.
Notes
This idea can benefit projects outside of vLLM as well, so we should consider developing this as a separate package.
CC List
@ywang96 @Isotr0py @mgoin
@hmellor in case this is already a project on HF
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
imkero
Metadata
Metadata
Assignees
Labels
RFCfeature requestNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomersmulti-modalityRelated to multi-modality (#4194)Related to multi-modality (#4194)
Type
Projects
Status
Done