Open
Description
In the sdxl jax blog, example inference in the blog code is given and blog says with that inference 4 images is generated about 2 s with Cloud TPU v5e-4. However, Even though I am using same code and google cloud same tpu with ubuntu 22.04 base software version and python 3.10, I am generating images 3.23 seconds. Could you help me about what am I missing to achieve fast inference?
absl-py==2.1.0
attrs==21.2.0
Automat==20.2.0
Babel==2.8.0
bcrypt==3.2.0
blinker==1.4
certifi==2020.6.20
chardet==4.0.0
charset-normalizer==3.1.0
chex==0.1.86
click==8.0.3
cloud-init==23.1.2
colorama==0.4.4
command-not-found==0.3
configobj==5.0.6
constantly==15.1.0
cryptography==3.4.8
Cython==0.29.28
dbus-python==1.2.18
diffusers==0.27.2
distlib==0.3.6
distro==1.7.0
distro-info===1.1build1
etils==1.7.0
filelock==3.12.0
flax==0.8.3
fsspec==2024.3.1
httplib2==0.20.2
huggingface-hub==0.22.2
hyperlink==21.0.0
idna==3.3
importlib-metadata==4.6.4
importlib_resources==6.4.0
incremental==21.3.0
jax==0.4.26
jaxlib==0.4.26
jeepney==0.7.1
Jinja2==3.0.3
jsonpatch==1.32
jsonpointer==2.0
jsonschema==3.2.0
keyring==23.5.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
libtpu-nightly==0.1.dev20240403
markdown-it-py==3.0.0
MarkupSafe==2.0.1
mdurl==0.1.2
ml-dtypes==0.4.0
more-itertools==8.10.0
msgpack==1.0.8
nest-asyncio==1.6.0
netifaces==0.11.0
numpy==1.26.4
oauthlib==3.2.0
opt-einsum==3.3.0
optax==0.2.2
orbax-checkpoint==0.5.10
packaging==21.3
pexpect==4.8.0
pillow==10.3.0
platformdirs==3.5.0
protobuf==5.26.1
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.1
Pygments==2.17.2
PyGObject==3.42.1
PyHamcrest==2.0.2
PyJWT==2.3.0
pyOpenSSL==21.0.0
pyparsing==2.4.7
pyrsistent==0.18.1
pyserial==3.5
python-apt==2.4.0+ubuntu1
python-debian===0.1.43ubuntu1
python-magic==0.4.24
pytz==2022.1
PyYAML==5.4.1
regex==2024.4.28
requests==2.29.0
rich==13.7.1
safetensors==0.4.3
scipy==1.13.0
SecretStorage==3.3.1
service-identity==18.1.0
six==1.16.0
sos==4.4
ssh-import-id==5.11
systemd-python==234
tensorstore==0.1.58
tokenizers==0.19.1
toolz==0.12.1
tqdm==4.66.2
transformers==4.40.1
Twisted==22.1.0
typing_extensions==4.11.0
ubuntu-advantage-tools==8001
ufw==0.36.1
unattended-upgrades==0.1
urllib3==1.26.5
virtualenv==20.23.0
wadllib==1.3.6
zipp==1.0.0
zope.interface==5.4.0
Show best practices for SDXL JAX
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
import time
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", split_head_dim=True
)
scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state
default_prompt = "high-quality photo of a baby dolphin playing in a pool and wearing a party hat"
default_neg_prompt = "illustration, low-quality"
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = 25
def tokenize_prompt(prompt, neg_prompt):
prompt_ids = pipeline.prepare_inputs(prompt)
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
return prompt_ids, neg_prompt_ids
NUM_DEVICES = jax.device_count()
# Model parameters don't change during inference,
# so we only need to replicate them once.
p_params = replicate(params)
def replicate_all(prompt_ids, neg_prompt_ids, seed):
p_prompt_ids = replicate(prompt_ids)
p_neg_prompt_ids = replicate(neg_prompt_ids)
rng = jax.random.PRNGKey(seed)
rng = jax.random.split(rng, NUM_DEVICES)
return p_prompt_ids, p_neg_prompt_ids, rng
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
neg_prompt_ids=neg_prompt_ids,
guidance_scale=guidance_scale,
jit=True,
).images
# convert the images to PIL
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")
for i in range(100):
start = time.time()
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt)
print(f"Inference in {time.time() - start}")
Metadata
Metadata
Assignees
Labels
No labels