|
49 | 49 | from vllm.multimodal.utils import group_mm_inputs_by_modality
|
50 | 50 | from vllm.sampling_params import SamplingType
|
51 | 51 | from vllm.sequence import IntermediateTensors
|
52 |
| -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, |
53 |
| - LayerBlockType, LazyLoader, cdiv) |
| 52 | +from vllm.utils import (DeviceMemoryProfiler, LazyLoader, cdiv, |
| 53 | + is_pin_memory_available) |
54 | 54 | from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
55 | 55 | from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
56 | 56 | KVCacheSpec)
|
@@ -143,7 +143,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
143 | 143 | else:
|
144 | 144 | self.chunked_prefill_enabled = True
|
145 | 145 | self.device = device
|
146 |
| - |
| 146 | + self.pin_memory = is_pin_memory_available() |
147 | 147 | self.is_multimodal_model = self.model_config.is_multimodal_model
|
148 | 148 | self.block_size = vllm_config.cache_config.block_size
|
149 | 149 |
|
@@ -1607,89 +1607,6 @@ def execute_model(
|
1607 | 1607 |
|
1608 | 1608 | return model_runner_output
|
1609 | 1609 |
|
1610 |
| - def _profile_multimodal(self) -> None: |
1611 |
| - # TODO: handle encoder-decoder models once we support them. |
1612 |
| - # NOTE: Currently model is profiled with a single non-text |
1613 |
| - # modality with the max possible input tokens even when |
1614 |
| - # it supports multiple. |
1615 |
| - |
1616 |
| - if (not self.is_multimodal_model |
1617 |
| - or self.max_num_encoder_input_tokens <= 0 |
1618 |
| - or self.encoder_cache_size <= 0): |
1619 |
| - return |
1620 |
| - |
1621 |
| - max_tokens_by_modality_dict = ( |
1622 |
| - MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( |
1623 |
| - self.model_config)) |
1624 |
| - dummy_data_modality, max_tokens_per_mm_item = max( |
1625 |
| - max_tokens_by_modality_dict.items(), key=lambda item: item[1]) |
1626 |
| - |
1627 |
| - # Check how many items of this modality can be supported by |
1628 |
| - # the encoder budget. |
1629 |
| - encoder_budget = min(self.max_num_encoder_input_tokens, |
1630 |
| - self.encoder_cache_size) |
1631 |
| - |
1632 |
| - max_num_mm_items_encoder_budget = cdiv(encoder_budget, |
1633 |
| - max_tokens_per_mm_item) |
1634 |
| - |
1635 |
| - # Check how many items of this modality can be supported by |
1636 |
| - # the decoder budget. |
1637 |
| - max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( |
1638 |
| - self.model_config)[dummy_data_modality] |
1639 |
| - |
1640 |
| - # NOTE: We do not consider max_num_batched_tokens on purpose |
1641 |
| - # because the multimodal embeddings can be generated in advance |
1642 |
| - # and chunked prefilled. |
1643 |
| - max_num_mm_items_decoder_budget = self.max_num_reqs * \ |
1644 |
| - max_mm_items_per_req |
1645 |
| - |
1646 |
| - max_num_mm_items = min(max_num_mm_items_encoder_budget, |
1647 |
| - max_num_mm_items_decoder_budget) |
1648 |
| - |
1649 |
| - logger.info( |
1650 |
| - "Encoder cache will be initialized with a budget of %s tokens," |
1651 |
| - " and profiled with %s %s items of the maximum feature size.", |
1652 |
| - encoder_budget, max_num_mm_items, dummy_data_modality) |
1653 |
| - |
1654 |
| - # Create dummy batch of multimodal inputs. |
1655 |
| - dummy_request_data = self.input_registry.dummy_data_for_profiling( |
1656 |
| - model_config=self.model_config, |
1657 |
| - seq_len=self.max_num_tokens, |
1658 |
| - mm_registry=self.mm_registry, |
1659 |
| - ) |
1660 |
| - dummy_mm_data = dummy_request_data.multi_modal_data |
1661 |
| - |
1662 |
| - if not isinstance(dummy_mm_data, MultiModalKwargs): |
1663 |
| - # TODO: Delete this check once input mapper is fully removed. |
1664 |
| - raise RuntimeError("Legacy input mapper is not supported in V1") |
1665 |
| - |
1666 |
| - # Dummy data definition in V0 may contain multiple multimodal items |
1667 |
| - # (e.g, multiple images) for a single request, therefore here we |
1668 |
| - # always replicate first item by max_num_mm_items times since in V1 |
1669 |
| - # they are scheduled to be processed separately. |
1670 |
| - |
1671 |
| - dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality, |
1672 |
| - item_index=0) |
1673 |
| - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) |
1674 |
| - |
1675 |
| - batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * |
1676 |
| - max_num_mm_items) |
1677 |
| - batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( |
1678 |
| - batched_dummy_mm_inputs, device=self.device) |
1679 |
| - |
1680 |
| - # Run multimodal encoder. |
1681 |
| - dummy_encoder_outputs = self.model.get_multimodal_embeddings( |
1682 |
| - **batched_dummy_mm_inputs) |
1683 |
| - assert len(dummy_encoder_outputs) == max_num_mm_items, ( |
1684 |
| - "Expected dimension 0 of encoder outputs to match the number " |
1685 |
| - f"of multimodal data items: {max_num_mm_items}, got " |
1686 |
| - f"{len(dummy_encoder_outputs)=} instead. This is most likely " |
1687 |
| - "due to the 'get_multimodal_embeddings' method of the model " |
1688 |
| - "not implemented correctly.") |
1689 |
| - |
1690 |
| - # Cache the dummy encoder outputs. |
1691 |
| - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) |
1692 |
| - |
1693 | 1610 | @torch.inference_mode()
|
1694 | 1611 | def _dummy_run(
|
1695 | 1612 | self,
|
@@ -1796,12 +1713,76 @@ def _dummy_run(
|
1796 | 1713 | self.drafter.dummy_run(num_tokens)
|
1797 | 1714 | return hidden_states
|
1798 | 1715 |
|
| 1716 | + @torch.inference_mode() |
1799 | 1717 | def profile_run(self) -> None:
|
1800 |
| - # FIXME Profile with multimodal encoder & encoder cache. |
1801 |
| - # current _profile_multimodal() using PyTorch SDPA backend method not |
1802 |
| - # support for window/full attn to reduce Memcpy operations, so will cause |
1803 |
| - # Out Of Memory problem, so we currently don't use self._profile_multimodal() |
1804 |
| - # self._profile_multimodal() |
| 1718 | + # Profile with multimodal encoder & encoder cache. |
| 1719 | + # TODO: handle encoder-decoder models once we support them. |
| 1720 | + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 |
| 1721 | + and self.encoder_cache_size > 0): |
| 1722 | + |
| 1723 | + # NOTE: Currently model is profiled with a single non-text |
| 1724 | + # modality with the max possible input tokens even when |
| 1725 | + # it supports multiple. |
| 1726 | + max_tokens_by_modality_dict = self.mm_registry \ |
| 1727 | + .get_max_tokens_per_item_by_nonzero_modality(self.model_config) |
| 1728 | + dummy_data_modality, max_tokens_per_mm_item = max( |
| 1729 | + max_tokens_by_modality_dict.items(), key=lambda item: item[1]) |
| 1730 | + |
| 1731 | + # Check how many items of this modality can be supported by |
| 1732 | + # the encoder budget. |
| 1733 | + encoder_budget = min(self.max_num_encoder_input_tokens, |
| 1734 | + self.encoder_cache_size) |
| 1735 | + |
| 1736 | + max_num_mm_items_encoder_budget = cdiv(encoder_budget, |
| 1737 | + max_tokens_per_mm_item) |
| 1738 | + |
| 1739 | + # Check how many items of this modality can be supported by |
| 1740 | + # the decoder budget. |
| 1741 | + max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( |
| 1742 | + self.model_config)[dummy_data_modality] |
| 1743 | + |
| 1744 | + # NOTE: We do not consider max_num_batched_tokens on purpose |
| 1745 | + # because the multimodal embeddings can be generated in advance |
| 1746 | + # and chunked prefilled. |
| 1747 | + max_num_mm_items_decoder_budget = self.max_num_reqs * \ |
| 1748 | + max_mm_items_per_req |
| 1749 | + |
| 1750 | + max_num_mm_items = min(max_num_mm_items_encoder_budget, |
| 1751 | + max_num_mm_items_decoder_budget) |
| 1752 | + |
| 1753 | + logger.info( |
| 1754 | + "Encoder cache will be initialized with a budget of %s tokens," |
| 1755 | + " and profiled with %s %s items of the maximum feature size.", |
| 1756 | + encoder_budget, max_num_mm_items, dummy_data_modality) |
| 1757 | + |
| 1758 | + # Create dummy batch of multimodal inputs. |
| 1759 | + dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( |
| 1760 | + model_config=self.model_config, |
| 1761 | + seq_len=self.max_num_tokens, |
| 1762 | + mm_counts={ |
| 1763 | + dummy_data_modality: 1 |
| 1764 | + }, |
| 1765 | + ).multi_modal_data |
| 1766 | + |
| 1767 | + batched_dummy_mm_inputs = MultiModalKwargs.batch( |
| 1768 | + [dummy_mm_kwargs] * max_num_mm_items, |
| 1769 | + pin_memory=self.pin_memory) |
| 1770 | + batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( |
| 1771 | + batched_dummy_mm_inputs, |
| 1772 | + device=self.device, |
| 1773 | + ) |
| 1774 | + |
| 1775 | + # Run multimodal encoder. |
| 1776 | + dummy_encoder_outputs = self.model.get_multimodal_embeddings( |
| 1777 | + **batched_dummy_mm_inputs) |
| 1778 | + |
| 1779 | + sanity_check_mm_encoder_outputs( |
| 1780 | + dummy_encoder_outputs, |
| 1781 | + expected_num_items=max_num_mm_items, |
| 1782 | + ) |
| 1783 | + |
| 1784 | + # Cache the dummy encoder outputs. |
| 1785 | + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) |
1805 | 1786 |
|
1806 | 1787 | # For profile, have maximum num_reqs and that collectively have
|
1807 | 1788 | # maximum num_tokens.
|
|
0 commit comments