summaryrefslogtreecommitdiff
path: root/net/ipv6/udp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv6/udp.c')
-rw-r--r--net/ipv6/udp.c86
1 files changed, 65 insertions, 21 deletions
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 9da3287a3..422dd014a 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -47,6 +47,7 @@
#include <net/xfrm.h>
#include <net/inet6_hashtables.h>
#include <net/busy_poll.h>
+#include <net/sock_reuseport.h>
#include <linux/proc_fs.h>
#include <linux/seq_file.h>
@@ -76,7 +77,14 @@ static u32 udp6_ehashfn(const struct net *net,
udp_ipv6_hash_secret + net_hash_mix(net));
}
-int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2)
+/* match_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6
+ * only, and any IPv4 addresses if not IPv6 only
+ * match_wildcard == false: addresses must be exactly the same, i.e.
+ * IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
+ * and 0.0.0.0 equals to 0.0.0.0 only
+ */
+int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
+ bool match_wildcard)
{
const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
int sk2_ipv6only = inet_v6_ipv6only(sk2);
@@ -84,16 +92,24 @@ int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2)
int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
/* if both are mapped, treat as IPv4 */
- if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED)
- return (!sk2_ipv6only &&
- (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr ||
- sk->sk_rcv_saddr == sk2->sk_rcv_saddr));
+ if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
+ if (!sk2_ipv6only) {
+ if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr)
+ return 1;
+ if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr)
+ return match_wildcard;
+ }
+ return 0;
+ }
- if (addr_type2 == IPV6_ADDR_ANY &&
+ if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY)
+ return 1;
+
+ if (addr_type2 == IPV6_ADDR_ANY && match_wildcard &&
!(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
return 1;
- if (addr_type == IPV6_ADDR_ANY &&
+ if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
!(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED))
return 1;
@@ -235,11 +251,13 @@ static inline int compute_score2(struct sock *sk, struct net *net,
static struct sock *udp6_lib_lookup2(struct net *net,
const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, unsigned int hnum, int dif,
- struct udp_hslot *hslot2, unsigned int slot2)
+ struct udp_hslot *hslot2, unsigned int slot2,
+ struct sk_buff *skb)
{
struct sock *sk, *result;
struct hlist_nulls_node *node;
int score, badness, matches = 0, reuseport = 0;
+ bool select_ok = true;
u32 hash = 0;
begin:
@@ -255,6 +273,17 @@ begin:
if (reuseport) {
hash = udp6_ehashfn(net, daddr, hnum,
saddr, sport);
+ if (select_ok) {
+ struct sock *sk2;
+
+ sk2 = reuseport_select_sock(sk, hash, skb,
+ sizeof(struct udphdr));
+ if (sk2) {
+ result = sk2;
+ select_ok = false;
+ goto found;
+ }
+ }
matches = 1;
}
} else if (score == badness && reuseport) {
@@ -273,6 +302,7 @@ begin:
goto begin;
if (result) {
+found:
if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
result = NULL;
else if (unlikely(compute_score2(result, net, saddr, sport,
@@ -287,7 +317,8 @@ begin:
struct sock *__udp6_lib_lookup(struct net *net,
const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, __be16 dport,
- int dif, struct udp_table *udptable)
+ int dif, struct udp_table *udptable,
+ struct sk_buff *skb)
{
struct sock *sk, *result;
struct hlist_nulls_node *node;
@@ -295,6 +326,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
unsigned int hash2, slot2, slot = udp_hashfn(net, hnum, udptable->mask);
struct udp_hslot *hslot2, *hslot = &udptable->hash[slot];
int score, badness, matches = 0, reuseport = 0;
+ bool select_ok = true;
u32 hash = 0;
rcu_read_lock();
@@ -307,7 +339,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
result = udp6_lib_lookup2(net, saddr, sport,
daddr, hnum, dif,
- hslot2, slot2);
+ hslot2, slot2, skb);
if (!result) {
hash2 = udp6_portaddr_hash(net, &in6addr_any, hnum);
slot2 = hash2 & udptable->mask;
@@ -317,7 +349,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
result = udp6_lib_lookup2(net, saddr, sport,
&in6addr_any, hnum, dif,
- hslot2, slot2);
+ hslot2, slot2, skb);
}
rcu_read_unlock();
return result;
@@ -334,6 +366,17 @@ begin:
if (reuseport) {
hash = udp6_ehashfn(net, daddr, hnum,
saddr, sport);
+ if (select_ok) {
+ struct sock *sk2;
+
+ sk2 = reuseport_select_sock(sk, hash, skb,
+ sizeof(struct udphdr));
+ if (sk2) {
+ result = sk2;
+ select_ok = false;
+ goto found;
+ }
+ }
matches = 1;
}
} else if (score == badness && reuseport) {
@@ -352,6 +395,7 @@ begin:
goto begin;
if (result) {
+found:
if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
result = NULL;
else if (unlikely(compute_score(result, net, hnum, saddr, sport,
@@ -377,13 +421,13 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
return sk;
return __udp6_lib_lookup(dev_net(skb_dst(skb)->dev), &iph->saddr, sport,
&iph->daddr, dport, inet6_iif(skb),
- udptable);
+ udptable, skb);
}
struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, __be16 dport, int dif)
{
- return __udp6_lib_lookup(net, saddr, sport, daddr, dport, dif, &udp_table);
+ return __udp6_lib_lookup(net, saddr, sport, daddr, dport, dif, &udp_table, NULL);
}
EXPORT_SYMBOL_GPL(udp6_lib_lookup);
@@ -402,6 +446,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int peeked, off = 0;
int err;
int is_udplite = IS_UDPLITE(sk);
+ bool checksum_valid = false;
int is_udp4;
bool slow;
@@ -433,11 +478,12 @@ try_again:
*/
if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) {
- if (udp_lib_checksum_complete(skb))
+ checksum_valid = !udp_lib_checksum_complete(skb);
+ if (!checksum_valid)
goto csum_copy_err;
}
- if (skb_csum_unnecessary(skb))
+ if (checksum_valid || skb_csum_unnecessary(skb))
err = skb_copy_datagram_msg(skb, sizeof(struct udphdr),
msg, copied);
else {
@@ -547,8 +593,8 @@ void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
int err;
struct net *net = dev_net(skb->dev);
- sk = __udp6_lib_lookup(net, daddr, uh->dest,
- saddr, uh->source, inet6_iif(skb), udptable);
+ sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source,
+ inet6_iif(skb), udptable, skb);
if (!sk) {
ICMP6_INC_STATS_BH(net, __in6_dev_get(skb->dev),
ICMP6_MIB_INERRORS);
@@ -916,11 +962,9 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
ret = udpv6_queue_rcv_skb(sk, skb);
sock_put(sk);
- /* a return value > 0 means to resubmit the input, but
- * it wants the return to be -protocol, or 0
- */
+ /* a return value > 0 means to resubmit the input */
if (ret > 0)
- return -ret;
+ return ret;
return 0;
}