diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py
index b7298e1..44b9406 100644
--- a/jetstream_pt/engine.py
+++ b/jetstream_pt/engine.py
@@ -28,8 +28,7 @@
 import torch
 import numpy as np
 
-from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils
-from jetstream.engine import sampling_utils
+from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils, sampling_utils
 import torch_xla2
 from torch.utils import _pytree as pytree
 
@@ -44,6 +43,7 @@
 from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model
 
 from absl import flags
+from collections.abc import Callable
 
 FLAGS = flags.FLAGS
 
@@ -60,6 +60,7 @@ class Prefix:
   token: jax.Array  # [1, seqlen]
   caches: List[Tuple[jax.Array, jax.Array]]
   seq_len: int  # true seqlen front pad
+  sampler: List[Any] | int  # User defined Sampler
 
 
 @struct.dataclass
@@ -73,8 +74,12 @@ class DecodeState:
   current_position: int
   lens: jax.Array  # [batch_size, 1], the output token length
   start: jax.Array  # [batch_size, 1], the starting pos for each slot
-  input_pos: jax.Array  # [batch_size, 1] input pos for each slot
+  input_pos: (
+      jax.Array
+  )  # [batch_size, 1] total (prefill + decode) length for each slot
   mask: jax.Array  # [batch_size, seqlen] -inf for invalid; 0 for valid
+  # The sampling function
+  samplers: Any
 
 
 # NOTE model specific
@@ -93,7 +98,8 @@ def __init__(
     self.pt_model = pt_model
     self.env = env
     self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
-    self.rng = jax.random.PRNGKey(0)
+    self.rng = jax.random.key(0)
+
     self.weights = weights
 
     self.y_sharding = env.sharding_by_axis(1)
@@ -119,6 +125,7 @@ def __init__(
         donate_argnums=(1,),
         out_shardings=(self.get_decode_state_sharding(), None),
     )
+    # self.generate = self.generate_impl
 
     if self.env.page_attention:
       max_pages_per_sequence = (
@@ -168,6 +175,7 @@ def init_decode_state(
     scalers = []
     if self.env.quant_config.enable_kv_quantization:
       scalers = [c.scalers() for c in caches_obj]
+
     return DecodeState(
         jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32),
         caches,
@@ -181,6 +189,7 @@ def init_decode_state(
             float("-inf"),
             dtype=self.default_dtype,
         ),  # mask
+        None,
     )
 
   # pylint: disable-next=all
@@ -280,19 +289,42 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
     caches_res = [c.state() for c in caches]
     return torchjax.from_torch((res, caches_res))
 
-  def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
+  # Temporarily disabled becuase handling per request sampling is not ready yet.
+  # @classmethod
+  # def _custom_sampling(self, logits, samplers) -> jnp.ndarray:
+  #   if len(logits.shape) == 2:
+  #     logits = jnp.expand_dims(logits, 0)
+
+  #   logits = logits[:, -1]
+
+  #   # Prefill and Generate have different batch size
+  #   current_batch_size = logits.shape[0]
+
+  #   idx = jnp.arange(current_batch_size)
+  #   apply_sampler = lambda i, l: jax.lax.switch(i, samplers, l)
+  #   apply_vmap = jax.vmap(apply_sampler, in_axes=(0, 0))
+  #   return apply_vmap(idx, logits).reshape(current_batch_size, -1)
+
+
+  def _sampling(
+      self, logits: Any, algorithm, rng, temperature, topk, nucleus_topp
+  ) -> jnp.ndarray:
     if len(logits.shape) == 2:
       logits = jnp.expand_dims(logits, 0)
+
+    logits = logits[:, -1]
+    current_batch_size = logits.shape[0]
+
     return (
         sampling_utils.sampling(
-            logits[:, -1],
-            self.rng,
-            self.env.sampling_algorithm,
-            self.env.topk,
-            self.env.nucleus_topp,
-            self.env.temperature,
+            logits=logits,
+            rng=rng,
+            algorithm=algorithm,
+            topk=topk,
+            nucleus_topp=nucleus_topp,
+            temperature=temperature,
         )
-        .reshape(batch_size, -1)
+        .reshape(current_batch_size, -1)
         .astype(jnp.int32)
     )
 
@@ -301,7 +333,7 @@ def prefill(
       *,
       params: Any,  # Weights
       existing_prefix: Optional[Prefix] = None,
-      padded_tokens: PrefillInputs,  # PrefillInputs[jax.Array],
+      padded_tokens: PrefillInputs,  # PrefillInputs[jax.Array]
       true_length: int,
       sampler: Optional[Callable[[Any], Any]] = None,
   ) -> Tuple[Prefix, engine_api.ResultTokens]:
@@ -321,6 +353,7 @@ def prefill(
     )
     if len(logits.shape) == 3:  # b, seqlen, num words
       logits = logits[0]  # seqlen, num words
+
     if sampler:
       token = sampler(logits[true_length - 1])
     else:
@@ -332,6 +365,7 @@ def prefill(
           self.env.nucleus_topp,
           self.env.temperature,
       )
+    token = jnp.reshape(token, (1,))
     token_out = jnp.reshape(token, (1, 1))
     data = jnp.concatenate(
         [
@@ -357,7 +391,10 @@ def prefill(
     #       v, seq_len - true_length, true_length, axis=2))
     #   for k, v in updated_caches
     # ]
-    return Prefix(token, updated_caches, true_length), result
+    return (
+        Prefix(token, updated_caches, true_length, sampler),
+        result,
+    )
 
   def shrink_prefix(
       self,
@@ -476,6 +513,8 @@ def insert(cache, scaler, new_entry, update_index):
           caches.append((kcache, vcache))
           scales.append((kscale, vscale))
     lens = decode_state.lens.at[slot].set(1)
+
+    sampler = prefix.sampler if prefix.sampler else decode_state.samplers
     return DecodeState(
         tokens,
         caches,
@@ -485,6 +524,7 @@ def insert(cache, scaler, new_entry, update_index):
         start,
         input_pos,
         mask,
+        sampler,
     )
 
   # pylint: disable-next=all
@@ -569,6 +609,9 @@ def insert(cache, scaler, new_entry):
         scales.append((kscale, vscale))
 
     lens = decode_state.lens.at[slot].set(1)
+
+    sampler = prefix.sampler if prefix.sampler else decode_state.samplers
+
     return DecodeState(
         tokens,
         caches,
@@ -578,6 +621,7 @@ def insert(cache, scaler, new_entry):
         start,
         input_pos,
         mask,
+        sampler,
     )
 
   def _insert_page_attention(
@@ -613,6 +657,8 @@ def _insert_page_attention(
     input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len)
     scales = None
     lens = decode_state.lens.at[slot].set(1)
+
+    sampler = prefix.sampler if prefix.sampler else decode_state.samplers
     return DecodeState(
         tokens,
         caches,
@@ -622,6 +668,7 @@ def _insert_page_attention(
         start,
         input_pos,
         mask,
+        sampler,
     )
 
   def insert(
@@ -729,7 +776,9 @@ def false_comp(b, i, bk, start, end):
     return b_next, i_next
 
   def generate(
-      self, params: Any, decode_state: DecodeState, sampler=None
+      self,
+      params: Any,
+      decode_state: DecodeState,
   ) -> tuple[DecodeState, engine_api.ResultTokens]:
     return (None, None)
 
@@ -752,7 +801,6 @@ def generate_impl(
       self,
       params: Any,
       decode_state: DecodeState,
-      sampler=None,
       page_token_indices=None,
   ) -> tuple[DecodeState, engine_api.ResultTokens]:
     # seq_len = padded_tokens.shape[0]
@@ -764,12 +812,16 @@ def generate_impl(
     else:
       input_indexes = decode_state.input_pos
 
-    ragged_batch_index, ragged_block_index = (
-        self.precompute_ragged_block_indices(decode_state)
-    )
-    ragged_batch_index, ragged_block_index = ragged_batch_index.reshape(
-        (-1)
-    ), ragged_block_index.reshape((-1))
+    # TODO(lancewang): Remove ragged index precomputation
+    # ragged_batch_index, ragged_block_index = (
+    #     self.precompute_ragged_block_indices(decode_state)
+    # )
+    # ragged_batch_index, ragged_block_index = ragged_batch_index.reshape(
+    #     (-1)
+    # ), ragged_block_index.reshape((-1))
+
+    ragged_batch_index = 0
+    ragged_block_index = 0
 
     def update_mask():
       if self.env.ring_buffer:
@@ -799,10 +851,20 @@ def update_mask():
       # fill mask later, now use flash attention
       mask = update_mask()
 
-    if sampler:
-      next_token = sampler(logits[:, -1])
+    # Temporarily disabled becuase handling per request sampling is not ready yet.
+    # next_token = self._custom_sampling(logits, decode_state.samplers)
+    if decode_state.samplers:
+      next_token = decode_state.samplers(logits)
     else:
-      next_token = self._sampling(logits, self.env.batch_size)
+      next_token = self._sampling(
+          logits,
+          self.env.sampling_algorithm,
+          self.rng,
+          self.env.temperature,
+          self.env.topk,
+          self.env.nucleus_topp,
+      )
+
     if self.env.ring_buffer:
       input_pos = decode_state.input_pos + 1
       lens = decode_state.lens + 1
@@ -844,6 +906,7 @@ def update_mask():
         decode_state.start,
         input_pos,
         mask,
+        decode_state.samplers,
     )
     return new_decode_state, result_tokens
 
@@ -963,6 +1026,7 @@ def get_prefix_destination_sharding(self) -> Prefix:
         if self.env.page_attention
         else self.cache_sharding,
         self.replicated,
+        self.replicated,
     )
 
   def get_decode_state_sharding(self) -> DecodeState:
@@ -976,6 +1040,7 @@ def get_decode_state_sharding(self) -> DecodeState:
         self.replicated,
         self.replicated,
         self.replicated,
+        self.replicated,
     )
 
   def get_prefix_sequence_ddim(self) -> Any:
diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py
index 4917705..0202711 100644
--- a/jetstream_pt/environment.py
+++ b/jetstream_pt/environment.py
@@ -52,7 +52,10 @@ class JetEngineEnvironmentData:
   batch_size: int = 32  # batch size is generate step batch size
   cache_sequence_length: int = 2048  # size of the cache.
 
-  quant_config: QuantizationConfig = QuantizationConfig()
+  # quant_config: QuantizationConfig = QuantizationConfig()
+  quant_config: QuantizationConfig = dataclasses.field(
+      default_factory=QuantizationConfig
+  )
 
   model_type: str = "llama-2-13b"  # this implies the model config
 
diff --git a/run_interactive.py b/run_interactive.py
index 8463658..2e00159 100644
--- a/run_interactive.py
+++ b/run_interactive.py
@@ -23,6 +23,7 @@
 from absl import app
 from jetstream.engine import token_utils
 from jetstream_pt.config import FLAGS, create_engine_from_config_flags
+from jetstream.engine import sampling_utils
 
 
 # pylint: disable-next=all
@@ -30,6 +31,21 @@ def main(argv):
 
   engine = create_engine_from_config_flags()
 
+  rng = jax.random.key(1)
+  temperature = 1
+  topk = 1
+  topp = 0.2
+
+  sampler = jax.tree_util.Partial(
+      sampling_utils.jittable_sample_topk_logits,
+      rng=rng,
+      temperature=temperature,
+      topk=topk,
+  )
+  # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_topp_logits, rng=rng, temperature=temperature, topp=topp)
+  # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_greedy_logits)
+  # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_weighted_logits, rng=rng, temperature=temperature)
+
   start = time.perf_counter()
   params = engine.load_params()
   print("Load params ", time.perf_counter() - start)
@@ -77,7 +93,10 @@ def main(argv):
       jax.profiler.start_trace(profiling_output)
 
     prefill_result, _ = engine.prefill(
-        params=params, padded_tokens=tokens, true_length=true_length
+        params=params,
+        padded_tokens=tokens,
+        true_length=true_length,
+        sampler=sampler
     )
     # pylint: disable-next=all
     decode_state = engine.insert(prefill_result, decode_state, slot=slot)
diff --git a/tests/helpers.py b/tests/helpers.py
index ac0ea5f..860389d 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -7,7 +7,9 @@
 
 
 # pylint: disable-next=all
-def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None):
+def make_env_tiny(
+    bf16_enable=True, env_data_update_fn=lambda _: None, batch_size=1
+):
   torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
   torch.set_default_dtype(torch_dtype)
   jax.config.update("jax_dynamic_shapes", False)
@@ -19,7 +21,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None):
   environment_data.cache_sequence_length = 128
   environment_data.bf16_enable = bf16_enable
   environment_data.model_type = "llama-2-tiny"
-  environment_data.batch_size = 1
+  environment_data.batch_size = batch_size
   environment_data.num_layers = config.n_layers
   environment_data.cache_shape = (
       1,
diff --git a/tests/test_engine.py b/tests/test_engine.py
index 57245c0..9704ce1 100644
--- a/tests/test_engine.py
+++ b/tests/test_engine.py
@@ -20,16 +20,64 @@
 
 from jetstream_pt.third_party.llama import model_exportable
 from jetstream_pt.engine import PyTorchEngine
+from jetstream_pt.engine import DecodeState
+from jetstream_pt.engine import Prefix
 from tests import helpers
+from jetstream_pt import cache_manager
+# from jetstream_pt.engine import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler
+# from jetstream.engine.sampling_util import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler
+from jetstream.engine.sampling_utils import jittable_sample_greedy_logits, jittable_sample_topp_logits, jittable_sample_topk_logits, jittable_sample_weighted_logits
+
+
+class MockEngine(PyTorchEngine):
+
+  def _call_model_prefill(self, weights, tokens, input_indexes):
+    caches = [
+        cache_manager.KVCachePrefill(
+            self.env.quant_config.enable_kv_quantization
+        )
+        for _ in self.pt_model.layers
+    ]
+    # logits = jnp.ones((self.env.batch_size, 1), jnp.float32)
+    assert (
+        self.env.batch_size == 1
+    ), f"The batch size {self.env.batch_size} != 1"
+    logits = jnp.array([[0.5, 0.6, 0.7, 0.8]])
+    return logits, caches
+
+  def _call_model_generate(
+      self,
+      weights,
+      tokens,
+      input_indexes,
+      caches,
+      cache_scales,
+      mask,
+      start,
+      input_pos,
+      ragged_batch_index,
+      ragged_block_index,
+      page_token_indices,
+  ):
+    logits = jnp.array(
+        [
+            [[0.5, 0.6, 0.7, 0.8]],
+            # [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]],
+            [[0.4, 0.3, 0.2, 0.1]],
+        ]
+    )
+    return logits, caches, cache_scales
 
 
 class EngineTest(unittest.TestCase):
 
-  def setup(self):
-    env, model_arg = helpers.make_env_tiny(bf16_enable=True)
+  def setup(self, batch_size=1):
+    env, model_arg = helpers.make_env_tiny(
+        bf16_enable=True, batch_size=batch_size
+    )
     model_ours = model_exportable.Transformer(model_arg, env)
-    engine = PyTorchEngine(pt_model=model_ours, env=env)
-    engine.rng = jax.random.PRNGKey(0)
+    engine = MockEngine(pt_model=model_ours, env=env)
+    engine.rng = jax.random.key(0)
     return engine
 
   def test_sampling_2D(self):
@@ -37,14 +85,23 @@ def test_sampling_2D(self):
     engine = self.setup()
     self.assertEqual(engine.env.sampling_algorithm, "greedy")
     logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]])
-    token = engine._sampling(logits, batch_size=1)
+    token = engine._sampling(
+        logits, "greedy", engine.rng, temperature=1.0, topk=1, nucleus_topp=0.0
+    )
     self.assertEqual(token, jnp.array([[0]]))
     self.assertTrue(jnp.isdtype(token, jnp.int32))
 
     # test weighted
     engine.env.sampling_algorithm = "weighted"
     engine.env.temperature = 5.0
-    token = engine._sampling(logits, batch_size=1)
+    token = engine._sampling(
+        logits,
+        engine.env.sampling_algorithm,
+        engine.rng,
+        temperature=5.0,
+        topk=1,
+        nucleus_topp=0.0,
+    )
     self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
     self.assertTrue(jnp.isdtype(token, jnp.int32))
 
@@ -52,21 +109,36 @@ def test_sampling_2D(self):
     engine.env.sampling_algorithm = "topk"
     engine.env.temperature = 5.0
     engine.env.topk = 4
-    token = engine._sampling(logits, batch_size=1)
+    token = engine._sampling(
+        logits,
+        engine.env.sampling_algorithm,
+        engine.rng,
+        temperature=5.0,
+        topk=4,
+        nucleus_topp=0.0,
+    )
     self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
     self.assertTrue(jnp.isdtype(token, jnp.int32))
 
     # test nucleus
     engine.env.sampling_algorithm = "nucleus"
-    engine.env.temperature = 0.0
+    engine.env.temperature = 1.0
     engine.env.nucleus_topp = 0.8
-    token = engine._sampling(logits, batch_size=1)
+    token = engine._sampling(
+        logits,
+        engine.env.sampling_algorithm,
+        engine.rng,
+        temperature=0.0,
+        topk=1,
+        nucleus_topp=0.8,
+    )
     self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
     self.assertTrue(jnp.isdtype(token, jnp.int32))
 
   def test_sampling_3D(self):
     # test greedy
-    engine = self.setup()
+    engine = self.setup(batch_size=2)
+
     self.assertEqual(engine.env.sampling_algorithm, "greedy")
     logits = jnp.array(
         [
@@ -74,14 +146,28 @@ def test_sampling_3D(self):
             [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]],
         ]
     )
-    token = engine._sampling(logits, batch_size=2)
+    token = engine._sampling(
+        logits,
+        engine.env.sampling_algorithm,
+        engine.rng,
+        engine.env.temperature,
+        engine.env.topk,
+        engine.env.nucleus_topp,
+    )
     self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]])))
     self.assertTrue(jnp.isdtype(token, jnp.int32))
 
     # test weighted
     engine.env.sampling_algorithm = "weighted"
     engine.env.temperature = 10.0
-    token = engine._sampling(logits, batch_size=2)
+    token = engine._sampling(
+        logits,
+        engine.env.sampling_algorithm,
+        engine.rng,
+        engine.env.temperature,
+        engine.env.topk,
+        engine.env.nucleus_topp,
+    )
     self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]])))
     self.assertTrue(jnp.isdtype(token, jnp.int32))
 
@@ -89,7 +175,14 @@ def test_sampling_3D(self):
     engine.env.sampling_algorithm = "topk"
     engine.env.temperature = 1.0
     engine.env.topk = 3
-    token = engine._sampling(logits, batch_size=2)
+    token = engine._sampling(
+        logits,
+        engine.env.sampling_algorithm,
+        engine.rng,
+        engine.env.temperature,
+        engine.env.topk,
+        engine.env.nucleus_topp,
+    )
     self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]])))
     self.assertTrue(jnp.isdtype(token, jnp.int32))
 
@@ -97,10 +190,343 @@ def test_sampling_3D(self):
     engine.env.sampling_algorithm = "nucleus"
     engine.env.temperature = 1.0
     engine.env.nucleus_topp = 0.8
-    token = engine._sampling(logits, batch_size=2)
+    token = engine._sampling(
+        logits,
+        engine.env.sampling_algorithm,
+        engine.rng,
+        engine.env.temperature,
+        engine.env.topk,
+        engine.env.nucleus_topp,
+    )
     self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]])))
     self.assertTrue(jnp.isdtype(token, jnp.int32))
 
+    # Temporarily disabled becuase handling per request sampling is not ready yet.
+#   def test_custom_sampling_3D(self):
+#     engine = self.setup(batch_size=2)
+#     rng = jax.random.key(3)
+
+#     engine.env.sampling_algorithm = ""
+
+#     # Need a different engine of batch size of 1 to reshape the output
+#     rng_b1 = jax.random.key(3)
+#     logits = jnp.array(
+#         [
+#             [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]],
+#             [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]],
+#         ]
+#     )
+
+#     # test greedy
+#     sampler = jittable_sample_greedy_logits
+#     samplers = [sampler, sampler]
+#     token = engine._custom_sampling(logits, samplers)
+
+#     original_tokens = []
+#     for i in range(2):
+#       original_token = engine._sampling(
+#           logits[i],
+#           "greedy",
+#           rng=rng,
+#           temperature=0.0,
+#           topk=0,
+#           nucleus_topp=0.0,
+#       )
+#       original_tokens.append(original_token)
+#     original_tokens = jnp.concatenate(original_tokens)
+
+#     print(f"custom sampling token {token} vs original tokens {original_tokens}")
+#     self.assertTrue(jnp.array_equal(token, original_tokens))
+#     self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#     # test weighted
+#     sampler1 = jax.tree_util.Partial(
+#         jittable_sample_weighted_logits, rng=rng, temperature=1.0
+#     )
+#     sampler2 = jax.tree_util.Partial(
+#         jittable_sample_weighted_logits, rng=rng, temperature=1.0
+#     )
+#     samplers = [sampler1, sampler2]
+#     token = engine._custom_sampling(logits, samplers)
+
+#     original_tokens = []
+#     for i in range(2):
+#       rng_b1, sub_rng = jax.random.split(rng_b1)
+#       original_token = engine._sampling(
+#           logits[i],
+#           "weighted",
+#           rng,
+#           temperature=1,
+#           topk=0,
+#           nucleus_topp=0.0,
+#       )
+#       original_tokens.append(original_token)
+#     original_tokens = jnp.concatenate(original_tokens)
+
+#     print(f"custom sampling token {token} vs original tokens {original_tokens}")
+#     self.assertTrue(jnp.array_equal(token, original_tokens))
+#     self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#     # # test topk
+#     sampler1 = jax.tree_util.Partial(
+#         jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3
+#     )
+#     sampler2 = jax.tree_util.Partial(
+#         jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3
+#     )
+#     samplers = [sampler1, sampler2]
+#     token = engine._custom_sampling(logits, samplers)
+
+#     original_tokens = []
+#     for i in range(2):
+#       #   rng_b1, sub_rng = jax.random.split(rng_b1)
+#       sub_rng = rng
+#       original_token = engine._sampling(
+#           logits[i],
+#           "topk",
+#           rng=sub_rng,
+#           temperature=1.0,
+#           topk=3,
+#           nucleus_topp=0.0,
+#       )
+#       original_tokens.append(original_token)
+#     original_tokens = jnp.concatenate(original_tokens)
+
+#     print(f"custom sampling token {token} vs original tokens {original_tokens}")
+#     self.assertTrue(jnp.array_equal(token, original_tokens))
+#     self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#     # test nucleus
+#     sampler1 = jax.tree_util.Partial(
+#         jittable_sample_topp_logits, rng=rng, temperature=1.0, topp=0.8
+#     )
+#     sampler2 = jax.tree_util.Partial(
+#         jittable_sample_topp_logits, rng=rng, temperature=1.0, topp=0.8
+#     )
+#     samplers = [sampler1, sampler2]
+#     token = engine._custom_sampling(logits, samplers)
+
+#     original_tokens = []
+#     for i in range(2):
+#       original_token = engine._sampling(
+#           logits[i],
+#           "nucleus",
+#           rng,
+#           temperature=1.0,
+#           topk=0,
+#           nucleus_topp=0.8,
+#       )
+#       original_tokens.append(original_token)
+#     original_tokens = jnp.concatenate(original_tokens)
+#     print(f"custom sampling token {token} vs original tokens {original_tokens}")
+#     self.assertTrue(jnp.array_equal(token, original_tokens))
+#     self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#     # # test topk + greedy
+#     sampler1 = jax.tree_util.Partial(
+#         jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3
+#     )
+#     sampler2 = jax.tree_util.Partial(jittable_sample_greedy_logits)
+#     samplers = [sampler1, sampler2]
+#     token = engine._custom_sampling(logits, samplers)
+
+#     original_tokens = []
+#     i = 0
+#     original_token = engine._sampling(
+#         logits[i],
+#         "topk",
+#         rng,
+#         temperature=1.0,
+#         topk=3,
+#         nucleus_topp=0.8,
+#     )
+#     original_tokens.append(original_token)
+
+#     i = 1
+#     original_token = engine._sampling(
+#         logits[i],
+#         "greedy",
+#         rng,
+#         temperature=0.0,
+#         topk=0,
+#         nucleus_topp=0.0,
+#     )
+#     original_tokens.append(original_token)
+
+#     original_tokens = jnp.concatenate(original_tokens)
+
+#     print(f"custom sampling token {token} vs original tokens {original_tokens}")
+#     self.assertTrue(jnp.array_equal(token, original_tokens))
+#     self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#   # test Prefill
+#   def test_prefill_with_custom_sampling(self):
+#     engine = self.setup()
+#     engine.rng = jax.random.key(3)
+
+#     engine.env.sampling_algorithm = ""
+
+#     # Inputs doesn't matter
+#     params = jnp.zeros((1,), jnp.float32)
+#     padded_tokens = jnp.zeros((1,), jnp.float32)
+#     true_length = 1
+
+#     # Greedy
+#     sampler = jax.tree_util.Partial(jittable_sample_greedy_logits)
+#     prefix, _ = engine.prefill(
+#         params=params,
+#         padded_tokens=padded_tokens,
+#         true_length=true_length,
+#         sampler=sampler,
+#     )
+#     token = prefix.token
+#     print(f"Greedy output: {token}")
+#     self.assertTrue(jnp.array_equal(token, jnp.array([3])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#     # Weighted
+#     sampler = jax.tree_util.Partial(
+#         jittable_sample_weighted_logits, rng=engine.rng, temperature=1.0
+#     )
+#     prefix, _ = engine.prefill(
+#         params=params,
+#         padded_tokens=padded_tokens,
+#         true_length=true_length,
+#         sampler=sampler,
+#     )
+#     token = prefix.token
+#     print(f"Weighted output: {token}")
+#     self.assertTrue(jnp.array_equal(token, jnp.array([2])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#     # Nucleus
+#     sampler = jax.tree_util.Partial(
+#         jittable_sample_topp_logits, rng=engine.rng, temperature=1.0, topp=0.8
+#     )
+#     prefix, _ = engine.prefill(
+#         params=params,
+#         padded_tokens=padded_tokens,
+#         true_length=true_length,
+#         sampler=sampler,
+#     )
+#     token = prefix.token
+#     print(f"Nucleus output: {token}")
+#     self.assertTrue(jnp.array_equal(token, jnp.array([2])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#     # Topk
+#     sampler = jax.tree_util.Partial(
+#         jittable_sample_topk_logits, rng=engine.rng, temperature=1.0, topk=3
+#     )
+
+#     prefix, _ = engine.prefill(
+#         params=params,
+#         padded_tokens=padded_tokens,
+#         true_length=true_length,
+#         sampler=sampler,
+#     )
+#     token = prefix.token
+#     print(f"Topk output: {token}")
+#     self.assertTrue(jnp.array_equal(token, jnp.array([1])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
+#   def test_insert_no_wrap_with_custom_sampling(self):
+#     engine = self.setup()
+#     engine.env.sampling_algorithm = ""
+#     engine.env.batch_size = 2
+#     cache_shape = engine.env.cache_shape
+
+#     prefill_cache_shape = (1, cache_shape[1], 16, cache_shape[3])
+#     prefill_cache = []
+#     for _ in range(engine.env.num_layers):
+#       prefill_cache.append(
+#           (
+#               jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16),
+#               jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16),
+#           )
+#       )
+
+#     sampler = jittable_sample_greedy_logits
+#     prefix = Prefix(
+#         token=jnp.ones((1)),
+#         caches=prefill_cache,
+#         seq_len=16,
+#         sampler=sampler,
+#     )
+
+#     doesnt_matter = jnp.array([0])
+#     kv_cache = engine.env.make_caches_generate()
+#     kv_cache = [c.state() for c in kv_cache]
+
+#     base_sampler = jax.tree_util.Partial(
+#         engine._sampling,
+#         algorithm=engine.env.sampling_algorithm,
+#         rng=engine.rng,
+#         temperature=engine.env.temperature,
+#         topk=engine.env.topk,
+#         nucleus_topp=engine.env.nucleus_topp,
+#     )
+#     decode_state = DecodeState(
+#         tokens=jnp.zeros((engine.env.batch_size, 1)),
+#         caches=kv_cache,
+#         cache_scales=[doesnt_matter],
+#         current_position=16,
+#         lens=jnp.zeros((engine.env.batch_size, 1)),
+#         start=jnp.zeros((engine.env.batch_size, 1)),
+#         input_pos=jnp.zeros((engine.env.batch_size,)),
+#         mask=jnp.zeros((engine.env.batch_size, 128)),
+#         # samplers = [base_sampler] * engine.env.batch_size
+#         samplers=None,
+#     )
+
+#     # Insert to slot 1
+#     result_decode_state = engine._insert_no_wrap(prefix, decode_state, slot=1)
+
+#     self.assertAlmostEqual(
+#         result_decode_state.tokens.all(), decode_state.tokens.all()
+#     )
+#     self.assertEqual(result_decode_state.samplers, prefix.sampler)
+
+#   def test_generate_with_custom_sampling(self):
+#     engine = self.setup(batch_size=2)
+#     engine.rng = jax.random.key(3)
+#     engine.env.sampling_algorithm = ""
+
+#     # Inputs doesn't matter
+#     doesnt_matter = jnp.array([0])
+#     params = doesnt_matter
+
+#     greedy_sampler = jax.tree_util.Partial(jittable_sample_greedy_logits)
+#     weighted_sampler = jax.tree_util.Partial(
+#         jittable_sample_weighted_logits, rng=engine.rng, temperature=1.0
+#     )
+#     decode_state = DecodeState(
+#         tokens=jnp.zeros((engine.env.batch_size, 1)),
+#         caches=[doesnt_matter],
+#         cache_scales=[doesnt_matter],
+#         current_position=0,
+#         lens=jnp.zeros((engine.env.batch_size, 1)),
+#         start=doesnt_matter,
+#         input_pos=jnp.zeros((engine.env.batch_size,)),
+#         mask=jnp.zeros((engine.env.batch_size, 1)),
+#         samplers=weighted_sampler,
+#     )
+
+#     # Topk + Weighted
+#     # algorithm, temperature, topk, nucleus_topp
+#     decode_state, _ = engine.generate_impl(
+#         params=params, decode_state=decode_state
+#     )
+#     token = decode_state.tokens
+#     print(f"Topk + Weighted output: {token}")
+#     self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]])))
+#     self.assertTrue(jnp.isdtype(token, jnp.int32))
+
 
 #     def test_insert(self):
 #         seqlen = 32