Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,11 @@ def progress(tokens_processed, tokens_total):
logging.info(f"We have {ncaches} kv caches that take {nbytes/1e9:.2f} GB")

# Process the prompt and generate tokens
kv_kwargs = {}
if self.cli_args.kv_bits is not None:
kv_kwargs["kv_bits"] = self.cli_args.kv_bits
kv_kwargs["kv_group_size"] = self.cli_args.kv_group_size
kv_kwargs["quantized_kv_start"] = self.cli_args.quantized_kv_start
for gen in stream_generate(
model=model,
tokenizer=tokenizer,
Expand All @@ -994,6 +999,7 @@ def progress(tokens_processed, tokens_total):
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
prompt_progress_callback=progress,
**kv_kwargs,
):
rqueue.put(
Response(
Expand Down Expand Up @@ -1920,6 +1926,25 @@ def main():
action="store_true",
help="Use pipelining instead of tensor parallelism",
)
parser.add_argument(
"--kv-bits",
type=int,
default=None,
choices=[4, 8],
help="Number of bits for KV cache quantization (4 or 8). Default: None (no quantization)",
)
parser.add_argument(
"--kv-group-size",
type=int,
default=64,
help="Group size for KV cache quantization (default: 64)",
)
parser.add_argument(
"--quantized-kv-start",
type=int,
default=0,
help="Step to begin using a quantized KV cache (default: 0)",
)
args = parser.parse_args()
if mx.metal.is_available():
wired_limit = mx.device_info()["max_recommended_working_set_size"]
Expand Down