diff --git a/protobuf/sk-opts.proto b/protobuf/sk-opts.proto index f85dea949..aaaf92064 100644 --- a/protobuf/sk-opts.proto +++ b/protobuf/sk-opts.proto @@ -17,6 +17,8 @@ message sk_opts_entry { optional bool so_no_check = 14; optional uint32 so_bound_dev = 15; + + repeated fixed64 so_filter = 16; } enum sk_shutdown { diff --git a/sockets.c b/sockets.c index c0a29d10d..3557133a0 100644 --- a/sockets.c +++ b/sockets.c @@ -5,6 +5,7 @@ #include #include #include +#include #include "libnetlink.h" #include "sockets.h" @@ -32,6 +33,10 @@ #define SK_HASH_SIZE 32 +#ifndef SO_GET_FILTER +#define SO_GET_FILTER SO_ATTACH_FILTER +#endif + static int dump_bound_dev(int sk, SkOptsEntry *soe) { int dev = 0, ret; @@ -69,6 +74,95 @@ static int restore_bound_dev(int sk, SkOptsEntry *soe) return do_restore_opt(sk, SOL_SOCKET, SO_BINDTODEVICE, n, IFNAMSIZ); } +/* + * Protobuf handles le/be himself, but the sock_filter is not just u64, + * it's a structure and we have to preserve the fields order to be able + * to move socket image across architectures. + */ + +static void encode_filter(struct sock_filter *f, uint64_t *img, int n) +{ + int i; + + BUILD_BUG_ON(sizeof(*f) != sizeof(*img)); + + for (i = 0; i < n; i++) + img[i] = ((uint64_t)f[i].code << 48) | + ((uint64_t)f[i].jt << 40) | + ((uint64_t)f[i].jf << 32) | + ((uint64_t)f[i].k << 0); +} + +static void decode_filter(uint64_t *img, struct sock_filter *f, int n) +{ + int i; + + for (i = 0; i < n; i++) { + f[i].code = img[i] >> 48; + f[i].jt = img[i] >> 40; + f[i].jf = img[i] >> 32; + f[i].k = img[i] >> 0; + } +} + +static int dump_socket_filter(int sk, SkOptsEntry *soe) +{ + socklen_t len = 0; + int ret; + struct sock_filter *flt; + + ret = getsockopt(sk, SOL_SOCKET, SO_GET_FILTER, NULL, &len); + if (ret && errno != ENOPROTOOPT) { + pr_perror("Can't get socket filter len"); + return ret; + } + + if (!len) { + pr_info("No filter for socket\n"); + return 0; + } + + flt = xmalloc(len * sizeof(*flt)); + if (!flt) + return -1; + + ret = getsockopt(sk, SOL_SOCKET, SO_GET_FILTER, flt, &len); + if (ret) { + pr_perror("Can't get socket filter\n"); + return ret; + } + + soe->so_filter = xmalloc(len * sizeof(*soe->so_filter)); + if (!soe->so_filter) + return -1; + + encode_filter(flt, soe->so_filter, len); + soe->n_so_filter = len; + xfree(flt); + return 0; +} + +static int restore_socket_filter(int sk, SkOptsEntry *soe) +{ + int ret; + struct sock_fprog sfp; + + if (!soe->n_so_filter) + return 0; + + pr_info("Restoring socket filter\n"); + sfp.len = soe->n_so_filter; + sfp.filter = xmalloc(soe->n_so_filter * sfp.len); + if (!sfp.filter) + return -1; + + decode_filter(soe->so_filter, sfp.filter, sfp.len); + ret = restore_opt(sk, SOL_SOCKET, SO_ATTACH_FILTER, &sfp); + xfree(sfp.filter); + + return ret; +} + static struct socket_desc *sockets[SK_HASH_SIZE]; struct socket_desc *lookup_socket(int ino, int family) @@ -160,6 +254,7 @@ int restore_socket_opts(int sk, SkOptsEntry *soe) ret |= restore_opt(sk, SOL_SOCKET, SO_RCVTIMEO, &tv); ret |= restore_bound_dev(sk, soe); + ret |= restore_socket_filter(sk, soe); /* The restore of SO_REUSEADDR depends on type of socket */ @@ -227,12 +322,14 @@ int dump_socket_opts(int sk, SkOptsEntry *soe) soe->so_no_check = val ? true : false; ret |= dump_bound_dev(sk, soe); + ret |= dump_socket_filter(sk, soe); return ret; } void release_skopts(SkOptsEntry *soe) { + xfree(soe->so_filter); } int dump_socket(struct fd_parms *p, int lfd, const struct cr_fdset *cr_fdset)