diff --git a/include/util-net.h b/include/util-net.h index bb68aa343..ec5888153 100644 --- a/include/util-net.h +++ b/include/util-net.h @@ -25,22 +25,23 @@ struct scm_fdset { struct msghdr hdr; struct iovec iov; char msg_buf[CR_SCM_MSG_SIZE]; - int msg; /* We are to send at least one byte */ + char msg[CR_SCM_MAX_FD]; }; -extern int send_fds(int sock, struct sockaddr_un *saddr, int saddr_len, int *fds, int nr_fds); -extern int recv_fds(int sock, int *fds, int nr_fds); +extern int send_fds(int sock, struct sockaddr_un *saddr, int saddr_len, + int *fds, int nr_fds, bool with_flags); +extern int recv_fds(int sock, int *fds, int nr_fds, char *flags); static inline int send_fd(int sock, struct sockaddr_un *saddr, int saddr_len, int fd) { - return send_fds(sock, saddr, saddr_len, &fd, 1); + return send_fds(sock, saddr, saddr_len, &fd, 1, false); } static inline int recv_fd(int sock) { int fd, ret; - ret = recv_fds(sock, &fd, 1); + ret = recv_fds(sock, &fd, 1, NULL); if (ret) return -1; diff --git a/parasite-syscall.c b/parasite-syscall.c index fe5af5a0e..3678f44af 100644 --- a/parasite-syscall.c +++ b/parasite-syscall.c @@ -583,7 +583,7 @@ int parasite_drain_fds_seized(struct parasite_ctl *ctl, int *fds, int *lfds, int goto err; } - ret = recv_fds(sock, lfds, nr_fds); + ret = recv_fds(sock, lfds, nr_fds, NULL); if (ret) { pr_err("Can't retrieve FDs from socket\n"); goto err; diff --git a/parasite.c b/parasite.c index 346787838..52a831000 100644 --- a/parasite.c +++ b/parasite.c @@ -373,7 +373,7 @@ static int drain_fds(struct parasite_drain_fd *args) int ret; ret = send_fds(tsock, &args->saddr, args->sun_len, - args->fds, args->nr_fds); + args->fds, args->nr_fds, false); if (ret) { sys_write_msg("send_fds failed\n"); SET_PARASITE_RET(st, ret); diff --git a/util-net.c b/util-net.c index e904ea879..090a8d7e7 100644 --- a/util-net.c +++ b/util-net.c @@ -20,17 +20,16 @@ static void scm_fdset_init_chunk(struct scm_fdset *fdset, int nr_fds) cmsg->cmsg_len = fdset->hdr.msg_controllen; } -static int *scm_fdset_init(struct scm_fdset *fdset, struct sockaddr_un *saddr, int saddr_len) +static int *scm_fdset_init(struct scm_fdset *fdset, struct sockaddr_un *saddr, + int saddr_len, bool with_flags) { struct cmsghdr *cmsg; BUILD_BUG_ON(CR_SCM_MAX_FD > SCM_MAX_FD); BUILD_BUG_ON(sizeof(fdset->msg_buf) < (CMSG_SPACE(sizeof(int) * CR_SCM_MAX_FD))); - fdset->msg = '*'; - fdset->iov.iov_base = &fdset->msg; - fdset->iov.iov_len = sizeof(fdset->msg); + fdset->iov.iov_len = with_flags ? sizeof(fdset->msg) : 1; fdset->hdr.msg_iov = &fdset->iov; fdset->hdr.msg_iovlen = 1; @@ -48,18 +47,33 @@ static int *scm_fdset_init(struct scm_fdset *fdset, struct sockaddr_un *saddr, i return (int *)CMSG_DATA(cmsg); } -int send_fds(int sock, struct sockaddr_un *saddr, int len, int *fds, int nr_fds) +int send_fds(int sock, struct sockaddr_un *saddr, int len, + int *fds, int nr_fds, bool with_flags) { struct scm_fdset fdset; int *cmsg_data; int i, min_fd, ret; - cmsg_data = scm_fdset_init(&fdset, saddr, len); + cmsg_data = scm_fdset_init(&fdset, saddr, len, with_flags); for (i = 0; i < nr_fds; i += min_fd) { min_fd = min(CR_SCM_MAX_FD, nr_fds - i); scm_fdset_init_chunk(&fdset, min_fd); builtin_memcpy(cmsg_data, &fds[i], sizeof(int) * min_fd); + if (with_flags) { + int j; + + for (j = 0; j < min_fd; j++) { + int flags; + + flags = sys_fcntl(fds[i + j], F_GETFD, 0); + if (flags < 0) + return -1; + + fdset.msg[j] = (char)flags; + } + } + ret = sys_sendmsg(sock, &fdset.hdr, 0); if (ret <= 0) return ret ? : -1; @@ -68,7 +82,7 @@ int send_fds(int sock, struct sockaddr_un *saddr, int len, int *fds, int nr_fds) return 0; } -int recv_fds(int sock, int *fds, int nr_fds) +int recv_fds(int sock, int *fds, int nr_fds, char *flags) { struct scm_fdset fdset; struct cmsghdr *cmsg; @@ -76,7 +90,7 @@ int recv_fds(int sock, int *fds, int nr_fds) int ret; int i, min_fd; - cmsg_data = scm_fdset_init(&fdset, NULL, 0); + cmsg_data = scm_fdset_init(&fdset, NULL, 0, flags != NULL); for (i = 0; i < nr_fds; i += min_fd) { min_fd = min(CR_SCM_MAX_FD, nr_fds - i); scm_fdset_init_chunk(&fdset, min_fd); @@ -104,6 +118,8 @@ int recv_fds(int sock, int *fds, int nr_fds) if (unlikely(min_fd <= 0)) return -1; builtin_memcpy(&fds[i], cmsg_data, sizeof(int) * min_fd); + if (flags) + builtin_memcpy(flags, fdset.msg, sizeof(char) * min_fd); } return 0;