diff options
Diffstat (limited to 'net/ipv4/udp.c')
-rw-r--r-- | net/ipv4/udp.c | 248 |
1 files changed, 178 insertions, 70 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index cec0f2cc49b7..5da703e699da 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -114,14 +114,36 @@ DEFINE_RWLOCK(udp_hash_lock); static int udp_port_rover; -static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[]) +/* + * Note about this hash function : + * Typical use is probably daddr = 0, only dport is going to vary hash + */ +static inline unsigned int udp_hash_port(__u16 port) +{ + return port; +} + +static inline int __udp_lib_port_inuse(unsigned int hash, int port, + const struct sock *this_sk, + struct hlist_head udptable[], + const struct udp_get_port_ops *ops) { struct sock *sk; struct hlist_node *node; + struct inet_sock *inet; - sk_for_each(sk, node, &udptable[num & (UDP_HTABLE_SIZE - 1)]) - if (sk->sk_hash == num) + sk_for_each(sk, node, &udptable[hash & (UDP_HTABLE_SIZE - 1)]) { + if (sk->sk_hash != hash) + continue; + inet = inet_sk(sk); + if (inet->num != port) + continue; + if (this_sk) { + if (ops->saddr_cmp(sk, this_sk)) + return 1; + } else if (ops->saddr_any(sk)) return 1; + } return 0; } @@ -132,16 +154,16 @@ static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[]) * @snum: port number to look up * @udptable: hash list table, must be of UDP_HTABLE_SIZE * @port_rover: pointer to record of last unallocated port - * @saddr_comp: AF-dependent comparison of bound local IP addresses + * @ops: AF-dependent address operations */ int __udp_lib_get_port(struct sock *sk, unsigned short snum, struct hlist_head udptable[], int *port_rover, - int (*saddr_comp)(const struct sock *sk1, - const struct sock *sk2 ) ) + const struct udp_get_port_ops *ops) { struct hlist_node *node; struct hlist_head *head; struct sock *sk2; + unsigned int hash; int error = 1; write_lock_bh(&udp_hash_lock); @@ -156,7 +178,8 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) { int size; - head = &udptable[result & (UDP_HTABLE_SIZE - 1)]; + hash = ops->hash_port_and_rcv_saddr(result, sk); + head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; if (hlist_empty(head)) { if (result > sysctl_local_port_range[1]) result = sysctl_local_port_range[0] + @@ -181,7 +204,16 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, result = sysctl_local_port_range[0] + ((result - sysctl_local_port_range[0]) & (UDP_HTABLE_SIZE - 1)); - if (! __udp_lib_lport_inuse(result, udptable)) + hash = udp_hash_port(result); + if (__udp_lib_port_inuse(hash, result, + NULL, udptable, ops)) + continue; + if (ops->saddr_any(sk)) + break; + + hash = ops->hash_port_and_rcv_saddr(result, sk); + if (! __udp_lib_port_inuse(hash, result, + sk, udptable, ops)) break; } if (i >= (1 << 16) / UDP_HTABLE_SIZE) @@ -189,21 +221,40 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, gotit: *port_rover = snum = result; } else { - head = &udptable[snum & (UDP_HTABLE_SIZE - 1)]; + hash = udp_hash_port(snum); + head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; sk_for_each(sk2, node, head) - if (sk2->sk_hash == snum && - sk2 != sk && - (!sk2->sk_reuse || !sk->sk_reuse) && - (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if - || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && - (*saddr_comp)(sk, sk2) ) + if (sk2->sk_hash == hash && + sk2 != sk && + inet_sk(sk2)->num == snum && + (!sk2->sk_reuse || !sk->sk_reuse) && + (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || + sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && + ops->saddr_cmp(sk, sk2)) goto fail; + + if (!ops->saddr_any(sk)) { + hash = ops->hash_port_and_rcv_saddr(snum, sk); + head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; + + sk_for_each(sk2, node, head) + if (sk2->sk_hash == hash && + sk2 != sk && + inet_sk(sk2)->num == snum && + (!sk2->sk_reuse || !sk->sk_reuse) && + (!sk2->sk_bound_dev_if || + !sk->sk_bound_dev_if || + sk2->sk_bound_dev_if == + sk->sk_bound_dev_if) && + ops->saddr_cmp(sk, sk2)) + goto fail; + } } inet_sk(sk)->num = snum; - sk->sk_hash = snum; + sk->sk_hash = hash; if (sk_unhashed(sk)) { - head = &udptable[snum & (UDP_HTABLE_SIZE - 1)]; + head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; sk_add_node(sk, head); sock_prot_inc_use(sk->sk_prot); } @@ -214,12 +265,12 @@ fail: } int udp_get_port(struct sock *sk, unsigned short snum, - int (*scmp)(const struct sock *, const struct sock *)) + const struct udp_get_port_ops *ops) { - return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp); + return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, ops); } -int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) +static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) { struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2); @@ -228,9 +279,33 @@ int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) inet1->rcv_saddr == inet2->rcv_saddr )); } +static int ipv4_rcv_saddr_any(const struct sock *sk) +{ + return !inet_sk(sk)->rcv_saddr; +} + +static inline unsigned int ipv4_hash_port_and_addr(__u16 port, __be32 addr) +{ + addr ^= addr >> 16; + addr ^= addr >> 8; + return port ^ addr; +} + +static unsigned int ipv4_hash_port_and_rcv_saddr(__u16 port, + const struct sock *sk) +{ + return ipv4_hash_port_and_addr(port, inet_sk(sk)->rcv_saddr); +} + +const struct udp_get_port_ops udp_ipv4_ops = { + .saddr_cmp = ipv4_rcv_saddr_equal, + .saddr_any = ipv4_rcv_saddr_any, + .hash_port_and_rcv_saddr = ipv4_hash_port_and_rcv_saddr, +}; + static inline int udp_v4_get_port(struct sock *sk, unsigned short snum) { - return udp_get_port(sk, snum, ipv4_rcv_saddr_equal); + return udp_get_port(sk, snum, &udp_ipv4_ops); } /* UDP is nearly always wildcards out the wazoo, it makes no sense to try @@ -242,63 +317,77 @@ static struct sock *__udp4_lib_lookup(__be32 saddr, __be16 sport, { struct sock *sk, *result = NULL; struct hlist_node *node; - unsigned short hnum = ntohs(dport); - int badness = -1; + unsigned int hash, hashwild; + int score, best = -1, hport = ntohs(dport); + + hash = ipv4_hash_port_and_addr(hport, daddr); + hashwild = udp_hash_port(hport); read_lock(&udp_hash_lock); - sk_for_each(sk, node, &udptable[hnum & (UDP_HTABLE_SIZE - 1)]) { + +lookup: + + sk_for_each(sk, node, &udptable[hash & (UDP_HTABLE_SIZE - 1)]) { struct inet_sock *inet = inet_sk(sk); - if (sk->sk_hash == hnum && !ipv6_only_sock(sk)) { - int score = (sk->sk_family == PF_INET ? 1 : 0); - if (inet->rcv_saddr) { - if (inet->rcv_saddr != daddr) - continue; - score+=2; - } - if (inet->daddr) { - if (inet->daddr != saddr) - continue; - score+=2; - } - if (inet->dport) { - if (inet->dport != sport) - continue; - score+=2; - } - if (sk->sk_bound_dev_if) { - if (sk->sk_bound_dev_if != dif) - continue; - score+=2; - } - if (score == 9) { - result = sk; - break; - } else if (score > badness) { - result = sk; - badness = score; - } + if (sk->sk_hash != hash || ipv6_only_sock(sk) || + inet->num != hport) + continue; + + score = (sk->sk_family == PF_INET ? 1 : 0); + if (inet->rcv_saddr) { + if (inet->rcv_saddr != daddr) + continue; + score+=2; + } + if (inet->daddr) { + if (inet->daddr != saddr) + continue; + score+=2; + } + if (inet->dport) { + if (inet->dport != sport) + continue; + score+=2; + } + if (sk->sk_bound_dev_if) { + if (sk->sk_bound_dev_if != dif) + continue; + score+=2; } + if (score == 9) { + result = sk; + goto found; + } else if (score > best) { + result = sk; + best = score; + } + } + + if (hash != hashwild) { + hash = hashwild; + goto lookup; } +found: if (result) sock_hold(result); read_unlock(&udp_hash_lock); return result; } -static inline struct sock *udp_v4_mcast_next(struct sock *sk, - __be16 loc_port, __be32 loc_addr, +static inline struct sock *udp_v4_mcast_next(struct sock *sk, unsigned int hnum, + int hport, __be32 loc_addr, __be16 rmt_port, __be32 rmt_addr, int dif) { struct hlist_node *node; struct sock *s = sk; - unsigned short hnum = ntohs(loc_port); sk_for_each_from(s, node) { struct inet_sock *inet = inet_sk(s); if (s->sk_hash != hnum || + inet->num != hport || (inet->daddr && inet->daddr != rmt_addr) || (inet->dport != rmt_port && inet->dport) || (inet->rcv_saddr && inet->rcv_saddr != loc_addr) || @@ -633,8 +722,11 @@ int udp_sendmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg, .dport = dport } } }; security_sk_classify_flow(sk, &fl); err = ip_route_output_flow(&rt, &fl, sk, 1); - if (err) + if (err) { + if (err == -ENETUNREACH) + IP_INC_STATS_BH(IPSTATS_MIB_OUTNOROUTES); goto out; + } err = -EACCES; if ((rt->rt_flags & RTCF_BROADCAST) && @@ -917,7 +1009,7 @@ int udp_disconnect(struct sock *sk, int flags) } /* return: - * 1 if the the UDP system should process it + * 1 if the UDP system should process it * 0 if we should drop this packet * -1 if it should get processed by xfrm4_rcv_encap */ @@ -1129,29 +1221,45 @@ static int __udp4_lib_mcast_deliver(struct sk_buff *skb, __be32 saddr, __be32 daddr, struct hlist_head udptable[]) { - struct sock *sk; + struct sock *sk, *skw, *sknext; int dif; + int hport = ntohs(uh->dest); + unsigned int hash = ipv4_hash_port_and_addr(hport, daddr); + unsigned int hashwild = udp_hash_port(hport); - read_lock(&udp_hash_lock); - sk = sk_head(&udptable[ntohs(uh->dest) & (UDP_HTABLE_SIZE - 1)]); dif = skb->dev->ifindex; - sk = udp_v4_mcast_next(sk, uh->dest, daddr, uh->source, saddr, dif); - if (sk) { - struct sock *sknext = NULL; + read_lock(&udp_hash_lock); + + sk = sk_head(&udptable[hash & (UDP_HTABLE_SIZE - 1)]); + skw = sk_head(&udptable[hashwild & (UDP_HTABLE_SIZE - 1)]); + + sk = udp_v4_mcast_next(sk, hash, hport, daddr, uh->source, saddr, dif); + if (!sk) { + hash = hashwild; + sk = udp_v4_mcast_next(skw, hash, hport, daddr, uh->source, + saddr, dif); + } + if (sk) { do { struct sk_buff *skb1 = skb; - - sknext = udp_v4_mcast_next(sk_next(sk), uh->dest, daddr, - uh->source, saddr, dif); + sknext = udp_v4_mcast_next(sk_next(sk), hash, hport, + daddr, uh->source, saddr, dif); + if (!sknext && hash != hashwild) { + hash = hashwild; + sknext = udp_v4_mcast_next(skw, hash, hport, + daddr, uh->source, saddr, dif); + } if (sknext) skb1 = skb_clone(skb, GFP_ATOMIC); if (skb1) { int ret = udp_queue_rcv_skb(sk, skb1); if (ret > 0) - /* we should probably re-process instead - * of dropping packets here. */ + /* + * we should probably re-process + * instead of dropping packets here. + */ kfree_skb(skb1); } sk = sknext; @@ -1238,7 +1346,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct hlist_head udptable[], return __udp4_lib_mcast_deliver(skb, uh, saddr, daddr, udptable); sk = __udp4_lib_lookup(saddr, uh->source, daddr, uh->dest, - skb->dev->ifindex, udptable ); + skb->dev->ifindex, udptable); if (sk != NULL) { int ret = udp_queue_rcv_skb(sk, skb); |