diff --git a/src/WireGuard-ESP32.h b/src/WireGuard-ESP32.h index b30c884..0d68094 100644 --- a/src/WireGuard-ESP32.h +++ b/src/WireGuard-ESP32.h @@ -4,12 +4,25 @@ */ #pragma once #include +#include + +typedef struct { + IPAddress localIP; + IPAddress gateway; + IPAddress subnet; + const char *privateKey; + const char *peerHost; + uint16_t peerPort; + const char *peerPublickKey; + struct netif *underlineNetif; +} WireGuardConfig; class WireGuard { private: bool _is_initialized = false; public: + bool begin(WireGuardConfig *config); bool begin(const IPAddress& localIP, const IPAddress& Subnet, const IPAddress& Gateway, const char* privateKey, const char* remotePeerAddress, const char* remotePeerPublicKey, uint16_t remotePeerPort); bool begin(const IPAddress& localIP, const char* privateKey, const char* remotePeerAddress, const char* remotePeerPublicKey, uint16_t remotePeerPort); void end(); diff --git a/src/WireGuard.cpp b/src/WireGuard.cpp index ce69650..361a287 100644 --- a/src/WireGuard.cpp +++ b/src/WireGuard.cpp @@ -14,6 +14,8 @@ #include "lwip/ip.h" #include "lwip/netdb.h" +#include "tcpip_adapter.h" + #include "esp32-hal-log.h" extern "C" { @@ -29,7 +31,15 @@ static uint8_t wireguard_peer_index = WIREGUARDIF_INVALID_INDEX; #define TAG "[WireGuard] " -bool WireGuard::begin(const IPAddress& localIP, const IPAddress& Subnet, const IPAddress& Gateway, const char* privateKey, const char* remotePeerAddress, const char* remotePeerPublicKey, uint16_t remotePeerPort) { +bool WireGuard::begin(WireGuardConfig *config) { + IPAddress localIP = config->localIP; + IPAddress Subnet = config->subnet; + IPAddress Gateway = config->gateway; + const char* privateKey = config->privateKey; + const char* remotePeerAddress = config->peerHost; + const char* remotePeerPublicKey = config->peerPublickKey; + uint16_t remotePeerPort = config->peerPort; + struct wireguardif_init_data wg; struct wireguardif_peer peer; ip_addr_t ipaddr = IPADDR4_INIT(static_cast(localIP)); @@ -43,9 +53,14 @@ bool WireGuard::begin(const IPAddress& localIP, const IPAddress& Subnet, const I // Setup the WireGuard device structure wg.private_key = privateKey; - wg.listen_port = remotePeerPort; + wg.listen_port = remotePeerPort; - wg.bind_netif = NULL; + if (config->underlineNetif) { + wg.bind_netif = config->underlineNetif; + } else { + tcpip_adapter_get_netif(TCPIP_ADAPTER_IF_STA, (void **)&wg.bind_netif); + log_i(TAG "underlying_netif = %p", wg.bind_netif); + } // Initialise the first WireGuard peer structure wireguardif_peer_init(&peer); @@ -119,6 +134,20 @@ bool WireGuard::begin(const IPAddress& localIP, const IPAddress& Subnet, const I return true; } +bool WireGuard::begin(const IPAddress& localIP, const IPAddress& Subnet, const IPAddress& Gateway, const char* privateKey, const char* remotePeerAddress, const char* remotePeerPublicKey, uint16_t remotePeerPort) { + WireGuardConfig config = { + .localIP = localIP, + .gateway = Gateway, + .subnet = Subnet, + .privateKey = privateKey, + .peerHost = remotePeerAddress, + .peerPort = remotePeerPort, + .peerPublickKey = remotePeerPublicKey, + .underlineNetif = NULL, + }; + return WireGuard::begin(&config); +} + bool WireGuard::begin(const IPAddress& localIP, const char* privateKey, const char* remotePeerAddress, const char* remotePeerPublicKey, uint16_t remotePeerPort) { // Maintain compatiblity with old begin auto subnet = IPAddress(255,255,255,255); @@ -144,4 +173,4 @@ void WireGuard::end() { wg_netif = nullptr; this->_is_initialized = false; -} \ No newline at end of file +} diff --git a/src/wireguardif.c b/src/wireguardif.c index d64ad85..4dcc906 100644 --- a/src/wireguardif.c +++ b/src/wireguardif.c @@ -920,10 +920,6 @@ err_t wireguardif_init(struct netif *netif) { uint8_t private_key[WIREGUARD_PRIVATE_KEY_LEN]; size_t private_key_len = sizeof(private_key); - struct netif* underlying_netif; - tcpip_adapter_get_netif(TCPIP_ADAPTER_IF_STA, &underlying_netif); - log_i(TAG "underlying_netif = %p", underlying_netif); - LWIP_ASSERT("netif != NULL", (netif != NULL)); LWIP_ASSERT("state != NULL", (netif->state != NULL)); @@ -950,7 +946,7 @@ err_t wireguardif_init(struct netif *netif) { device = (struct wireguard_device *)mem_calloc(1, sizeof(struct wireguard_device)); if (device) { device->netif = netif; - device->underlying_netif = underlying_netif; + device->underlying_netif = init_data->bind_netif; //udp_bind_netif(udp, underlying_netif); device->udp_pcb = udp;