|
167 | 167 | "where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
|
168 | 168 | )
|
169 | 169 |
|
| 170 | +lib.define( |
| 171 | + "rope(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos) -> (Tensor out)" |
| 172 | +) |
| 173 | +lib.define( |
| 174 | + "rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)" |
| 175 | +) |
| 176 | + |
170 | 177 | # ------------------------------------ #
|
171 | 178 | # Migrated from custom_ops.yaml #
|
172 | 179 | # ------------------------------------ #
|
@@ -954,3 +961,29 @@ def where_Scalar_meta(
|
954 | 961 | other: float,
|
955 | 962 | ) -> torch.Tensor:
|
956 | 963 | return condition.new_empty(condition.size(), dtype=torch.float32)
|
| 964 | + |
| 965 | + |
| 966 | +@register_fake("cadence::rope") |
| 967 | +def rope_meta( |
| 968 | + input: torch.Tensor, |
| 969 | + sin_tensor: torch.Tensor, |
| 970 | + cos_tensor: torch.Tensor, |
| 971 | + pos: Optional[torch.Tensor], |
| 972 | +) -> torch.Tensor: |
| 973 | + input_shape = list(input.shape) |
| 974 | + assert ( |
| 975 | + len(input_shape) in (4, 5) and input_shape[0] == 1 |
| 976 | + ), f"input shape {input_shape} must be (1, seq, h, hd) or (1, seq, h, hd / 2, 2)" |
| 977 | + seq = input_shape[1] |
| 978 | + h = input_shape[2] |
| 979 | + hd = prod(input_shape) / (seq * h) |
| 980 | + sin_shape = list(sin_tensor.shape) |
| 981 | + cos_shape = list(cos_tensor.shape) |
| 982 | + assert sin_shape == cos_shape, f"{sin_shape=} must be same as {cos_shape}" |
| 983 | + assert ( |
| 984 | + len(sin_shape) == 2 and sin_shape[-1] == hd // 2 |
| 985 | + ), f"{sin_shape=} must be [seq, hd/2]" |
| 986 | + assert ( |
| 987 | + pos is None or len(pos.shape) == 1 and pos.shape[0] == seq |
| 988 | + ), f"{pos.shape} must be [{seq}]" |
| 989 | + return input.new_empty(input.shape, dtype=input.dtype) |
0 commit comments