Skip to content

Commit 2f47a22

Browse files
committed
[Feature] bind delete_keys parameter on tf_client update_priority, like client API
1 parent 2b74d54 commit 2f47a22

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

reverb/cc/ops/client.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ REGISTER_OP("ReverbClientUpdatePriorities")
7070
.Input("table: string")
7171
.Input("keys: uint64")
7272
.Input("priorities: double")
73+
.Input("keys_to_delete: uint64")
7374
.Doc(R"doc(
7475
Blocking call to update the priorities of a collection of items. Keys that could
7576
not be found in table `table` on server are ignored and does not impact the rest
@@ -187,7 +188,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
187188
const tensorflow::Tensor* keys;
188189
OP_REQUIRES_OK(context, context->input("keys", &keys));
189190
const tensorflow::Tensor* priorities;
190-
OP_REQUIRES_OK(context, context->input("priorities", &priorities));
191+
OP_REQUIRES_OK(context, context->input("priorities", &priorities));
192+
const tensorflow::Tensor* keys_to_delete;
193+
OP_REQUIRES_OK(context, context->input("keys_to_delete", &keys_to_delete));
191194

192195
OP_REQUIRES(
193196
context, keys->dims() == 1,
@@ -197,6 +200,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
197200
"Tensors `keys` and `priorities` do not match in shape (",
198201
keys->shape().DebugString(), " vs. ",
199202
priorities->shape().DebugString(), ")"));
203+
OP_REQUIRES(
204+
context, keys_to_delete->dims() == 1,
205+
InvalidArgument("Tensors `keys_to_delete` must be of rank 1."));
200206

201207
std::string table_str = table->scalar<tstring>()();
202208
std::vector<KeyWithPriority> updates;
@@ -207,14 +213,19 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
207213
updates.push_back(std::move(update));
208214
}
209215

216+
std::vector<uint64_t> deletes;
217+
for (int i = 0; i < keys_to_delete->dim_size(0); i++) {
218+
deletes.push_back(keys_to_delete->flat<tensorflow::uint64>()(i));
219+
}
220+
210221
// The call will only fail if the Reverb-server is brought down during an
211222
// active call (e.g preempted). When this happens the request is retried and
212223
// since MutatePriorities sets `wait_for_ready` the request will no be sent
213224
// before the server is brought up again. It is therefore no problem to have
214225
// this retry in this tight loop.
215226
absl::Status status;
216227
do {
217-
status = resource->client()->MutatePriorities(table_str, updates, {});
228+
status = resource->client()->MutatePriorities(table_str, updates, deletes);
218229
} while (absl::IsUnavailable(status) || absl::IsDeadlineExceeded(status));
219230
OP_REQUIRES_OK(context, ToTensorflowStatus(status));
220231
}

reverb/tf_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def update_priorities(self,
117117
table: str,
118118
keys: tf.Tensor,
119119
priorities: tf.Tensor,
120+
keys_to_delete: Optional[tf.Tensor] = None,
120121
name: str = None):
121122
"""Creates op for updating priorities of existing items in the replay.
122123
@@ -126,16 +127,20 @@ def update_priorities(self,
126127
table: Probability table to update.
127128
keys: Keys of the items to update. Must be same length as `priorities`.
128129
priorities: New priorities for `keys`. Must be same length as `keys`.
130+
keys_to_delete: Keys of the items to delete
129131
name: Optional name for the operation.
130132
131133
Returns:
132134
A tf-op for performing the update.
133135
"""
134136

137+
if keys_to_delete is None:
138+
keys_to_delete = tf.constant([], dtype=tf.uint64)
139+
135140
with tf.name_scope(name, f'{self._name}_update_priorities',
136141
['update_priorities']) as scope:
137142
return gen_client_ops.reverb_client_update_priorities(
138-
self._handle, table, keys, priorities, name=scope)
143+
self._handle, table, keys, priorities, keys_to_delete, name=scope)
139144

140145
def dataset(self,
141146
table: str,

reverb/tf_client_test.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from reverb import tf_client
2626
import tensorflow.compat.v1 as tf
2727

28-
2928
def make_server():
3029
return server.Server(
3130
tables=[
@@ -161,6 +160,39 @@ def test_priority_update_is_applied(self):
161160
self.fail('Updated item was not found')
162161

163162

163+
def test_delete_key_is_applied(self):
164+
# Start with 4 items
165+
for i in range(4):
166+
self._client.insert([np.array([i], dtype=np.uint32)], {'dist': 1})
167+
168+
# Until we have recieved all 4 items.
169+
items = {}
170+
while len(items) < 4:
171+
item = next(self._client.sample('dist'))[0]
172+
items[item.info.key] = item.info.probability
173+
174+
# remove 2 items
175+
items_to_keep = [*items.keys()][:2]
176+
items_to_remove = [*items.keys()][2:]
177+
with self.session() as session:
178+
client = tf_client.TFClient(self._client.server_address)
179+
for key in items_to_remove:
180+
update_op = client.update_priorities(
181+
table=tf.constant('dist'),
182+
keys=tf.constant([], dtype=tf.uint64),
183+
priorities=tf.constant([], dtype=tf.float64),
184+
keys_to_delete=tf.constant([key], dtype=tf.uint64))
185+
self.assertIsNone(session.run(update_op))
186+
187+
# 2 remaining items must persist
188+
final_items = {}
189+
for _ in range(1000):
190+
item = next(self._client.sample('dist'))[0]
191+
self.assertTrue(item.info.key in items_to_keep)
192+
final_items[item.info.key] = item.info.probability
193+
self.assertEqual(len(final_items), 2)
194+
195+
164196
class InsertOpTest(tf.test.TestCase):
165197

166198
@classmethod

0 commit comments

Comments
 (0)