@@ -70,6 +70,7 @@ REGISTER_OP("ReverbClientUpdatePriorities")
7070 .Input(" table: string" )
7171 .Input(" keys: uint64" )
7272 .Input(" priorities: double" )
73+ .Input(" keys_to_delete: uint64" )
7374 .Doc(R"doc(
7475Blocking call to update the priorities of a collection of items. Keys that could
7576not be found in table `table` on server are ignored and does not impact the rest
@@ -187,7 +188,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
187188 const tensorflow::Tensor* keys;
188189 OP_REQUIRES_OK (context, context->input (" keys" , &keys));
189190 const tensorflow::Tensor* priorities;
190- OP_REQUIRES_OK (context, context->input (" priorities" , &priorities));
191+ OP_REQUIRES_OK (context, context->input (" priorities" , &priorities));
192+ const tensorflow::Tensor* keys_to_delete;
193+ OP_REQUIRES_OK (context, context->input (" keys_to_delete" , &keys_to_delete));
191194
192195 OP_REQUIRES (
193196 context, keys->dims () == 1 ,
@@ -197,6 +200,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
197200 " Tensors `keys` and `priorities` do not match in shape (" ,
198201 keys->shape ().DebugString (), " vs. " ,
199202 priorities->shape ().DebugString (), " )" ));
203+ OP_REQUIRES (
204+ context, keys_to_delete->dims () == 1 ,
205+ InvalidArgument (" Tensors `keys_to_delete` must be of rank 1." ));
200206
201207 std::string table_str = table->scalar <tstring>()();
202208 std::vector<KeyWithPriority> updates;
@@ -207,14 +213,19 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
207213 updates.push_back (std::move (update));
208214 }
209215
216+ std::vector<uint64_t > deletes;
217+ for (int i = 0 ; i < keys_to_delete->dim_size (0 ); i++) {
218+ deletes.push_back (keys_to_delete->flat <tensorflow::uint64>()(i));
219+ }
220+
210221 // The call will only fail if the Reverb-server is brought down during an
211222 // active call (e.g preempted). When this happens the request is retried and
212223 // since MutatePriorities sets `wait_for_ready` the request will no be sent
213224 // before the server is brought up again. It is therefore no problem to have
214225 // this retry in this tight loop.
215226 absl::Status status;
216227 do {
217- status = resource->client ()->MutatePriorities (table_str, updates, {} );
228+ status = resource->client ()->MutatePriorities (table_str, updates, deletes );
218229 } while (absl::IsUnavailable (status) || absl::IsDeadlineExceeded (status));
219230 OP_REQUIRES_OK (context, ToTensorflowStatus (status));
220231 }
0 commit comments