diff --git a/sys/dev/if_wg/module/if_wg_session.c b/sys/dev/if_wg/module/if_wg_session.c index 492f356bca9..13ad2ea5c9f 100644 --- a/sys/dev/if_wg/module/if_wg_session.c +++ b/sys/dev/if_wg/module/if_wg_session.c @@ -1859,6 +1859,40 @@ wg_index_drop(struct wg_softc *sc, uint32_t key0) SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry); } +static int +wg_update_endpoint_addrs(struct wg_endpoint *e, const struct sockaddr *srcsa, + struct ifnet *rcvif) +{ + const struct sockaddr_in *sa4; + const struct sockaddr_in6 *sa6; + int ret = 0; + + /* + * UDP passes a 2-element sockaddr array: first element is the + * source addr/port, second the destination addr/port. + */ + if (srcsa->sa_family == AF_INET) { + sa4 = (const struct sockaddr_in *)srcsa; + e->e_remote.r_sin = sa4[0]; + /* Only update dest if not mcast/bcast */ + if (!(IN_MULTICAST(ntohl(sa4[1].sin_addr.s_addr)) || + sa4[1].sin_addr.s_addr == INADDR_BROADCAST || + in_broadcast(sa4[1].sin_addr, rcvif))) { + e->e_local.l_in = sa4[1].sin_addr; + } + } else if (srcsa->sa_family == AF_INET6) { + sa6 = (const struct sockaddr_in6 *)srcsa; + e->e_remote.r_sin6 = sa6[0]; + /* Only update dest if not multicast */ + if (!IN6_IS_ADDR_MULTICAST(&sa6[1].sin6_addr)) + e->e_local.l_in6 = sa6[1].sin6_addr; + } else { + ret = EAFNOSUPPORT; + } + + return (ret); +} + static void wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb, const struct sockaddr *srcsa, void *_sc) @@ -1890,12 +1924,11 @@ wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb, goto free; } e = wg_mbuf_endpoint_get(m); - if (srcsa->sa_family == AF_INET) - e->e_remote.r_sin = *(const struct sockaddr_in *)srcsa; - else if (srcsa->sa_family == AF_INET6) - e->e_remote.r_sin6 = *(const struct sockaddr_in6 *)srcsa; - else - e->e_remote.r_sa = *srcsa; + + if (wg_update_endpoint_addrs(e, srcsa, m->m_pkthdr.rcvif)) { + DPRINTF(sc, "unknown family\n"); + goto free; + } verify_endpoint(m); if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);