Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from jax.experimental import checkify
import jax.extend as jex
from jax.extend import sharding as jex_sharding
from jax.lib import xla_client
import jax.numpy as jnp
import jax.scipy.special
import numpy as np
Expand All @@ -33,7 +32,6 @@
from tf2jax._src import utils
from tf2jax._src import xla_utils


ArrayLike = Union[np.ndarray, jnp.ndarray]

# NoOp inserted to trigger side effects in function with no return values.
Expand Down Expand Up @@ -2547,11 +2545,10 @@ def _xla_sharding(proto):
if not sharding_str:
return lambda x: x

sharding = xla_client.OpSharding()
if sharding_v2 and sharding_v2.s:
sharding.ParseFromString(sharding_v2.s)
sharding = jex.sharding.get_op_sharding_from_serialized_proto(sharding_v2.s)
else:
sharding.ParseFromString(sharding_str)
sharding = jex.sharding.get_op_sharding_from_serialized_proto(sharding_str)

# TODO(shaobohou): Replace with jax.sharding.NamedSharding as GSPMDSharding is
# deprecated.
Expand Down
Loading