Skip to content

Commit 6875c8e

Browse files
authored
: Adds DSP op for RoPE
Differential Revision: D75605145 Pull Request resolved: pytorch#11264
1 parent 70f1456 commit 6875c8e

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,13 @@
167167
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
168168
)
169169

170+
lib.define(
171+
"rope(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos) -> (Tensor out)"
172+
)
173+
lib.define(
174+
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
175+
)
176+
170177
# ------------------------------------ #
171178
# Migrated from custom_ops.yaml #
172179
# ------------------------------------ #
@@ -954,3 +961,29 @@ def where_Scalar_meta(
954961
other: float,
955962
) -> torch.Tensor:
956963
return condition.new_empty(condition.size(), dtype=torch.float32)
964+
965+
966+
@register_fake("cadence::rope")
967+
def rope_meta(
968+
input: torch.Tensor,
969+
sin_tensor: torch.Tensor,
970+
cos_tensor: torch.Tensor,
971+
pos: Optional[torch.Tensor],
972+
) -> torch.Tensor:
973+
input_shape = list(input.shape)
974+
assert (
975+
len(input_shape) in (4, 5) and input_shape[0] == 1
976+
), f"input shape {input_shape} must be (1, seq, h, hd) or (1, seq, h, hd / 2, 2)"
977+
seq = input_shape[1]
978+
h = input_shape[2]
979+
hd = prod(input_shape) / (seq * h)
980+
sin_shape = list(sin_tensor.shape)
981+
cos_shape = list(cos_tensor.shape)
982+
assert sin_shape == cos_shape, f"{sin_shape=} must be same as {cos_shape}"
983+
assert (
984+
len(sin_shape) == 2 and sin_shape[-1] == hd // 2
985+
), f"{sin_shape=} must be [seq, hd/2]"
986+
assert (
987+
pos is None or len(pos.shape) == 1 and pos.shape[0] == seq
988+
), f"{pos.shape} must be [{seq}]"
989+
return input.new_empty(input.shape, dtype=input.dtype)

0 commit comments

Comments
 (0)