diff --git a/backends/cadence/generic/operators/op_linalg_svd.cpp b/backends/cadence/generic/operators/op_linalg_svd.cpp new file mode 100644 index 00000000000..8f3ea8583ae --- /dev/null +++ b/backends/cadence/generic/operators/op_linalg_svd.cpp @@ -0,0 +1,359 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +#include +#include +#include + +#include +#include +#include + +const float EPSILON = 1e-10; +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +namespace impl { +namespace generic { +namespace native { +namespace { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +// A simple 3x3 matrix struct. +struct Matrix3x3 { + float m[3][3]; +}; + +// Returns the 3x3 identity matrix. +Matrix3x3 identityMatrix() { + Matrix3x3 I{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + I.m[i][j] = (i == j) ? 1.0 : 0.0; + } + } + return I; +} + +// Transposes matrix A. +Matrix3x3 transpose(const Matrix3x3& A) { + Matrix3x3 At{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + At.m[i][j] = A.m[j][i]; + } + } + return At; +} + +// Multiplies matrices A and B. +Matrix3x3 multiply(const Matrix3x3& A, const Matrix3x3& B) { + Matrix3x3 C{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + C.m[i][j] = 0.0; + for (int k = 0; k < 3; k++) { + C.m[i][j] += A.m[i][k] * B.m[k][j]; + } + } + } + return C; +} + +// Jacobi method to compute the eigen-decomposition of a symmetric 3x3 matrix A. +// It outputs the eigenvalues (in 'diag') and the eigenvectors as columns in V. +void jacobiEigenDecomposition(const Matrix3x3& A, float diag[3], Matrix3x3& V) { + Matrix3x3 D = A; // Make a copy; D will be transformed into a diagonal matrix. + V = identityMatrix(); + + // Iterate until convergence (or max iterations) + for (int iter = 0; iter < 100; iter++) { + // Find the largest off-diagonal element in D. + int p = 0, q = 1; + float maxOff = std::fabs(D.m[0][1]); + if (std::fabs(D.m[0][2]) > maxOff) { + maxOff = std::fabs(D.m[0][2]); + p = 0; + q = 2; + } + if (std::fabs(D.m[1][2]) > maxOff) { + maxOff = std::fabs(D.m[1][2]); + p = 1; + q = 2; + } + + if (maxOff < EPSILON) { + break; + } + + // Compute the Jacobi rotation angle. + float theta = 0.0; + if (std::fabs(D.m[p][p] - D.m[q][q]) < EPSILON) { + theta = M_PI / 4.0; + } else { + theta = 0.5 * std::atan2(2 * D.m[p][q], D.m[q][q] - D.m[p][p]); + } + + float c = std::cos(theta); + float s = std::sin(theta); + + // Update the diagonal elements. + float D_pp = c * c * D.m[p][p] - 2 * s * c * D.m[p][q] + s * s * D.m[q][q]; + float D_qq = s * s * D.m[p][p] + 2 * s * c * D.m[p][q] + c * c * D.m[q][q]; + D.m[p][p] = D_pp; + D.m[q][q] = D_qq; + D.m[p][q] = D.m[q][p] = 0.0; + + // Update the remaining elements. + for (int j = 0; j < 3; j++) { + if (j != p && j != q) { + float D_pj = c * D.m[p][j] - s * D.m[q][j]; + float D_qj = s * D.m[p][j] + c * D.m[q][j]; + D.m[p][j] = D.m[j][p] = D_pj; + D.m[q][j] = D.m[j][q] = D_qj; + } + } + + // Update the eigenvector matrix V. + for (int i = 0; i < 3; i++) { + float V_ip = c * V.m[i][p] - s * V.m[i][q]; + float V_iq = s * V.m[i][p] + c * V.m[i][q]; + V.m[i][p] = V_ip; + V.m[i][q] = V_iq; + } + } + + diag[0] = D.m[0][0]; + diag[1] = D.m[1][1]; + diag[2] = D.m[2][2]; +} + +// Sorts the eigenvalues (and the corresponding eigenvectors in V) in descending +// order. +void sortEigenDecomposition(float eigenvalues[3], Matrix3x3& V) { + int indices[3] = {0, 1, 2}; + std::sort(indices, indices + 3, [&](int a, int b) { + return eigenvalues[a] > eigenvalues[b]; + }); + + float sortedEigenvalues[3]; + Matrix3x3 sortedV{}; + for (int i = 0; i < 3; i++) { + sortedEigenvalues[i] = eigenvalues[indices[i]]; + for (int j = 0; j < 3; j++) { + sortedV.m[j][i] = V.m[j][indices[i]]; + } + } + for (int i = 0; i < 3; i++) { + eigenvalues[i] = sortedEigenvalues[i]; + for (int j = 0; j < 3; j++) { + V.m[j][i] = sortedV.m[j][i]; + } + } +} + +// Computes the cross product of two 3D vectors. +void crossProduct(const float a[3], const float b[3], float result[3]) { + result[0] = a[1] * b[2] - a[2] * b[1]; + result[1] = a[2] * b[0] - a[0] * b[2]; + result[2] = a[0] * b[1] - a[1] * b[0]; +} + +// Normalizes a 3D vector. +void normalize(float v[3]) { + float norm = std::sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + if (norm > EPSILON) { + v[0] /= norm; + v[1] /= norm; + v[2] /= norm; + } +} + +// Computes the singular value decomposition of A such that A = U * S * Vt. +// U and Vt are orthogonal matrices and S is a diagonal matrix with singular +// values. +std::tuple svd(const Matrix3x3& A) { + // Compute A^T * A (which is symmetric). + Matrix3x3 At = transpose(A); + Matrix3x3 ATA = multiply(At, A); + + // Compute the eigen-decomposition of ATA. + float eigenvalues[3]; + Matrix3x3 V{}; + jacobiEigenDecomposition(ATA, eigenvalues, V); + sortEigenDecomposition(eigenvalues, V); + + // The singular values are the square roots of the eigenvalues. + float sigma[3]; + for (int i = 0; i < 3; i++) { + sigma[i] = (eigenvalues[i] > 0.0) ? std::sqrt(eigenvalues[i]) : 0.0; + } + + // Compute U = A * V * S^{-1}. + Matrix3x3 U{}; + for (int i = 0; i < 3; i++) { + float av[3] = {0, 0, 0}; + // Multiply A by the i-th eigenvector (column of V). + for (int r = 0; r < 3; r++) { + for (int c = 0; c < 3; c++) { + av[r] += A.m[r][c] * V.m[c][i]; + } + } + if (sigma[i] > EPSILON) { + for (int r = 0; r < 3; r++) { + U.m[r][i] = av[r] / sigma[i]; + } + } else { + // If sigma[i] is nearly zero, choose a vector orthogonal to the previous + // ones. + float vec[3] = {0, 0, 0}; + if (i == 1) { + float u0[3] = {U.m[0][0], U.m[1][0], U.m[2][0]}; + float tmp[3] = {1, 0, 0}; + float dot = u0[0] * tmp[0] + u0[1] * tmp[1] + u0[2] * tmp[2]; + if (std::fabs(dot) > 0.9) { + tmp[0] = 0; + tmp[1] = 1; + tmp[2] = 0; + } + crossProduct(u0, tmp, vec); + } else if (i == 2) { + float u0[3] = {U.m[0][0], U.m[1][0], U.m[2][0]}; + float u1[3] = {U.m[0][1], U.m[1][1], U.m[2][1]}; + crossProduct(u0, u1, vec); + } + normalize(vec); + for (int r = 0; r < 3; r++) { + U.m[r][i] = vec[r]; + } + } + } + + // Construct the diagonal S matrix. + Matrix3x3 S{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + S.m[i][j] = (i == j) ? sigma[i] : 0.0; + } + } + + // Vt is the transpose of V. + Matrix3x3 Vt = transpose(V); + + return std::make_tuple(U, S, Vt); +} +} // namespace + +std::tuple linalg_svd_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& A, + bool full_matrices, + bool compute_uv, + ::executorch::aten::optional<::executorch::aten::string_view> driver, + Tensor& U, + Tensor& S, + Tensor& Vh) { + std::tuple ret_val(U, S, Vh); + + ET_KERNEL_CHECK_MSG( + ctx, + A.scalar_type() == ScalarType::Float, + InvalidArgument, + ret_val, + "input.dtype(): %s must be %s", + ::torch::executor::toString(A.scalar_type()), + ::torch::executor::toString(ScalarType::Float)); + + ET_KERNEL_CHECK_MSG( + ctx, A.numel() > 0, InvalidArgument, ret_val, "input.size() must be > 0"); + + ET_KERNEL_CHECK_MSG( + ctx, + A.numel() % 9 == 0, + InvalidArgument, + ret_val, + "SVD of only 3x3 matrix is supported! Expected the input to have (batch_size x 9) number of elements, but received %d elements instead", + int(A.numel())); + + int batch_size = A.numel() / 9; + + ET_KERNEL_CHECK_MSG( + ctx, + U.numel() / 9 == batch_size, + InvalidArgument, + ret_val, + "Output tensor U must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(U.numel() / 9)); + + ET_KERNEL_CHECK_MSG( + ctx, + S.numel() / 3 == batch_size, + InvalidArgument, + ret_val, + "Output tensor S must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(S.numel() / 3)); + + ET_KERNEL_CHECK_MSG( + ctx, + Vh.numel() / 9 == batch_size, + InvalidArgument, + ret_val, + "Output tensor Vh must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(Vh.numel() / 9)); + + const float* A_data = A.const_data_ptr(); + float* U_data = U.mutable_data_ptr(); + float* S_data = S.mutable_data_ptr(); + float* Vh_data = Vh.mutable_data_ptr(); + + for (int i = 0; i < batch_size; i++) { + int offset = i * 9; + Matrix3x3 A_mat = {{ + {A_data[offset + 0], A_data[offset + 1], A_data[offset + 2]}, + {A_data[offset + 3], A_data[offset + 4], A_data[offset + 5]}, + {A_data[offset + 6], A_data[offset + 7], A_data[offset + 8]}, + }}; + + Matrix3x3 U_mat{}, S_mat{}, Vh_mat{}; + std::tie(U_mat, S_mat, Vh_mat) = svd(A_mat); + + U_data[offset + 0] = U_mat.m[0][0]; + U_data[offset + 1] = U_mat.m[0][1]; + U_data[offset + 2] = U_mat.m[0][2]; + U_data[offset + 3] = U_mat.m[1][0]; + U_data[offset + 4] = U_mat.m[1][1]; + U_data[offset + 5] = U_mat.m[1][2]; + U_data[offset + 6] = U_mat.m[2][0]; + U_data[offset + 7] = U_mat.m[2][1]; + U_data[offset + 8] = U_mat.m[2][2]; + + S_data[offset + 0] = S_mat.m[0][0]; + S_data[offset + 1] = S_mat.m[1][1]; + S_data[offset + 2] = S_mat.m[2][2]; + + Vh_data[offset + 0] = Vh_mat.m[0][0]; + Vh_data[offset + 1] = Vh_mat.m[0][1]; + Vh_data[offset + 2] = Vh_mat.m[0][2]; + Vh_data[offset + 3] = Vh_mat.m[1][0]; + Vh_data[offset + 4] = Vh_mat.m[1][1]; + Vh_data[offset + 5] = Vh_mat.m[1][2]; + Vh_data[offset + 6] = Vh_mat.m[2][0]; + Vh_data[offset + 7] = Vh_mat.m[2][1]; + Vh_data[offset + 8] = Vh_mat.m[2][2]; + } + + return ret_val; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_linalg_svd.h b/backends/cadence/generic/operators/op_linalg_svd.h new file mode 100644 index 00000000000..3975e38a9bd --- /dev/null +++ b/backends/cadence/generic/operators/op_linalg_svd.h @@ -0,0 +1,30 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +std::tuple< + ::executorch::aten::Tensor&, + ::executorch::aten::Tensor&, + ::executorch::aten::Tensor&> +linalg_svd_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& A, + bool full_matrices, + bool compute_uv, + ::executorch::aten::optional<::executorch::aten::string_view> driver, + ::executorch::aten::Tensor& U, + ::executorch::aten::Tensor& S, + ::executorch::aten::Tensor& Vh); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_roi_align_box_processor.cpp b/backends/cadence/generic/operators/op_roi_align_box_processor.cpp new file mode 100644 index 00000000000..5bbc1776c75 --- /dev/null +++ b/backends/cadence/generic/operators/op_roi_align_box_processor.cpp @@ -0,0 +1,173 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +#include + +namespace impl { +namespace generic { +namespace native { +namespace { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +using UnpackedVec = std::array; +using PackedVec = std::array; +using IterVec = std::array; + +IterVec computeAddrIncr(const IterVec& shape, const IterVec& strides) { + auto rank = shape.size(); + auto inc = strides; + for (int n = 1; n < static_cast(rank); ++n) { + inc[n] = strides[n] - strides[n - 1] * shape[n - 1] + inc[n - 1]; + } + return inc; +} + +template +PackedVec packTuringVals(const UnpackedVec& vals, bool is_signed) { + PackedVec result{}; + int bitPos = 0; // bit position in output vector + for (int v : vals) { + assert(is_signed || v >= 0); + if (is_signed) { + assert( + v >= -(1 << (perItemBitWidth - 1)) && + v < (1 << (perItemBitWidth - 1))); + } else { + assert(v < (1 << perItemBitWidth)); + } + + if (v < 0) { + v = (1 << perItemBitWidth) + v; + } + + // Extract bit by bit and store in the output array + for (int bit = 0; bit < perItemBitWidth; ++bit) { + auto outBitIndex = bitPos + bit; + auto byteIndex = outBitIndex / 8; + auto bitInByte = outBitIndex % 8; + // Extract bit from val + uint8_t bitVal = (v >> bit) & 1; + // Set bit in output byte + result[byteIndex] |= (bitVal << bitInByte); + } + bitPos += perItemBitWidth; + } + assert(bitPos == vals.size() * perItemBitWidth); + return result; +} + +template +constexpr int get_fp_scale() { + return 1 << frac_bits; +} + +template +int convert_to_S13(float fp) { + return int(std::round(fp * get_fp_scale())); +} + +PackedVec convertBoxPosToTuringConfig( + float topLeftX, + float topLeftY, + float bottomRightX, + float bottomRightY, + int roiAlignNumBoxes, + int output_size_h, + int output_size_w, + int sampling_ratio, + bool aligned) { + constexpr int precisionMode = 0; + auto dstImgH = output_size_h * sampling_ratio; + auto dstImgW = output_size_w * sampling_ratio; + auto dstTileH = dstImgH; + auto dstTileW = dstImgW; + + float stepX = (bottomRightX - topLeftX) / dstImgW; + float stepY = (bottomRightY - topLeftY) / dstImgH; + + if (aligned) { + topLeftX -= 0.5; + topLeftY -= 0.5; + } + + auto anchorX = convert_to_S13(topLeftX + stepX / 2); + auto anchorY = convert_to_S13(topLeftY + stepY / 2); + + UnpackedVec vals{}; + vals[0] = anchorX; + vals[1] = anchorY; + + IterVec shape = {dstTileW, dstTileH, 1, 1, 1, roiAlignNumBoxes}; + auto addrIncrementX = computeAddrIncr( + shape, + {convert_to_S13(stepX), + 0, + convert_to_S13(stepX * dstTileW), + 0, + 0, + 0}); + auto addrIncrementY = computeAddrIncr( + shape, + {0, + convert_to_S13(stepY), + 0, + convert_to_S13(stepY * dstTileH), + 0, + 0}); + + for (int i = 0; i < 10; ++i) { + vals[i + 2] = i < addrIncrementX.size() + ? addrIncrementX[i] + : addrIncrementX[addrIncrementX.size() - 1]; + vals[i + 12] = i < addrIncrementY.size() + ? addrIncrementY[i] + : addrIncrementY[addrIncrementY.size() - 1]; + } + + return packTuringVals(vals, true); +} + +} // namespace + +Tensor& roi_align_box_processor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& rois, + int64_t output_size_h, + int64_t output_size_w, + int64_t sampling_ratio, + bool aligned, + Tensor& out) { + int K = static_cast(rois.size(0)); + auto roi = rois.const_data_ptr(); + for (int i = 0; i < K; ++i) { + assert( + static_cast(roi[i * 5]) == 0 && "Only support 1 image for now."); + auto x1 = roi[i * 5 + 1]; + auto y1 = roi[i * 5 + 2]; + auto x2 = roi[i * 5 + 3]; + auto y2 = roi[i * 5 + 4]; + auto turing_roi = convertBoxPosToTuringConfig( + x1, + y1, + x2, + y2, + static_cast(K), + static_cast(output_size_h), + static_cast(output_size_w), + static_cast(sampling_ratio), + aligned); + static_assert(turing_roi.size() == 80); + + auto out_ptr = out.mutable_data_ptr() + i * turing_roi.size(); + for (auto val : turing_roi) { + *out_ptr++ = val; + } + } + return out; +} +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_roi_align_box_processor.h b/backends/cadence/generic/operators/op_roi_align_box_processor.h new file mode 100644 index 00000000000..9948dbb0b1d --- /dev/null +++ b/backends/cadence/generic/operators/op_roi_align_box_processor.h @@ -0,0 +1,23 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& roi_align_box_processor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& rois, + int64_t output_size_h, + int64_t output_size_w, + int64_t sampling_ratio, + bool aligned, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_rope.cpp b/backends/cadence/generic/operators/op_rope.cpp new file mode 100644 index 00000000000..a11c56b3d19 --- /dev/null +++ b/backends/cadence/generic/operators/op_rope.cpp @@ -0,0 +1,59 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "executorch/backends/cadence/generic/operators/op_rope.h" + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::Tensor; + +Tensor& rope_out( + ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& sin_tensor, + const Tensor& cos_tensor, + const optional& pos, + Tensor& out) { + // Input shape is [1, seq, h, hd / 2, 2] or [1, seq, h, hd] + const auto kSeq = input.size(1); + const auto kH = input.size(2); + const auto kHd = input.numel() / (kSeq * kH); + for (int32_t s = 0; s < kSeq; ++s) { + for (int32_t h = 0; h < kH; ++h) { + for (int32_t hd_o = 0; hd_o < kHd / 2; ++hd_o) { + float x_0 = + input.const_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2]; + float x_1 = + input + .const_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2 + 1]; + int64_t token_id = s; + if (pos.has_value()) { + if (pos->scalar_type() == ::executorch::aten::ScalarType::Int) { + token_id = pos.has_value() ? pos->const_data_ptr()[s] : s; + } else { + token_id = pos.has_value() ? pos->const_data_ptr()[s] : s; + } + } + float sin = + sin_tensor.const_data_ptr()[token_id * kHd / 2 + hd_o]; + float cos = + cos_tensor.const_data_ptr()[token_id * kHd / 2 + hd_o]; + + float out_0 = x_0 * cos - x_1 * sin; + float out_1 = x_0 * sin + x_1 * cos; + out.mutable_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2] = + out_0; + out.mutable_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2 + 1] = + out_1; + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_rope.h b/backends/cadence/generic/operators/op_rope.h new file mode 100644 index 00000000000..24308eb9dce --- /dev/null +++ b/backends/cadence/generic/operators/op_rope.h @@ -0,0 +1,22 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& rope_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& sin_tensor, + const ::executorch::aten::Tensor& cos_tensor, + const ::executorch::aten::optional<::executorch::aten::Tensor>& pos, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_where_scalar.cpp b/backends/cadence/generic/operators/op_where_scalar.cpp new file mode 100644 index 00000000000..7c08c8f9312 --- /dev/null +++ b/backends/cadence/generic/operators/op_where_scalar.cpp @@ -0,0 +1,27 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& where_Scalar_out( + ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& condition, + const double val1, + const double val2, + ::executorch::aten::Tensor& out) { + const float val1_f = static_cast(val1); + const float val2_f = static_cast(val2); + for (int i = 0; i < out.numel(); ++i) { + out.mutable_data_ptr()[i] = + condition.const_data_ptr()[i] ? val1_f : val2_f; + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_where_scalar.h b/backends/cadence/generic/operators/op_where_scalar.h new file mode 100644 index 00000000000..4176ae86a02 --- /dev/null +++ b/backends/cadence/generic/operators/op_where_scalar.h @@ -0,0 +1,21 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& where_Scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& condition, + double val1, + double val2, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index fa0f128b229..a04dd70f53e 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -183,6 +183,71 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "op_where_scalar", + srcs = ["op_where_scalar.cpp"], + exported_headers = ["op_where_scalar.h", "operators.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_rope", + srcs = ["op_rope.cpp"], + exported_headers = ["op_rope.h", "operators.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_linalg_svd", + srcs = ["op_linalg_svd.cpp"], + headers = ["op_linalg_svd.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_roi_align_box_processor", + srcs = ["op_roi_align_box_processor.cpp"], + exported_headers = ["op_roi_align_box_processor.h", "operators.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + # Combined target for backward compatibility # NOTE: cadence_aot_lib now uses individual targets directly for better linking runtime.cxx_library( diff --git a/backends/cadence/vision/kernels/kernels.cpp b/backends/cadence/vision/kernels/kernels.cpp index 70c811df741..d87ff98e06f 100644 --- a/backends/cadence/vision/kernels/kernels.cpp +++ b/backends/cadence/vision/kernels/kernels.cpp @@ -25,8 +25,8 @@ void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) { // Quantize a fp32 value to an int8_t/uint8_t value template T quantize(const float x, float scale, int32_t zero_point) { - constexpr float min_val = std::numeric_limits::min(); - constexpr float max_val = std::numeric_limits::max(); + constexpr float min_val = static_cast(std::numeric_limits::min()); + constexpr float max_val = static_cast(std::numeric_limits::max()); float tmp = roundf(x * scale + zero_point); return std::max(std::min(tmp, max_val), min_val); }