Skip to content

bpf: tcp: Exactly-once socket iteration #9026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 196 additions & 67 deletions net/ipv4/tcp_ipv4.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include <linux/times.h>
#include <linux/slab.h>
#include <linux/sched.h>
#include <linux/sock_diag.h>

#include <net/net_namespace.h>
#include <net/icmp.h>
Expand Down Expand Up @@ -3014,13 +3015,17 @@ static int tcp4_seq_show(struct seq_file *seq, void *v)
}

#ifdef CONFIG_BPF_SYSCALL
union bpf_tcp_iter_batch_item {
struct sock *sk;
__u64 cookie;
};

struct bpf_tcp_iter_state {
struct tcp_iter_state state;
unsigned int cur_sk;
unsigned int end_sk;
unsigned int max_sk;
struct sock **batch;
bool st_bucket_done;
union bpf_tcp_iter_batch_item *batch;
};

struct bpf_iter__tcp {
Expand All @@ -3043,134 +3048,265 @@ static int tcp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,

static void bpf_iter_tcp_put_batch(struct bpf_tcp_iter_state *iter)
{
while (iter->cur_sk < iter->end_sk)
sock_gen_put(iter->batch[iter->cur_sk++]);
union bpf_tcp_iter_batch_item *item;
unsigned int cur_sk = iter->cur_sk;
__u64 cookie;

/* Remember the cookies of the sockets we haven't seen yet, so we can
* pick up where we left off next time around.
*/
while (cur_sk < iter->end_sk) {
item = &iter->batch[cur_sk++];
cookie = sock_gen_cookie(item->sk);
sock_gen_put(item->sk);
item->cookie = cookie;
}
}

static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter,
unsigned int new_batch_sz)
unsigned int new_batch_sz, gfp_t flags)
{
struct sock **new_batch;
union bpf_tcp_iter_batch_item *new_batch;

new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz,
GFP_USER | __GFP_NOWARN);
flags | __GFP_NOWARN);
if (!new_batch)
return -ENOMEM;

bpf_iter_tcp_put_batch(iter);
if (flags != GFP_NOWAIT)
bpf_iter_tcp_put_batch(iter);

memcpy(new_batch, iter->batch, sizeof(*iter->batch) * iter->end_sk);
kvfree(iter->batch);
iter->batch = new_batch;
iter->max_sk = new_batch_sz;

return 0;
}

static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq,
struct sock *start_sk)
static struct sock *bpf_iter_tcp_resume_bucket(struct sock *first_sk,
union bpf_tcp_iter_batch_item *cookies,
int n_cookies)
{
struct hlist_nulls_node *node;
struct sock *sk;
int i;

for (i = 0; i < n_cookies; i++) {
sk = first_sk;
sk_nulls_for_each_from(sk, node) {
if (cookies[i].cookie == atomic64_read(&sk->sk_cookie))
return sk;
}
}

return NULL;
}

static struct sock *bpf_iter_tcp_resume_listening(struct seq_file *seq)
{
struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
struct bpf_tcp_iter_state *iter = seq->private;
struct tcp_iter_state *st = &iter->state;
unsigned int find_cookie = iter->cur_sk;
unsigned int end_cookie = iter->end_sk;
int resume_bucket = st->bucket;
struct sock *sk;

if (end_cookie && find_cookie == end_cookie)
++st->bucket;

sk = listening_get_first(seq);
iter->cur_sk = 0;
iter->end_sk = 0;

if (sk && st->bucket == resume_bucket && end_cookie) {
sk = bpf_iter_tcp_resume_bucket(sk, &iter->batch[find_cookie],
end_cookie - find_cookie);
if (!sk) {
spin_unlock(&hinfo->lhash2[st->bucket].lock);
++st->bucket;
sk = listening_get_first(seq);
}
}

return sk;
}

static struct sock *bpf_iter_tcp_resume_established(struct seq_file *seq)
{
struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
struct bpf_tcp_iter_state *iter = seq->private;
struct tcp_iter_state *st = &iter->state;
unsigned int find_cookie = iter->cur_sk;
unsigned int end_cookie = iter->end_sk;
int resume_bucket = st->bucket;
struct sock *sk;

if (end_cookie && find_cookie == end_cookie)
++st->bucket;

sk = established_get_first(seq);
iter->cur_sk = 0;
iter->end_sk = 0;

if (sk && st->bucket == resume_bucket && end_cookie) {
sk = bpf_iter_tcp_resume_bucket(sk, &iter->batch[find_cookie],
end_cookie - find_cookie);
if (!sk) {
spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket));
++st->bucket;
sk = established_get_first(seq);
}
}

return sk;
}

static struct sock *bpf_iter_tcp_resume(struct seq_file *seq)
{
struct bpf_tcp_iter_state *iter = seq->private;
struct tcp_iter_state *st = &iter->state;
struct sock *sk = NULL;

switch (st->state) {
case TCP_SEQ_STATE_LISTENING:
sk = bpf_iter_tcp_resume_listening(seq);
if (sk)
break;
st->bucket = 0;
st->state = TCP_SEQ_STATE_ESTABLISHED;
fallthrough;
case TCP_SEQ_STATE_ESTABLISHED:
sk = bpf_iter_tcp_resume_established(seq);
}

return sk;
}

static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq,
struct sock **start_sk)
{
struct bpf_tcp_iter_state *iter = seq->private;
struct hlist_nulls_node *node;
unsigned int expected = 1;
struct sock *sk;

sock_hold(start_sk);
iter->batch[iter->end_sk++] = start_sk;
sock_hold(*start_sk);
iter->batch[iter->end_sk++].sk = *start_sk;

sk = sk_nulls_next(start_sk);
sk = sk_nulls_next(*start_sk);
*start_sk = NULL;
sk_nulls_for_each_from(sk, node) {
if (seq_sk_match(seq, sk)) {
if (iter->end_sk < iter->max_sk) {
sock_hold(sk);
iter->batch[iter->end_sk++] = sk;
iter->batch[iter->end_sk++].sk = sk;
} else if (!*start_sk) {
/* Remember where we left off. */
*start_sk = sk;
}
expected++;
}
}
spin_unlock(&hinfo->lhash2[st->bucket].lock);

return expected;
}

static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq,
struct sock *start_sk)
struct sock **start_sk)
{
struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
struct bpf_tcp_iter_state *iter = seq->private;
struct tcp_iter_state *st = &iter->state;
struct hlist_nulls_node *node;
unsigned int expected = 1;
struct sock *sk;

sock_hold(start_sk);
iter->batch[iter->end_sk++] = start_sk;
sock_hold(*start_sk);
iter->batch[iter->end_sk++].sk = *start_sk;

sk = sk_nulls_next(start_sk);
sk = sk_nulls_next(*start_sk);
*start_sk = NULL;
sk_nulls_for_each_from(sk, node) {
if (seq_sk_match(seq, sk)) {
if (iter->end_sk < iter->max_sk) {
sock_hold(sk);
iter->batch[iter->end_sk++] = sk;
iter->batch[iter->end_sk++].sk = sk;
} else if (!*start_sk) {
/* Remember where we left off. */
*start_sk = sk;
}
expected++;
}
}
spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket));

return expected;
}

static struct sock *bpf_iter_tcp_batch(struct seq_file *seq)
static void bpf_iter_tcp_unlock_bucket(struct seq_file *seq)
{
struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
struct bpf_tcp_iter_state *iter = seq->private;
struct tcp_iter_state *st = &iter->state;

if (st->state == TCP_SEQ_STATE_LISTENING)
spin_unlock(&hinfo->lhash2[st->bucket].lock);
else
spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket));
}

static struct sock *bpf_iter_tcp_batch(struct seq_file *seq)
{
struct bpf_tcp_iter_state *iter = seq->private;
struct tcp_iter_state *st = &iter->state;
int prev_bucket, prev_state;
unsigned int expected;
bool resized = false;
int resizes = 0;
struct sock *sk;

/* The st->bucket is done. Directly advance to the next
* bucket instead of having the tcp_seek_last_pos() to skip
* one by one in the current bucket and eventually find out
* it has to advance to the next bucket.
*/
if (iter->st_bucket_done) {
st->offset = 0;
st->bucket++;
if (st->state == TCP_SEQ_STATE_LISTENING &&
st->bucket > hinfo->lhash2_mask) {
st->state = TCP_SEQ_STATE_ESTABLISHED;
st->bucket = 0;
}
}
int err;

again:
/* Get a new batch */
iter->cur_sk = 0;
iter->end_sk = 0;
iter->st_bucket_done = false;

sk = tcp_seek_last_pos(seq);
prev_bucket = st->bucket;
prev_state = st->state;
sk = bpf_iter_tcp_resume(seq);
if (!sk)
return NULL; /* Done */
if (st->bucket != prev_bucket || st->state != prev_state)
resizes = 0;
expected = 0;

fill_batch:
if (st->state == TCP_SEQ_STATE_LISTENING)
expected = bpf_iter_tcp_listening_batch(seq, sk);
expected += bpf_iter_tcp_listening_batch(seq, &sk);
else
expected = bpf_iter_tcp_established_batch(seq, sk);
expected += bpf_iter_tcp_established_batch(seq, &sk);

if (iter->end_sk == expected) {
iter->st_bucket_done = true;
return sk;
}
if (unlikely(resizes <= 1 && iter->end_sk != expected)) {
resizes++;

if (resizes == 1) {
bpf_iter_tcp_unlock_bucket(seq);

err = bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2,
GFP_USER);
if (err)
return ERR_PTR(err);
goto again;
}

err = bpf_iter_tcp_realloc_batch(iter, expected, GFP_NOWAIT);
if (err) {
bpf_iter_tcp_unlock_bucket(seq);
return ERR_PTR(err);
}

if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2)) {
resized = true;
goto again;
expected = iter->end_sk;
goto fill_batch;
}

return sk;
bpf_iter_tcp_unlock_bucket(seq);

WARN_ON_ONCE(iter->end_sk != expected);
return iter->batch[0].sk;
}

static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos)
Expand Down Expand Up @@ -3200,16 +3336,11 @@ static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
* meta.seq_num is used instead.
*/
st->num++;
/* Move st->offset to the next sk in the bucket such that
* the future start() will resume at st->offset in
* st->bucket. See tcp_seek_last_pos().
*/
st->offset++;
sock_gen_put(iter->batch[iter->cur_sk++]);
sock_gen_put(iter->batch[iter->cur_sk++].sk);
}

if (iter->cur_sk < iter->end_sk)
sk = iter->batch[iter->cur_sk];
sk = iter->batch[iter->cur_sk].sk;
else
sk = bpf_iter_tcp_batch(seq);

Expand Down Expand Up @@ -3275,10 +3406,8 @@ static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v)
(void)tcp_prog_seq_show(prog, &meta, v, 0);
}

if (iter->cur_sk < iter->end_sk) {
if (iter->cur_sk < iter->end_sk)
bpf_iter_tcp_put_batch(iter);
iter->st_bucket_done = false;
}
}

static const struct seq_operations bpf_iter_tcp_seq_ops = {
Expand Down Expand Up @@ -3596,7 +3725,7 @@ static int bpf_iter_init_tcp(void *priv_data, struct bpf_iter_aux_info *aux)
if (err)
return err;

err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ);
err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ, GFP_USER);
if (err) {
bpf_iter_fini_seq_net(priv_data);
return err;
Expand Down
Loading
Loading