Skip to content

Commit 8a9048f

Browse files
committed
Fix: Replace RET_CHECK crash with Status return in TfLiteWeightAccessor
Replaces hard RET_CHECK failures with absl::InternalError and absl::InvalidArgumentError to prevent crashes on invalid models. Fixes #5600.
1 parent ab0fcee commit 8a9048f

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

mediapipe/tasks/cc/genai/inference/utils/xnn_utils/tflite_weight_accessor.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ void TfLiteWeightAccessor::BuildWeightsMapFromTfliteModel(char* data) {
103103
absl::StatusOr<std::shared_ptr<Tensor>> TfLiteWeightAccessor::LoadWeight(
104104
absl::string_view tensor_name, Tensor::DimsType expected_dims,
105105
size_t dim_scale_if_any) const {
106-
RET_CHECK(tflite_model_);
106+
// FIX: Replaced RET_CHECK with proper Status return
107+
if (!tflite_model_) {
108+
return absl::InternalError("TfLiteWeightAccessor: TFLite model is not initialized.");
109+
}
110+
107111
if (!weights_.contains(tensor_name)) {
108112
ABSL_DLOG(WARNING) << "Tensor not found: " << tensor_name;
109113
return nullptr;
@@ -137,7 +141,15 @@ absl::StatusOr<std::shared_ptr<Tensor>> TfLiteWeightAccessor::LoadWeight(
137141
absl::StrCat("Scale tensor not found: ", scale_tensor_name));
138142
}
139143
std::shared_ptr<Tensor> scale_tensor = weights_.at(scale_tensor_name);
140-
RET_CHECK_EQ(expected_dims[dim_scale_if_any], scale_tensor->num_elements);
144+
145+
// FIX: Replaced RET_CHECK_EQ with safe check
146+
if (expected_dims[dim_scale_if_any] != scale_tensor->num_elements) {
147+
return absl::InvalidArgumentError(absl::StrCat(
148+
"Scale tensor dimension mismatch for ", scale_tensor_name,
149+
". Expected: ", expected_dims[dim_scale_if_any],
150+
", Actual: ", scale_tensor->num_elements));
151+
}
152+
141153
switch (qtensor->datatype) {
142154
case xnn_datatype_qcint8:
143155
result = std::make_shared<QCTensor>(
@@ -166,12 +178,20 @@ absl::StatusOr<std::shared_ptr<Tensor>>
166178
TfLiteWeightAccessor::LoadTransposedWeight(absl::string_view tensor_name,
167179
Tensor::DimsType expected_dims,
168180
size_t dim_scale_if_any) const {
169-
RET_CHECK(tflite_model_);
170-
RET_CHECK_EQ(expected_dims.size(), 2);
181+
// FIX: Replaced RET_CHECK with proper Status return
182+
if (!tflite_model_) {
183+
return absl::InternalError("TfLiteWeightAccessor: TFLite model is not initialized.");
184+
}
185+
// FIX: Replaced RET_CHECK_EQ with proper Status return
186+
if (expected_dims.size() != 2) {
187+
return absl::InvalidArgumentError(absl::StrCat(
188+
"LoadTransposedWeight expects 2 dimensions, got: ", expected_dims.size()));
189+
}
190+
171191
return LoadWeight(
172192
tensor_name,
173193
Tensor::DimsType(expected_dims.rbegin(), expected_dims.rend()),
174194
1 - dim_scale_if_any);
175195
}
176196

177-
} // namespace mediapipe::tasks::genai::xnn_utils
197+
} // namespace mediapipe::tasks::genai::xnn_utils

0 commit comments

Comments
 (0)