diff --git a/common/packet.c b/common/packet.c index 9722dd01..91f427f8 100644 --- a/common/packet.c +++ b/common/packet.c @@ -48,7 +48,6 @@ static char copyright[] = #include "dhcpd.h" #include -#include #include #include #include @@ -69,14 +68,14 @@ void assemble_hw_header (interface, buf, bufix, to) struct ether_header eh; if (to) - memcpy (eh.ether_dhost, to -> haddr, sizeof eh.ether_dhost); + memcpy (ETHER_DEST (&eh), to -> haddr, sizeof eh.ether_dhost); else - memset (eh.ether_dhost, 0xff, sizeof (eh.ether_shost)); + memset (ETHER_DEST (&eh), 0xff, sizeof (eh.ether_shost)); if (interface -> hw_address.hlen == sizeof (eh.ether_shost)) - memcpy (eh.ether_shost, interface -> hw_address.haddr, + memcpy (ETHER_SRC (&eh), interface -> hw_address.haddr, sizeof (eh.ether_shost)); else - memset (eh.ether_shost, 0x00, sizeof (eh.ether_shost)); + memset (ETHER_SRC (&eh), 0x00, sizeof (eh.ether_shost)); eh.ether_type = htons (ETHERTYPE_IP); memcpy (&buf [*bufix], &eh, sizeof eh); @@ -152,10 +151,14 @@ size_t decode_hw_header (interface, buf, bufix, from) { struct ether_header eh; - memcpy (&eh, buf +bufix, sizeof eh); + memcpy (&eh, buf + bufix, sizeof eh); +#ifdef USERLAND_FILTER + if (ntohs (eh.ether_type) != ETHERTYPE_IP) + return -1; +#endif memcpy (from -> haddr, eh.ether_shost, sizeof (eh.ether_shost)); - from -> htype = ETHERTYPE_IP; + from -> htype = ARPHRD_ETHER; from -> hlen = sizeof eh.ether_shost; return sizeof eh; @@ -175,22 +178,42 @@ size_t decode_udp_ip_header (interface, buf, bufix, from, data, len) struct udphdr *udp; u_int32_t ip_len = (buf [bufix] & 0xf) << 2; u_int32_t sum, usum; +#ifdef USERLAND_FILTER + u_int32_t ibcst = INADDR_BROADCAST; +#endif + + ip = (struct ip *)(buf + bufix); + udp = (struct udphdr *)(buf + bufix + ip_len); + +#ifdef USERLAND_FILTER + /* Is it a UDP packet? */ + if (ip -> ip_p != IPPROTO_UDP) + return -1; + + /* Is it to the port we're serving? */ + if (udp -> uh_dport != server_port && udp -> uh_dport != server_port + 1) + return -1; + + /* Is it to this IP address? */ + if (memcmp (&ip -> ip_dst, &ibcst, sizeof ibcst) && + memcmp (&ip -> ip_dst, interface -> address.iabuf, 4)) + return -1; +#endif /* USERLAND_FILTER */ /* Check the IP header checksum - it should be zero. */ if (wrapsum (checksum (buf + bufix, ip_len, 0))) { note ("Bad IP checksum: %x", - wrapsum (checksum (buf + bufix, sizeof ip, 0))); + wrapsum (checksum (buf + bufix, sizeof *ip, 0))); + return -1; } /* Copy out the IP source address... */ - ip = (struct ip *)(buf + bufix); memcpy (&from -> sin_addr, &ip -> ip_src, 4); /* Compute UDP checksums, including the ``pseudo-header'', the UDP header and the data. If the UDP checksum field is zero, we're not supposed to do a checksum. */ - udp = (struct udphdr *)(buf + bufix + ip_len); if (!data) { data = buf + bufix + ip_len + sizeof *udp; len -= ip_len + sizeof *udp; @@ -198,6 +221,7 @@ size_t decode_udp_ip_header (interface, buf, bufix, from, data, len) usum = udp -> uh_sum; udp -> uh_sum = 0; + sum = wrapsum (checksum ((unsigned char *)udp, sizeof *udp, checksum (data, len, checksum ((unsigned char *) @@ -220,14 +244,18 @@ size_t decode_udp_ip_header (interface, buf, bufix, from, data, len) /* Compute the easy part of the checksum on a range of bytes. */ -static u_int32_t checksum (unsigned char *buf, int nbytes, u_int32_t sum) +static u_int32_t checksum (buf, nbytes, sum) + unsigned char *buf; + int nbytes; + u_int32_t sum; { int i; /* Checksum all the pairs of bytes first... */ - for (i = 0; i < (nbytes & ~1); i += 2) + for (i = 0; i < (nbytes & ~1); i += 2) { sum += (u_int16_t) ntohs(*((u_int16_t *)buf)++); - + } + /* If there's a single byte left over, checksum it, too. Network byte order is big-endian, so the remaining byte is the high byte. */ if (i < nbytes) { @@ -240,13 +268,14 @@ static u_int32_t checksum (unsigned char *buf, int nbytes, u_int32_t sum) /* Fold the upper sixteen bits of the checksum down into the lower bits, complement the sum, and then put it into network byte order. */ -static u_int32_t wrapsum (u_int32_t sum) +static u_int32_t wrapsum (sum) + u_int32_t sum; { while (sum > 0x10000) { sum = (sum >> 16) + (sum & 0xFFFF); sum += (sum >> 16); } - sum = ~sum; + sum = sum ^ 0xFFFF; return htons(sum); } diff --git a/packet.c b/packet.c index 9722dd01..91f427f8 100644 --- a/packet.c +++ b/packet.c @@ -48,7 +48,6 @@ static char copyright[] = #include "dhcpd.h" #include -#include #include #include #include @@ -69,14 +68,14 @@ void assemble_hw_header (interface, buf, bufix, to) struct ether_header eh; if (to) - memcpy (eh.ether_dhost, to -> haddr, sizeof eh.ether_dhost); + memcpy (ETHER_DEST (&eh), to -> haddr, sizeof eh.ether_dhost); else - memset (eh.ether_dhost, 0xff, sizeof (eh.ether_shost)); + memset (ETHER_DEST (&eh), 0xff, sizeof (eh.ether_shost)); if (interface -> hw_address.hlen == sizeof (eh.ether_shost)) - memcpy (eh.ether_shost, interface -> hw_address.haddr, + memcpy (ETHER_SRC (&eh), interface -> hw_address.haddr, sizeof (eh.ether_shost)); else - memset (eh.ether_shost, 0x00, sizeof (eh.ether_shost)); + memset (ETHER_SRC (&eh), 0x00, sizeof (eh.ether_shost)); eh.ether_type = htons (ETHERTYPE_IP); memcpy (&buf [*bufix], &eh, sizeof eh); @@ -152,10 +151,14 @@ size_t decode_hw_header (interface, buf, bufix, from) { struct ether_header eh; - memcpy (&eh, buf +bufix, sizeof eh); + memcpy (&eh, buf + bufix, sizeof eh); +#ifdef USERLAND_FILTER + if (ntohs (eh.ether_type) != ETHERTYPE_IP) + return -1; +#endif memcpy (from -> haddr, eh.ether_shost, sizeof (eh.ether_shost)); - from -> htype = ETHERTYPE_IP; + from -> htype = ARPHRD_ETHER; from -> hlen = sizeof eh.ether_shost; return sizeof eh; @@ -175,22 +178,42 @@ size_t decode_udp_ip_header (interface, buf, bufix, from, data, len) struct udphdr *udp; u_int32_t ip_len = (buf [bufix] & 0xf) << 2; u_int32_t sum, usum; +#ifdef USERLAND_FILTER + u_int32_t ibcst = INADDR_BROADCAST; +#endif + + ip = (struct ip *)(buf + bufix); + udp = (struct udphdr *)(buf + bufix + ip_len); + +#ifdef USERLAND_FILTER + /* Is it a UDP packet? */ + if (ip -> ip_p != IPPROTO_UDP) + return -1; + + /* Is it to the port we're serving? */ + if (udp -> uh_dport != server_port && udp -> uh_dport != server_port + 1) + return -1; + + /* Is it to this IP address? */ + if (memcmp (&ip -> ip_dst, &ibcst, sizeof ibcst) && + memcmp (&ip -> ip_dst, interface -> address.iabuf, 4)) + return -1; +#endif /* USERLAND_FILTER */ /* Check the IP header checksum - it should be zero. */ if (wrapsum (checksum (buf + bufix, ip_len, 0))) { note ("Bad IP checksum: %x", - wrapsum (checksum (buf + bufix, sizeof ip, 0))); + wrapsum (checksum (buf + bufix, sizeof *ip, 0))); + return -1; } /* Copy out the IP source address... */ - ip = (struct ip *)(buf + bufix); memcpy (&from -> sin_addr, &ip -> ip_src, 4); /* Compute UDP checksums, including the ``pseudo-header'', the UDP header and the data. If the UDP checksum field is zero, we're not supposed to do a checksum. */ - udp = (struct udphdr *)(buf + bufix + ip_len); if (!data) { data = buf + bufix + ip_len + sizeof *udp; len -= ip_len + sizeof *udp; @@ -198,6 +221,7 @@ size_t decode_udp_ip_header (interface, buf, bufix, from, data, len) usum = udp -> uh_sum; udp -> uh_sum = 0; + sum = wrapsum (checksum ((unsigned char *)udp, sizeof *udp, checksum (data, len, checksum ((unsigned char *) @@ -220,14 +244,18 @@ size_t decode_udp_ip_header (interface, buf, bufix, from, data, len) /* Compute the easy part of the checksum on a range of bytes. */ -static u_int32_t checksum (unsigned char *buf, int nbytes, u_int32_t sum) +static u_int32_t checksum (buf, nbytes, sum) + unsigned char *buf; + int nbytes; + u_int32_t sum; { int i; /* Checksum all the pairs of bytes first... */ - for (i = 0; i < (nbytes & ~1); i += 2) + for (i = 0; i < (nbytes & ~1); i += 2) { sum += (u_int16_t) ntohs(*((u_int16_t *)buf)++); - + } + /* If there's a single byte left over, checksum it, too. Network byte order is big-endian, so the remaining byte is the high byte. */ if (i < nbytes) { @@ -240,13 +268,14 @@ static u_int32_t checksum (unsigned char *buf, int nbytes, u_int32_t sum) /* Fold the upper sixteen bits of the checksum down into the lower bits, complement the sum, and then put it into network byte order. */ -static u_int32_t wrapsum (u_int32_t sum) +static u_int32_t wrapsum (sum) + u_int32_t sum; { while (sum > 0x10000) { sum = (sum >> 16) + (sum & 0xFFFF); sum += (sum >> 16); } - sum = ~sum; + sum = sum ^ 0xFFFF; return htons(sum); }