Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,21 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr
output_data, inference_state = result

await self._eval_mlx(output_data)
output_data = await asyncio.get_running_loop().run_in_executor(

output_data_mx = output_data # Because it stores results from MLX model(mx.array)

def convert_output_to_numpy(tensor):
# If the tensor is bfloat16, then convert to float32 as numpy doesn't support bfloat16.
if tensor.dtype == mx.bfloat16:
return np.array(tensor.astype(mx.float32), copy=False)
else:
return np.array(tensor, copy=False)

output_data_np = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: np.array(output_data, copy=False)
lambda: convert_output_to_numpy(output_data_mx)
)
return output_data, inference_state
return output_data_np, inference_state

async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
await self.ensure_shard(shard)
Expand Down Expand Up @@ -152,10 +162,26 @@ def train_step(inp, tar, lng):
)
await self._eval_mlx(*eval_args)

first_layer_np = np.array([])
layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
await self._eval_mlx(first_layer)
return score, first_layer

if layers and 'input_layernorm' in layers[0]:
first_layer_mx = layers[0]['input_layernorm']
await self._eval_mlx(first_layer_mx)

def convert_gradient_to_numpy(grad_tensor):
# If the tensor is bfloat16, then convert to float32 as numpy doesn't support bfloat16.
if grad_tensor.dtype == mx.bfloat16:
return np.array(grad_tensor.astype(mx.float32), copy=False)
else:
return np.array(grad_tensor, copy=False)

first_layer_np = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: convert_gradient_to_numpy(first_layer_mx)
)

return score, first_layer_np

async def ensure_shard(self, shard: Shard):
async with self._shard_lock:
Expand Down