diff --git a/Makefile b/Makefile index ff087de..890fa2a 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,18 @@ CXX := i686-w64-mingw32-g++ -CXXFLAGS := -std=c++11 -Wall +CXXFLAGS := -std=c++11 -Wall -D_WIN32_WINNT=0x0600 -ggdb TEST_CXXFLAGS := $(CXXFLAGS) -I./googletest/include/ all: dpnet.dll check: tests/all-tests.exe +ifeq ($(OS),Windows_NT) + tests/all-tests.exe +else wine tests/all-tests.exe +endif -dpnet.dll: src/dpnet.o src/dpnet.def src/DirectPlay8Address.o src/DirectPlay8Peer.o src/network.o src/packet.o +dpnet.dll: src/dpnet.o src/dpnet.def src/DirectPlay8Address.o src/DirectPlay8Peer.o src/network.o src/packet.o src/SendQueue.o src/AsyncHandleAllocator.o src/HostEnumerator.o src/COMAPIException.o $(CXX) $(CXXFLAGS) -Wl,--enable-stdcall-fixup -shared -o $@ $^ -ldxguid -lws2_32 -static-libstdc++ -static-libgcc tests/DirectPlay8Address.exe: tests/DirectPlay8Address.o src/DirectPlay8Address.o googletest/src/gtest-all.o googletest/src/gtest_main.o @@ -19,8 +23,10 @@ tests/PacketSerialiser.exe: tests/PacketSerialiser.o src/packet.o googletest/src tests/all-tests.exe: tests/DirectPlay8Address.o src/DirectPlay8Address.o \ tests/PacketSerialiser.o tests/PacketDeserialiser.o src/packet.o \ + tests/DirectPlay8Peer.o src/DirectPlay8Peer.o \ + src/SendQueue.o src/AsyncHandleAllocator.o src/HostEnumerator.o src/network.o src/COMAPIException.o \ googletest/src/gtest-all.o googletest/src/gtest_main.o - $(CXX) $(TEST_CXXFLAGS) -o $@ $^ -ldxguid -lole32 -static-libstdc++ -static-libgcc + $(CXX) $(TEST_CXXFLAGS) -o $@ $^ -ldxguid -lole32 -static-libstdc++ -static-libgcc -lws2_32 src/%.o: src/%.cpp $(CXX) $(CXXFLAGS) -c -o $@ $< diff --git a/src/AsyncHandleAllocator.cpp b/src/AsyncHandleAllocator.cpp new file mode 100644 index 0000000..4d69f24 --- /dev/null +++ b/src/AsyncHandleAllocator.cpp @@ -0,0 +1,47 @@ +#include + +#include "AsyncHandleAllocator.hpp" + +AsyncHandleAllocator::AsyncHandleAllocator(): + next_enum_id(1), + next_connect_id(1), + next_send_id(1) {} + +DPNHANDLE AsyncHandleAllocator::new_enum() +{ + DPNHANDLE handle = next_enum_id++ | TYPE_ENUM; + + next_enum_id &= ~TYPE_MASK; + if(next_enum_id == 0) + { + next_enum_id = 1; + } + + return handle; +} + +DPNHANDLE AsyncHandleAllocator::new_connect() +{ + DPNHANDLE handle = next_connect_id++ | TYPE_ENUM; + + next_connect_id &= ~TYPE_MASK; + if(next_connect_id == 0) + { + next_connect_id = 1; + } + + return handle; +} + +DPNHANDLE AsyncHandleAllocator::new_send() +{ + DPNHANDLE handle = next_send_id++ | TYPE_ENUM; + + next_send_id &= ~TYPE_MASK; + if(next_send_id == 0) + { + next_send_id = 1; + } + + return handle; +} diff --git a/src/AsyncHandleAllocator.hpp b/src/AsyncHandleAllocator.hpp new file mode 100644 index 0000000..3b0ccd1 --- /dev/null +++ b/src/AsyncHandleAllocator.hpp @@ -0,0 +1,40 @@ +#ifndef DPLITE_ASYNCHANDLEALLOCATOR_HPP +#define DPLITE_ASYNCHANDLEALLOCATOR_HPP + +#include + +/* There is an instance of this class in each DirectPlay8Peer/etc instance to allocate DPNHANDLEs + * for async operations. + * + * Handles are allocated sequentially, they aren't currently tracked, but I doubt anyone will ever + * have enough running at once to wrap around and conflict. + * + * The handle's type is encoded in the high bits so CancelAsyncOperation() can know where it needs + * to look rather than having to search through each type of async task. + * + * 0x00000000 and 0xFFFFFFFF are both impossible values as they have significance to some parts of + * DirectPlay. +*/ + +class AsyncHandleAllocator +{ + private: + DPNHANDLE next_enum_id; + DPNHANDLE next_connect_id; + DPNHANDLE next_send_id; + + public: + static const DPNHANDLE TYPE_MASK = 0xC0000000; + + static const DPNHANDLE TYPE_ENUM = 0x00000000; + static const DPNHANDLE TYPE_CONNECT = 0x40000000; + static const DPNHANDLE TYPE_SEND = 0x80000000; + + AsyncHandleAllocator(); + + DPNHANDLE new_enum(); + DPNHANDLE new_connect(); + DPNHANDLE new_send(); +}; + +#endif /* !DPLITE_ASYNCHANDLEALLOCATOR_HPP */ diff --git a/src/COMAPIException.cpp b/src/COMAPIException.cpp new file mode 100644 index 0000000..c18fa65 --- /dev/null +++ b/src/COMAPIException.cpp @@ -0,0 +1,21 @@ +#include + +#include "COMAPIException.hpp" + +COMAPIException::COMAPIException(HRESULT result): + hr(result) +{ + snprintf(what_s, sizeof(what_s), "COMAPIException, HRESULT %08X", (unsigned)(result)); +} + +COMAPIException::~COMAPIException() {} + +HRESULT COMAPIException::result() const noexcept +{ + return hr; +} + +const char *COMAPIException::what() const noexcept +{ + return what_s; +} diff --git a/src/COMAPIException.hpp b/src/COMAPIException.hpp new file mode 100644 index 0000000..8759fb6 --- /dev/null +++ b/src/COMAPIException.hpp @@ -0,0 +1,22 @@ +#ifndef DPLITE_COMAPIEXCEPTION_HPP +#define DPLITE_COMAPIEXCEPTION_HPP + +#include +#include +#include + +class COMAPIException: public std::exception +{ + private: + const HRESULT hr; + char what_s[64]; + + public: + COMAPIException(HRESULT result); + virtual ~COMAPIException(); + + HRESULT result() const noexcept; + virtual const char *what() const noexcept; +}; + +#endif /* !DPLITE_COMAPIEXCEPTION_HPP */ diff --git a/src/DirectPlay8Peer.cpp b/src/DirectPlay8Peer.cpp index edc58da..b01ac51 100644 --- a/src/DirectPlay8Peer.cpp +++ b/src/DirectPlay8Peer.cpp @@ -1,13 +1,19 @@ #include #include #include +#include +#include #include +#include #include #include +#include #include +#include "COMAPIException.hpp" #include "DirectPlay8Address.hpp" #include "DirectPlay8Peer.hpp" +#include "Messages.hpp" #include "network.hpp" #define UNIMPLEMENTED(fmt, ...) \ @@ -17,15 +23,29 @@ DirectPlay8Peer::DirectPlay8Peer(std::atomic *global_refcount): global_refcount(global_refcount), local_refcount(0), - state(STATE_DISCONNECTED), + state(STATE_NEW), udp_socket(-1), listener_socket(-1), discovery_socket(-1) { + io_event = CreateEvent(NULL, FALSE, FALSE, NULL); + if(io_event == NULL) + { + throw std::runtime_error("Cannot create event object"); + } + AddRef(); } -DirectPlay8Peer::~DirectPlay8Peer() {} +DirectPlay8Peer::~DirectPlay8Peer() +{ + if(state != STATE_NEW) + { + Close(0); + } + + CloseHandle(io_event); +} HRESULT DirectPlay8Peer::QueryInterface(REFIID riid, void **ppvObject) { @@ -71,9 +91,23 @@ ULONG DirectPlay8Peer::Release(void) HRESULT DirectPlay8Peer::Initialize(PVOID CONST pvUserContext, CONST PFNDPNMESSAGEHANDLER pfn, CONST DWORD dwFlags) { + if(state != STATE_NEW) + { + return DPNERR_ALREADYINITIALIZED; + } + message_handler = pfn; message_handler_ctx = pvUserContext; + WSADATA wd; + if(WSAStartup(MAKEWORD(2,2), &wd) != 0) + { + /* TODO */ + return DPNERR_GENERIC; + } + + state = STATE_INITIALISED; + return S_OK; } @@ -84,7 +118,64 @@ HRESULT DirectPlay8Peer::EnumServiceProviders(CONST GUID* CONST pguidServiceProv HRESULT DirectPlay8Peer::CancelAsyncOperation(CONST DPNHANDLE hAsyncHandle, CONST DWORD dwFlags) { - UNIMPLEMENTED("DirectPlay8Peer::CancelAsyncOperation"); + if(dwFlags & DPNCANCEL_PLAYER_SENDS) + { + /* Cancel sends to player ID in hAsyncHandle */ + UNIMPLEMENTED("DirectPlay8Peer::CancelAsyncOperation"); + } + else if(dwFlags & (DPNCANCEL_ENUM | DPNCANCEL_CONNECT | DPNCANCEL_ALL_OPERATIONS)) + { + /* Cancel all outstanding operations of one or more types. */ + + if(dwFlags & (DPNCANCEL_ENUM | DPNCANCEL_ALL_OPERATIONS)) + { + std::unique_lock l(lock); + + for(auto ei = host_enums.begin(); ei != host_enums.end(); ++ei) + { + ei->second.cancel(); + } + } + + if(dwFlags & (DPNCANCEL_CONNECT | DPNCANCEL_ALL_OPERATIONS)) + { + /* TODO: Cancel in-progress connect. */ + } + + if(dwFlags & DPNCANCEL_ALL_OPERATIONS) + { + /* TODO: Cancel all sends */ + } + + return S_OK; + } + else if((hAsyncHandle & AsyncHandleAllocator::TYPE_MASK) == AsyncHandleAllocator::TYPE_ENUM) + { + std::unique_lock l(lock); + + auto ei = host_enums.find(hAsyncHandle); + if(ei == host_enums.end()) + { + return DPNERR_INVALIDHANDLE; + } + + /* TODO: Make successive cancels for the same handle before it is destroyed fail? */ + + ei->second.cancel(); + return S_OK; + } + else if((hAsyncHandle & AsyncHandleAllocator::TYPE_MASK) == AsyncHandleAllocator::TYPE_CONNECT) + { + UNIMPLEMENTED("DirectPlay8Peer::CancelAsyncOperation"); + } + else if((hAsyncHandle & AsyncHandleAllocator::TYPE_MASK) == AsyncHandleAllocator::TYPE_SEND) + { + UNIMPLEMENTED("DirectPlay8Peer::CancelAsyncOperation"); + } + else{ + /* Unrecognised handle type. */ + return DPNERR_INVALIDHANDLE; + } } HRESULT DirectPlay8Peer::Connect(CONST DPN_APPLICATION_DESC* CONST pdnAppDesc, IDirectPlay8Address* CONST pHostAddr, IDirectPlay8Address* CONST pDeviceInfo, CONST DPN_SECURITY_DESC* CONST pdnSecurity, CONST DPN_SECURITY_CREDENTIALS* CONST pdnCredentials, CONST void* CONST pvUserConnectData, CONST DWORD dwUserConnectDataSize, void* CONST pvPlayerContext, void* CONST pvAsyncContext, DPNHANDLE* CONST phAsyncHandle, CONST DWORD dwFlags) @@ -104,9 +195,12 @@ HRESULT DirectPlay8Peer::GetSendQueueInfo(CONST DPNID dpnid, DWORD* CONST pdwNum HRESULT DirectPlay8Peer::Host(CONST DPN_APPLICATION_DESC* CONST pdnAppDesc, IDirectPlay8Address **CONST prgpDeviceInfo, CONST DWORD cDeviceInfo, CONST DPN_SECURITY_DESC* CONST pdnSecurity, CONST DPN_SECURITY_CREDENTIALS* CONST pdnCredentials, void* CONST pvPlayerContext, CONST DWORD dwFlags) { - if(state != STATE_DISCONNECTED) + switch(state) { - return DPNERR_ALREADYCONNECTED; + case STATE_NEW: return DPNERR_UNINITIALIZED; + case STATE_INITIALISED: break; + case STATE_HOSTING: return DPNERR_ALREADYCONNECTED; + case STATE_CONNECTED: return DPNERR_ALREADYCONNECTED; } if(pdnAppDesc->dwSize != sizeof(DPN_APPLICATION_DESC)) @@ -124,6 +218,13 @@ HRESULT DirectPlay8Peer::Host(CONST DPN_APPLICATION_DESC* CONST pdnAppDesc, IDir /* Not supported yet. */ } + /* Generate a random GUID for this session. */ + HRESULT guid_err = CoCreateGuid(&instance_guid); + if(guid_err != S_OK) + { + return guid_err; + } + application_guid = pdnAppDesc->guidApplication; max_players = pdnAppDesc->dwMaxPlayers; session_name = pdnAppDesc->pwszSessionName; @@ -172,13 +273,56 @@ HRESULT DirectPlay8Peer::Host(CONST DPN_APPLICATION_DESC* CONST pdnAppDesc, IDir if(port == 0) { - port = DEFAULT_HOST_PORT; + /* Ephemeral port range as defined by IANA. */ + const int AUTO_PORT_MIN = 49152; + const int AUTO_PORT_MAX = 65535; + + for(int p = AUTO_PORT_MIN; p <= AUTO_PORT_MAX; ++p) + { + /* TODO: Only continue if creation failed due to address conflict. */ + + udp_socket = create_udp_socket(ipaddr, port); + if(udp_socket == -1) + { + continue; + } + + listener_socket = create_listener_socket(ipaddr, port); + if(listener_socket == -1) + { + closesocket(udp_socket); + udp_socket = -1; + + continue; + } + + break; + } + + if(udp_socket == -1) + { + return DPNERR_GENERIC; + } + } + else{ + udp_socket = create_udp_socket(ipaddr, port); + if(udp_socket == -1) + { + return DPNERR_GENERIC; + } + + listener_socket = create_listener_socket(ipaddr, port); + if(listener_socket == -1) + { + closesocket(udp_socket); + udp_socket = -1; + + return DPNERR_GENERIC; + } } - udp_socket = create_udp_socket (ipaddr, port); - listener_socket = create_listener_socket(ipaddr, port); - - if(udp_socket == -1 || listener_socket == -1) + if(WSAEventSelect(udp_socket, io_event, FD_READ | FD_WRITE) != 0 + || WSAEventSelect(listener_socket, io_event, FD_ACCEPT) != 0) { return DPNERR_GENERIC; } @@ -186,8 +330,17 @@ HRESULT DirectPlay8Peer::Host(CONST DPN_APPLICATION_DESC* CONST pdnAppDesc, IDir if(!(pdnAppDesc->dwFlags & DPNSESSION_NODPNSVR)) { discovery_socket = create_discovery_socket(); + + if(discovery_socket == -1 + || WSAEventSelect(discovery_socket, io_event, FD_READ) != 0) + { + return DPNERR_GENERIC; + } } + io_run = true; + io_thread = std::thread(&DirectPlay8Peer::io_main, this); + state = STATE_HOSTING; return S_OK; @@ -195,7 +348,7 @@ HRESULT DirectPlay8Peer::Host(CONST DPN_APPLICATION_DESC* CONST pdnAppDesc, IDir HRESULT DirectPlay8Peer::GetApplicationDesc(DPN_APPLICATION_DESC* CONST pAppDescBuffer, DWORD* CONST pcbDataSize, CONST DWORD dwFlags) { - if(state == STATE_DISCONNECTED) + if(state != STATE_HOSTING && state != STATE_CONNECTED) { return DPNERR_NOCONNECTION; } @@ -213,7 +366,7 @@ HRESULT DirectPlay8Peer::GetApplicationDesc(DPN_APPLICATION_DESC* CONST pAppDesc pAppDescBuffer->guidInstance = instance_guid; pAppDescBuffer->guidApplication = application_guid; pAppDescBuffer->dwMaxPlayers = max_players; - pAppDescBuffer->dwCurrentPlayers = peers.size() + 1; + pAppDescBuffer->dwCurrentPlayers = other_player_ids.size() + 1; if(!password.empty()) { @@ -257,7 +410,7 @@ HRESULT DirectPlay8Peer::SetApplicationDesc(CONST DPN_APPLICATION_DESC* CONST pa { if(state == STATE_HOSTING) { - if(pad->dwMaxPlayers > 0 && pad->dwMaxPlayers <= peers.size()) + if(pad->dwMaxPlayers > 0 && pad->dwMaxPlayers <= other_player_ids.size()) { /* Can't set dwMaxPlayers below current player count. */ return DPNERR_INVALIDPARAM; @@ -354,12 +507,114 @@ HRESULT DirectPlay8Peer::GetLocalHostAddresses(IDirectPlay8Address** CONST prgpA HRESULT DirectPlay8Peer::Close(CONST DWORD dwFlags) { - UNIMPLEMENTED("DirectPlay8Peer::Close"); + if(state == STATE_NEW) + { + return DPNERR_UNINITIALIZED; + } + + if(state == STATE_HOSTING || state == STATE_CONNECTED) + { + io_run = false; + SetEvent(io_event); + + io_thread.join(); + } + + CancelAsyncOperation(0, DPNCANCEL_ALL_OPERATIONS); + + /* TODO: Wait properly. */ + + while(1) + { + std::unique_lock l(lock); + if(host_enums.empty()) + { + break; + } + + Sleep(50); + } + + /* TODO: Clean sockets etc up. */ + + WSACleanup(); + + state = STATE_NEW; + + return S_OK; } HRESULT DirectPlay8Peer::EnumHosts(PDPN_APPLICATION_DESC CONST pApplicationDesc, IDirectPlay8Address* CONST pAddrHost, IDirectPlay8Address* CONST pDeviceInfo,PVOID CONST pUserEnumData, CONST DWORD dwUserEnumDataSize, CONST DWORD dwEnumCount, CONST DWORD dwRetryInterval, CONST DWORD dwTimeOut,PVOID CONST pvUserContext, DPNHANDLE* CONST pAsyncHandle, CONST DWORD dwFlags) { - UNIMPLEMENTED("DirectPlay8Peer::EnumHosts"); + if(state == STATE_NEW) + { + return DPNERR_UNINITIALIZED; + } + + try { + if(dwFlags & DPNENUMHOSTS_SYNC) + { + HRESULT result; + + HostEnumerator he( + global_refcount, + message_handler, message_handler_ctx, + pApplicationDesc, pAddrHost, pDeviceInfo, pUserEnumData, dwUserEnumDataSize, + dwEnumCount, dwRetryInterval, dwTimeOut, pvUserContext, + + [&result](HRESULT r) + { + result = r; + }); + + he.wait(); + + return result; + } + else{ + DPNHANDLE handle = handle_alloc.new_enum(); + + *pAsyncHandle = handle; + + std::unique_lock l(lock); + + host_enums.emplace( + std::piecewise_construct, + std::forward_as_tuple(handle), + std::forward_as_tuple( + global_refcount, + message_handler, message_handler_ctx, + pApplicationDesc, pAddrHost, pDeviceInfo, pUserEnumData, dwUserEnumDataSize, + dwEnumCount, dwRetryInterval, dwTimeOut, pvUserContext, + + [this, handle, pvUserContext](HRESULT r) + { + DPNMSG_ASYNC_OP_COMPLETE oc; + memset(&oc, 0, sizeof(oc)); + + oc.dwSize = sizeof(oc); + oc.hAsyncOp = handle; + oc.pvUserContext = pvUserContext; + oc.hResultCode = r; + + message_handler(message_handler_ctx, DPN_MSGID_ASYNC_OP_COMPLETE, &oc); + + std::unique_lock l(lock); + host_enums.erase(handle); + l.unlock(); + })); + + return DPNSUCCESS_PENDING; + } + } + catch(const COMAPIException &e) + { + return e.result(); + } + catch(...) + { + return DPNERR_GENERIC; + } } HRESULT DirectPlay8Peer::DestroyPeer(CONST DPNID dpnidClient, CONST void* CONST pvDestroyData, CONST DWORD dwDestroyDataSize, CONST DWORD dwFlags) @@ -416,3 +671,352 @@ HRESULT DirectPlay8Peer::TerminateSession(void* CONST pvTerminateData, CONST DWO { UNIMPLEMENTED("DirectPlay8Peer::TerminateSession"); } + +void DirectPlay8Peer::io_main() +{ + while(io_run) + { + WaitForSingleObject(io_event, INFINITE); + + io_udp_recv(udp_socket); + io_udp_send(udp_socket, udp_sq); + + if(discovery_socket != -1) + { + io_udp_recv(discovery_socket); + io_udp_send(discovery_socket, discovery_sq); + } + + io_listener_accept(listener_socket); + + std::unique_lock l(lock); + + for(auto p = peers.begin(); p != peers.end();) + { + auto next_p = std::next(p); + + if(!io_tcp_recv(&*p) || !io_tcp_send(&*p)) + { + /* TODO: Complete outstanding sends (failed), drop player */ + closesocket(p->sock); + peers.erase(p); + } + + p = next_p; + } + } +} + +void DirectPlay8Peer::io_udp_recv(int sock) +{ + struct sockaddr_in from_addr; + int fa_len = sizeof(from_addr); + + int r = recvfrom(sock, (char*)(recv_buf), sizeof(recv_buf), 0, (struct sockaddr*)(&from_addr), &fa_len); + if(r <= 0) + { + return; + } + + /* Process message */ + std::unique_ptr pd; + + try { + pd.reset(new PacketDeserialiser(recv_buf, r)); + } + catch(const PacketDeserialiser::Error &e) + { + /* Malformed packet received */ + return; + } + + switch(pd->packet_type()) + { + case DPLITE_MSGID_HOST_ENUM_REQUEST: + { + if(state == STATE_HOSTING) + { + handle_host_enum_request(*pd, &from_addr); + } + + break; + } + + default: + /* TODO: Log "unrecognised packet type" */ + break; + } +} + +void DirectPlay8Peer::io_udp_send(int sock, SendQueue &sq) +{ + SendQueue::Buffer *sqb; + + while((sqb = sq.get_next()) != NULL) + { + std::pair data = sqb->get_data(); + std::pair addr = sqb->get_dest_addr(); + + int s = sendto(sock, (const char*)(data.first), data.second, 0, addr.first, addr.second); + if(s == -1) + { + DWORD err = WSAGetLastError(); + + if(err != WSAEWOULDBLOCK) + { + /* TODO: More specific error codes */ + sq.complete(sqb, DPNERR_GENERIC); + } + + break; + } + + sq.complete(sqb, S_OK); + } +} + +void DirectPlay8Peer::io_listener_accept(int sock) +{ + struct sockaddr_in addr; + int addrlen = sizeof(addr); + + int newfd = accept(sock, (struct sockaddr*)(&addr), &addrlen); + if(newfd == -1) + { + DWORD err = WSAGetLastError(); + + if(err == WSAEWOULDBLOCK) + { + return; + } + else{ + /* TODO */ + abort(); + } + } + + u_long non_blocking = 1; + if(ioctlsocket(newfd, FIONBIO, &non_blocking) != 0) + { + closesocket(newfd); + return; + } + + peers.emplace_back(newfd, addr.sin_addr.s_addr, ntohs(addr.sin_port)); +} + +bool DirectPlay8Peer::io_tcp_recv(Player *player) +{ + int r = recv(player->sock, (char*)(player->recv_buf) + player->recv_buf_cur, sizeof(player->recv_buf) - player->recv_buf_cur, 0); + if(r == 0) + { + /* Connection closed */ + return false; + } + else if(r == -1) + { + DWORD err = WSAGetLastError(); + + if(err == WSAEWOULDBLOCK) + { + /* Nothing to read */ + return true; + } + else{ + /* Read error. */ + return false; + } + } + + player->recv_buf_cur += r; + + if(player->recv_buf_cur >= sizeof(TLVChunk)) + { + TLVChunk *header = (TLVChunk*)(player->recv_buf); + size_t full_packet_size = sizeof(TLVChunk) + header->value_length; + + if(full_packet_size > MAX_PACKET_SIZE) + { + /* Malformed packet received */ + return false; + } + + if(player->recv_buf_cur >= full_packet_size) + { + /* Process message */ + std::unique_ptr pd; + + try { + pd.reset(new PacketDeserialiser(player->recv_buf, full_packet_size)); + } + catch(const PacketDeserialiser::Error &e) + { + /* Malformed packet received */ + return false; + } + + /* Message at the front of the buffer has been dealt with, shift any + * remaining data beyond it to the front and truncate it. + */ + + memmove(player->recv_buf, player->recv_buf + full_packet_size, + player->recv_buf_cur - full_packet_size); + player->recv_buf_cur -= full_packet_size; + } + } + + return true; +} + +bool DirectPlay8Peer::io_tcp_send(Player *player) +{ + SendQueue::Buffer *sqb = player->sq.get_next(); + + while(player->send_buf != NULL || sqb != NULL) + { + if(player->sqb == NULL) + { + std::pair sqb_data = sqb->get_data(); + + player->send_buf = (const unsigned char*)(sqb_data.first); + player->send_remain = sqb_data.second; + } + + int s = send(player->sock, (const char*)(player->send_buf), player->send_remain, 0); + if(s < 0) + { + DWORD err = WSAGetLastError(); + + if(err == WSAEWOULDBLOCK) + { + break; + } + else{ + /* TODO: Better error codes */ + player->sq.complete(sqb, DPNERR_GENERIC); + return false; + } + } + + if((size_t)(s) == player->send_remain) + { + player->send_buf = NULL; + + player->sq.complete(sqb, S_OK); + sqb = player->sq.get_next(); + } + else{ + player->send_buf += s; + player->send_remain -= s; + } + } + + return true; +} + +class SQB_TODO: public SendQueue::Buffer +{ + private: + PFNDPNMESSAGEHANDLER message_handler; + PVOID message_handler_ctx; + + DPNMSG_ENUM_HOSTS_QUERY query; + + public: + SQB_TODO(const void *data, size_t data_size, const struct sockaddr_in *dest_addr, + PFNDPNMESSAGEHANDLER message_handler, PVOID message_handler_ctx, + DPNMSG_ENUM_HOSTS_QUERY query): + Buffer(data, data_size, (const struct sockaddr*)(dest_addr), sizeof(*dest_addr)), + message_handler(message_handler), + message_handler_ctx(message_handler_ctx), + query(query) {} + + virtual void complete(HRESULT result) override + { + if(query.pvResponseData != NULL) + { + DPNMSG_RETURN_BUFFER rb; + memset(&rb, 0, sizeof(rb)); + + rb.dwSize = sizeof(rb); + rb.hResultCode = result; + rb.pvBuffer = query.pvResponseData; + rb.pvUserContext = query.pvResponseContext; + + message_handler(message_handler_ctx, DPN_MSGID_RETURN_BUFFER, &rb); + } + } +}; + +void DirectPlay8Peer::handle_host_enum_request(const PacketDeserialiser &pd, const struct sockaddr_in *from_addr) +{ + if(!pd.is_null(0)) + { + GUID r_application_guid = pd.get_guid(0); + + if(application_guid != r_application_guid) + { + /* This isn't the application you're looking for. + * It can go about its business. + */ + return; + } + } + + DPNMSG_ENUM_HOSTS_QUERY query; + memset(&query, 0, sizeof(query)); + + query.dwSize = sizeof(query); + query.pAddressSender = NULL; // TODO + query.pAddressDevice = NULL; // TODO + + if(!pd.is_null(1)) + { + std::pair data = pd.get_data(1); + + query.pvReceivedData = (void*)(data.first); /* TODO: Make a non-const copy? */ + query.dwReceivedDataSize = data.second; + } + + query.dwMaxResponseDataSize = 9999; // TODO + + DWORD req_tick = pd.get_dword(2); + + if(message_handler(message_handler_ctx, DPN_MSGID_ENUM_HOSTS_QUERY, &query) == DPN_OK) + { + PacketSerialiser ps(DPLITE_MSGID_HOST_ENUM_RESPONSE); + + ps.append_dword(password.empty() ? 0 : DPNSESSION_REQUIREPASSWORD); + ps.append_guid(instance_guid); + ps.append_guid(application_guid); + ps.append_dword(max_players); + ps.append_dword(other_player_ids.size() + 1); + ps.append_wstring(session_name); + + if(!application_data.empty()) + { + ps.append_data(application_data.data(), application_data.size()); + } + else{ + ps.append_null(); + } + + if(query.pvResponseData != NULL && query.dwResponseDataSize != 0) + { + ps.append_data(query.pvResponseData, query.dwResponseDataSize); + } + else{ + ps.append_null(); + } + + ps.append_dword(req_tick); + + std::pair raw_pkt = ps.raw_packet(); + + udp_sq.send(SendQueue::SEND_PRI_MEDIUM, new SQB_TODO(raw_pkt.first, raw_pkt.second, from_addr, + message_handler, message_handler_ctx, query)); + } + else{ + /* Application rejected the DPNMSG_ENUM_HOSTS_QUERY message. */ + } +} diff --git a/src/DirectPlay8Peer.hpp b/src/DirectPlay8Peer.hpp index 3e13bcb..daaaf18 100644 --- a/src/DirectPlay8Peer.hpp +++ b/src/DirectPlay8Peer.hpp @@ -1,11 +1,20 @@ #ifndef DPLITE_DIRECTPLAY8PEER_HPP #define DPLITE_DIRECTPLAY8PEER_HPP +#include #include #include #include +#include #include #include +#include + +#include "AsyncHandleAllocator.hpp" +#include "HostEnumerator.hpp" +#include "network.hpp" +#include "packet.hpp" +#include "SendQueue.hpp" class DirectPlay8Peer: public IDirectPlay8Peer { @@ -17,11 +26,16 @@ class DirectPlay8Peer: public IDirectPlay8Peer PVOID message_handler_ctx; enum { - STATE_DISCONNECTED, + STATE_NEW, + STATE_INITIALISED, STATE_HOSTING, STATE_CONNECTED, } state; + AsyncHandleAllocator handle_alloc; + + std::map host_enums; + GUID instance_guid; GUID application_guid; DWORD max_players; @@ -33,8 +47,40 @@ class DirectPlay8Peer: public IDirectPlay8Peer int listener_socket; /* TCP listener socket. */ int discovery_socket; /* Discovery UDP sockets, RECIEVES broadcasts only. */ + unsigned char recv_buf[MAX_PACKET_SIZE]; + + SendQueue udp_sq; + SendQueue discovery_sq; + + HANDLE io_event; + std::thread io_thread; + std::atomic io_run; + struct Player { + enum PlayerState { + /* Peer has connected to us, we're waiting for the initial message from it. */ + PS_INIT, + + /* We are the host and the peer has sent the initial connect request, we are waiting + * for the application to process DPN_MSGID_INDICATE_CONNECT before we either add the + * player to the session or reject it. + */ + PS_CONNECTING, + + /* This is a fully-fledged peer. */ + PS_CONNECTED, + + /* This peer is closing down. Discard any future messages received from it, but flush + * anything waiting to be sent and keep the player DPNID valid until after the + * application has processed the DPN_MSGID_DESTROY_PLAYER message and any outstanding + * operations have completed or been cancelled. + */ + PS_CLOSING, + }; + + enum PlayerState state; + /* This is the TCP socket to the peer, we may have connected to it, or it * may have connected to us depending who joined the session first, that * doesn't really matter. @@ -42,10 +88,43 @@ class DirectPlay8Peer: public IDirectPlay8Peer int sock; uint32_t ip; /* IPv4 address, network byte order. */ - uint16_t port; /* Port, host byte order. */ + uint16_t port; /* TCP and UDP port, host byte order. */ + + DPNID id; /* Player ID, not initialised before state PS_CONNECTED. */ + + unsigned char recv_buf[MAX_PACKET_SIZE]; + size_t recv_buf_cur; + + SendQueue sq; + SendQueue::Buffer *sqb; + + const unsigned char *send_buf; + size_t send_remain; + + Player(int sock, uint32_t ip, uint16_t port): + state(PS_INIT), sock(sock), ip(ip), port(port), recv_buf_cur(0), send_buf(NULL) {} }; - std::map peers; + std::list peers; + std::map::iterator> other_player_ids; + + /* Serialises access to: + * + * host_enums + * pending_peers + * peers + */ + std::mutex lock; + + void io_main(); + + void io_udp_recv(int sock); + void io_udp_send(int sock, SendQueue &q); + void io_listener_accept(int sock); + bool io_tcp_recv(Player *player); + bool io_tcp_send(Player *player); + + void handle_host_enum_request(const PacketDeserialiser &pd, const struct sockaddr_in *from_addr); public: DirectPlay8Peer(std::atomic *global_refcount); diff --git a/src/HostEnumerator.cpp b/src/HostEnumerator.cpp new file mode 100644 index 0000000..07ecd58 --- /dev/null +++ b/src/HostEnumerator.cpp @@ -0,0 +1,302 @@ +#include +#include + +#include "DirectPlay8Address.hpp" +#include "HostEnumerator.hpp" +#include "Messages.hpp" +#include "packet.hpp" + +const GUID GUID_NULL = { 0, 0, 0, { 0, 0, 0, 0, 0, 0, 0, 0 } }; + +HostEnumerator::HostEnumerator( + std::atomic * const global_refcount, + + PFNDPNMESSAGEHANDLER message_handler, + PVOID message_handler_ctx, + + PDPN_APPLICATION_DESC const pApplicationDesc, + IDirectPlay8Address *const pdpaddrHost, + IDirectPlay8Address *const pdpaddrDeviceInfo, + PVOID const pvUserEnumData, + const DWORD dwUserEnumDataSize, + const DWORD dwEnumCount, + const DWORD dwRetryInterval, + const DWORD dwTimeOut, + PVOID const pvUserContext, + + std::function complete_cb): + + global_refcount(global_refcount), + message_handler(message_handler), + message_handler_ctx(message_handler_ctx), + complete_cb(complete_cb), + user_context(pvUserContext), + req_cancel(false) +{ + /* TODO: Use address in pdpaddrHost, if provided. */ + + send_addr.sin_family = AF_INET; + send_addr.sin_addr.s_addr = htonl(INADDR_BROADCAST); + send_addr.sin_port = htons(DISCOVERY_PORT); + + if(pApplicationDesc != NULL) + { + application_guid = pApplicationDesc->guidApplication; + } + else{ + application_guid = GUID_NULL; + } + + if(pvUserEnumData != NULL && dwUserEnumDataSize > 0) + { + user_data.insert( + user_data.end(), + (const unsigned char*)(pvUserEnumData), + (const unsigned char*)(pvUserEnumData) + dwUserEnumDataSize); + } + + tx_remain = (dwEnumCount == 0) ? DEFAULT_ENUM_COUNT : dwEnumCount; + tx_interval = (dwRetryInterval == 0) ? DEFAULT_ENUM_INTERVAL : dwRetryInterval; + rx_timeout = (dwTimeOut == 0) ? DEFAULT_ENUM_TIMEOUT : dwTimeOut; + + /* TODO: Bind to interface in pdpaddrDeviceInfo, if provided. */ + + sock = create_udp_socket(0, 0); + if(sock == -1) + { + throw std::runtime_error("Cannot create UDP socket"); + } + + wake_thread = CreateEvent(NULL, FALSE, FALSE, NULL); + if(wake_thread == NULL) + { + closesocket(sock); + throw std::runtime_error("Cannot create wake_thread object"); + } + + if(WSAEventSelect(sock, wake_thread, FD_READ)) + { + CloseHandle(wake_thread); + closesocket(sock); + throw std::runtime_error("Cannot WSAEventSelect"); + } + + thread = new std::thread(&HostEnumerator::main, this); +} + +HostEnumerator::~HostEnumerator() +{ + cancel(); + + if(thread->joinable()) + { + if(thread->get_id() == std::this_thread::get_id()) + { + thread->detach(); + } + else{ + thread->join(); + } + } + + delete thread; + + CloseHandle(wake_thread); +} + +void HostEnumerator::main() +{ + while(!req_cancel) + { + DWORD now = GetTickCount(); + + if(tx_remain > 0 && now >= next_tx_at) + { + PacketSerialiser ps(DPLITE_MSGID_HOST_ENUM_REQUEST); + + if(application_guid != GUID_NULL) + { + ps.append_guid(application_guid); + } + else{ + ps.append_null(); + } + + if(!user_data.empty()) + { + ps.append_data(user_data.data(), user_data.size()); + } + else{ + ps.append_null(); + } + + ps.append_dword(now); + + std::pair raw = ps.raw_packet(); + + sendto(sock, (const char*)(raw.first), raw.second, 0, + (struct sockaddr*)(&send_addr), sizeof(send_addr)); + + next_tx_at = now + tx_interval; + --tx_remain; + + if(rx_timeout != INFINITE) + { + stop_at = now + rx_timeout; + } + } + + struct sockaddr_in from_addr; + int addrlen = sizeof(from_addr); + + int r = recvfrom(sock, (char*)(recv_buf), sizeof(recv_buf), 0, (struct sockaddr*)(&from_addr), &addrlen); + if(r > 0) + { + handle_packet(recv_buf, r, &from_addr); + } + + if(tx_remain == 0 && stop_at > 0 && now >= stop_at) + { + /* No more requests to transmit and the wait for replies from the last one + * has timed out. + */ + break; + } + + DWORD timeout = INFINITE; + if(tx_remain > 0) { timeout = std::min((next_tx_at - now), timeout); } + if(stop_at > 0) { timeout = std::min((stop_at - now), timeout); } + + WaitForSingleObject(wake_thread, timeout); + } + + if(req_cancel) + { + complete_cb(DPNERR_USERCANCEL); + } + else{ + complete_cb(S_OK); + } +} + +void HostEnumerator::handle_packet(const void *data, size_t size, struct sockaddr_in *from_addr) +{ + std::unique_ptr pd; + + try { + pd.reset(new PacketDeserialiser(data, size)); + } + catch(const PacketDeserialiser::Error &e) + { + /* Malformed packet received */ + return; + } + + if(pd->packet_type() != DPLITE_MSGID_HOST_ENUM_RESPONSE) + { + /* Unexpected packet type. */ + return; + } + + DPN_APPLICATION_DESC app_desc; + memset(&app_desc, 0, sizeof(app_desc)); + std::wstring app_desc_pwszSessionName; + + const void *response_data = NULL; + size_t response_data_size = 0; + + DWORD request_tick_count; + + try { + app_desc.dwSize = sizeof(app_desc); + + app_desc.dwFlags = pd->get_dword(0); + app_desc.guidInstance = pd->get_guid(1); + app_desc.guidApplication = pd->get_guid(2); + app_desc.dwMaxPlayers = pd->get_dword(3); + app_desc.dwCurrentPlayers = pd->get_dword(4); + + app_desc_pwszSessionName = pd->get_wstring(5); + app_desc.pwszSessionName = (wchar_t*)(app_desc_pwszSessionName.c_str()); + + if(!pd->is_null(6)) + { + std::pair app_data = pd->get_data(6); + app_desc.pvApplicationReservedData = (void*)(app_data.first); + app_desc.dwApplicationReservedDataSize = app_data.second; + } + + if(!pd->is_null(7)) + { + std::pair r_data = pd->get_data(7); + response_data = r_data.first; + response_data_size = r_data.second; + } + + request_tick_count = pd->get_dword(8); + } + catch(const PacketDeserialiser::Error &e) + { + /* Malformed packet received */ + return; + } + + /* Build a DirectPlay8Address with the host/port where the response came from - thats the main + * port for the host. + */ + + IDirectPlay8Address *sender_address = new DirectPlay8Address(global_refcount); + sender_address->AddRef(); + sender_address->SetSP(&CLSID_DP8SP_TCPIP); /* TODO: Be IPX if application previously gave us an IPX address? */ + + char from_addr_ip_s[16]; + inet_ntop(AF_INET, &(from_addr->sin_addr), from_addr_ip_s, sizeof(from_addr_ip_s)); + + sender_address->AddComponent(DPNA_KEY_HOSTNAME, + from_addr_ip_s, strlen(from_addr_ip_s) + 1, DPNA_DATATYPE_STRING_ANSI); + + char from_addr_port_s[8]; + snprintf(from_addr_port_s, sizeof(from_addr_port_s), "%u", (unsigned)(ntohs(from_addr->sin_port))); + + sender_address->AddComponent(DPNA_KEY_PORT, + from_addr_port_s, strlen(from_addr_port_s) + 1, DPNA_DATATYPE_STRING_ANSI); + + /* Build a DirectPlay8Address with the interface we received the response on. + * TODO: Actually do this. + */ + + IDirectPlay8Address *device_address = new DirectPlay8Address(global_refcount); + device_address->AddRef(); + device_address->SetSP(&CLSID_DP8SP_TCPIP); /* TODO: Be IPX if application previously gave us an IPX address? */ + + DPNMSG_ENUM_HOSTS_RESPONSE message; + memset(&message, 0, sizeof(message)); + + message.dwSize = sizeof(message); + message.pAddressSender = sender_address; + message.pAddressDevice = device_address; + message.pApplicationDescription = &app_desc; + message.pvResponseData = (void*)(response_data); + message.dwResponseDataSize = response_data_size; + message.pvUserContext = user_context; + message.dwRoundTripLatencyMS = GetTickCount() - request_tick_count; + + message_handler(message_handler_ctx, DPN_MSGID_ENUM_HOSTS_RESPONSE, &message); + + device_address->Release(); + sender_address->Release(); +} + +void HostEnumerator::cancel() +{ + req_cancel = true; + SetEvent(wake_thread); +} + +void HostEnumerator::wait() +{ + if(thread->joinable()) + { + thread->join(); + } +} diff --git a/src/HostEnumerator.hpp b/src/HostEnumerator.hpp new file mode 100644 index 0000000..52aaba7 --- /dev/null +++ b/src/HostEnumerator.hpp @@ -0,0 +1,82 @@ +#ifndef DPLITE_HOSTENUMERATOR_HPP +#define DPLITE_HOSTENUMERATOR_HPP + +#include +#include +#include +#include +#include +#include + +#include "network.hpp" + +#define DEFAULT_ENUM_COUNT 3 +#define DEFAULT_ENUM_INTERVAL 1000 +#define DEFAULT_ENUM_TIMEOUT 1000 + +class HostEnumerator +{ + private: + /* No copy c'tor. */ + HostEnumerator(const HostEnumerator&) = delete; + + /* Pointer to the global refcount (if in use), for instantiating DirectPlay8Address objects. */ + std::atomic * const global_refcount; + + /* DirectPlay message handler and context value */ + PFNDPNMESSAGEHANDLER message_handler; + PVOID message_handler_ctx; + + std::function complete_cb; + + struct sockaddr_in send_addr; + + GUID application_guid; /* GUID of application to search for, or GUID_NULL */ + std::vector user_data; /* Data to include in request. */ + + DWORD tx_remain; /* Number of remaining requests to transmit, may be INFINITE. */ + DWORD tx_interval; /* Number of milliseconds to wait between transmits. */ + DWORD rx_timeout; /* Number of milliseconds to wait for replies to a request. */ + + void *user_context; /* DPNMSG_ENUM_HOSTS_REPONSE.pvUserContext */ + + DWORD next_tx_at; + DWORD stop_at; + + int sock; + HANDLE wake_thread; + std::thread *thread; + bool req_cancel; + + unsigned char recv_buf[MAX_PACKET_SIZE]; + + void main(); + void handle_packet(const void *data, size_t size, struct sockaddr_in *from_addr); + + public: + HostEnumerator( + std::atomic * const global_refcount, + + PFNDPNMESSAGEHANDLER message_handler, + PVOID message_handler_ctx, + + PDPN_APPLICATION_DESC const pApplicationDesc, + IDirectPlay8Address *const pdpaddrHost, + IDirectPlay8Address *const pdpaddrDeviceInfo, + PVOID const pvUserEnumData, + const DWORD dwUserEnumDataSize, + const DWORD dwEnumCount, + const DWORD dwRetryInterval, + const DWORD dwTimeOut, + PVOID const pvUserContext, + + std::function complete_cb + ); + + ~HostEnumerator(); + + void cancel(); + void wait(); +}; + +#endif /* !DPLITE_HOSTENUMERATOR_HPP */ diff --git a/src/Messages.hpp b/src/Messages.hpp new file mode 100644 index 0000000..8cc2d62 --- /dev/null +++ b/src/Messages.hpp @@ -0,0 +1,29 @@ +#ifndef DPLITE_MESSAGES_HPP +#define DPLITE_MESSAGES_HPP + +#define DPLITE_MSGID_HOST_ENUM_REQUEST 1 + +/* EnumHosts() request message. + * + * GUID - Application GUID, NULL to search for any + * DATA | NULL - User data + * DWORD - Current tick count, to be returned, for latency measurement +*/ + +#define DPLITE_MSGID_HOST_ENUM_RESPONSE 2 + +/* EnumHosts() response message. + * + * DWORD - DPN_APPLICATION_DESC.dwFlags + * GUID - DPN_APPLICATION_DESC.guidInstance + * GUID - DPN_APPLICATION_DESC.guidApplication + * DWORD - DPN_APPLICATION_DESC.dwMaxPlayers + * DWORD - DPN_APPLICATION_DESC.dwCurrentPlayers + * WSTRING - DPN_APPLICATION_DESC.pwszSessionName + * DATA | NULL - DPN_APPLICATION_DESC.pvApplicationReservedData + * + * DATA | NULL - DPN_MSGID_ENUM_HOSTS_RESPONSE.pvResponseData + * DWORD - Tick count from DPLITE_MSGID_HOST_ENUM_REQUEST +*/ + +#endif /* !DPLITE_MESSAGES_HPP */ diff --git a/src/SendQueue.cpp b/src/SendQueue.cpp new file mode 100644 index 0000000..028a632 --- /dev/null +++ b/src/SendQueue.cpp @@ -0,0 +1,78 @@ +#include + +#include "SendQueue.hpp" + +void SendQueue::send(SendPriority priority, Buffer *buffer) +{ + switch(priority) + { + case SEND_PRI_LOW: + low_queue.push_back(buffer); + break; + + case SEND_PRI_MEDIUM: + medium_queue.push_back(buffer); + break; + + case SEND_PRI_HIGH: + high_queue.push_back(buffer); + break; + } +} + +SendQueue::Buffer *SendQueue::get_next() +{ + if(current != NULL) + { + return current; + } + + if(!high_queue.empty()) + { + current = high_queue.front(); + high_queue.pop_front(); + } + else if(!medium_queue.empty()) + { + current = medium_queue.front(); + medium_queue.pop_front(); + } + else if(!low_queue.empty()) + { + current = low_queue.front(); + low_queue.pop_front(); + } + + return current; +} + +void SendQueue::complete(SendQueue::Buffer *buffer, HRESULT result) +{ + assert(buffer == current); + + current = NULL; + + buffer->complete(result); + delete buffer; +} + +SendQueue::Buffer::Buffer(const void *data, size_t data_size, const struct sockaddr *dest_addr, int dest_addr_len): + data((const unsigned char*)(data), (const unsigned char*)(data) + data_size) +{ + assert((size_t)(dest_addr_len) <= sizeof(this->dest_addr)); + + memcpy(&(this->dest_addr), dest_addr, dest_addr_len); + this->dest_addr_len = dest_addr_len; +} + +SendQueue::Buffer::~Buffer() {} + +std::pair SendQueue::Buffer::get_data() +{ + return std::make_pair(data.data(), data.size()); +} + +std::pair SendQueue::Buffer::get_dest_addr() +{ + return std::make_pair((struct sockaddr*)(&dest_addr), (int)(dest_addr_len)); +} diff --git a/src/SendQueue.hpp b/src/SendQueue.hpp new file mode 100644 index 0000000..3a16ff9 --- /dev/null +++ b/src/SendQueue.hpp @@ -0,0 +1,60 @@ +#ifndef DPLITE_SENDQUEUE_HPP +#define DPLITE_SENDQUEUE_HPP + +#include + +#include +#include +#include +#include +#include + +class SendQueue +{ + public: + enum SendPriority { + SEND_PRI_LOW = 1, + SEND_PRI_MEDIUM = 2, + SEND_PRI_HIGH = 4, + }; + + class Buffer { + private: + std::vector data; + + struct sockaddr_storage dest_addr; + int dest_addr_len; + + protected: + Buffer(const void *data, size_t data_size, const struct sockaddr *dest_addr = NULL, int dest_addr_len = 0); + + public: + virtual ~Buffer(); + + std::pair get_data(); + + std::pair get_dest_addr(); + + virtual void complete(HRESULT result) = 0; + }; + + private: + std::list low_queue; + std::list medium_queue; + std::list high_queue; + + Buffer *current; + + public: + SendQueue(): current(NULL) {} + + /* No copy c'tor. */ + SendQueue(const SendQueue &src) = delete; + + void send(SendPriority priority, Buffer *buffer); + + Buffer *get_next(); + void complete(Buffer *buffer, HRESULT result); +}; + +#endif /* !DPLITE_SENDQUEUE_HPP */ diff --git a/src/dpnet.cpp b/src/dpnet.cpp index 395b940..28ca0ab 100644 --- a/src/dpnet.cpp +++ b/src/dpnet.cpp @@ -1,3 +1,4 @@ +#include #include #include #include diff --git a/src/network.cpp b/src/network.cpp index dae0d25..272d852 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -12,6 +12,13 @@ int create_udp_socket(uint32_t ipaddr, uint16_t port) return -1; } + u_long non_blocking = 1; + if(ioctlsocket(sock, FIONBIO, &non_blocking) != 0) + { + closesocket(sock); + return -1; + } + BOOL broadcast = TRUE; if(setsockopt(sock, SOL_SOCKET, SO_BROADCAST, (char*)(&broadcast), sizeof(BOOL)) == -1) { @@ -35,12 +42,19 @@ int create_udp_socket(uint32_t ipaddr, uint16_t port) int create_listener_socket(uint32_t ipaddr, uint16_t port) { - int sock = socket(AF_INET, SOCK_DGRAM, 0); + int sock = socket(AF_INET, SOCK_STREAM, 0); if(sock == -1) { return -1; } + u_long non_blocking = 1; + if(ioctlsocket(sock, FIONBIO, &non_blocking) != 0) + { + closesocket(sock); + return -1; + } + struct sockaddr_in addr; addr.sin_family = AF_INET; addr.sin_addr.s_addr = ipaddr; @@ -69,6 +83,13 @@ int create_discovery_socket() return -1; } + u_long non_blocking = 1; + if(ioctlsocket(sock, FIONBIO, &non_blocking) != 0) + { + closesocket(sock); + return -1; + } + BOOL reuse = TRUE; if(setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)(&reuse), sizeof(BOOL)) == -1) { diff --git a/src/network.hpp b/src/network.hpp index 31cfdc1..13201e0 100644 --- a/src/network.hpp +++ b/src/network.hpp @@ -6,6 +6,7 @@ #define DISCOVERY_PORT 6073 #define DEFAULT_HOST_PORT 6072 #define LISTEN_QUEUE_SIZE 16 +#define MAX_PACKET_SIZE (60 * 1024) int create_udp_socket(uint32_t ipaddr, uint16_t port); int create_listener_socket(uint32_t ipaddr, uint16_t port); diff --git a/tests/DirectPlay8Peer.cpp b/tests/DirectPlay8Peer.cpp new file mode 100644 index 0000000..7522c1d --- /dev/null +++ b/tests/DirectPlay8Peer.cpp @@ -0,0 +1,888 @@ +#include +#include +#include +#include + +#include "../src/DirectPlay8Peer.hpp" + +static const GUID APP_GUID_1 = { 0xa6133957, 0x6f42, 0x46ce, { 0xa9, 0x88, 0x22, 0xf7, 0x79, 0x47, 0x08, 0x16 } }; +static const GUID APP_GUID_2 = { 0x5917faae, 0x7ab0, 0x42d2, { 0xae, 0x13, 0x9c, 0x54, 0x1b, 0x7f, 0xb5, 0xab } }; + +static HRESULT CALLBACK callback_shim(PVOID pvUserContext, DWORD dwMessageType, PVOID pMessage) +{ + std::function *callback = (std::function*)(pvUserContext); + return (*callback)(dwMessageType, pMessage); +} + +/* Wrapper around a DirectPlay8Peer which hosts a session. */ +struct SessionHost +{ + DirectPlay8Peer dp8p; + std::function cb; + + SessionHost( + GUID application_guid, + const wchar_t *session_description, + std::function cb = + [](DWORD dwMessageType, PVOID pMessage) + { + return DPN_OK; + }): + dp8p(NULL), + cb(cb) + { + if(dp8p.Initialize(&(this->cb), &callback_shim, 0) != S_OK) + { + throw std::runtime_error("DirectPlay8Peer::Initialize failed"); + } + + { + DPN_APPLICATION_DESC app_desc; + memset(&app_desc, 0, sizeof(app_desc)); + + app_desc.dwSize = sizeof(app_desc); + app_desc.guidApplication = application_guid; + app_desc.pwszSessionName = (wchar_t*)(session_description); + + if(dp8p.Host(&app_desc, NULL, 0, NULL, NULL, NULL, 0) != S_OK) + { + throw std::runtime_error("DirectPlay8Peer::Host failed"); + } + } + } +}; + +struct FoundSession +{ + GUID application_guid; + std::wstring session_description; + + FoundSession(GUID application_guid, const std::wstring &session_description): + application_guid(application_guid), + session_description(session_description) {} + + bool operator==(const FoundSession &rhs) const + { + return application_guid == rhs.application_guid + && session_description == rhs.session_description; + } +}; + +struct CompareGUID { + bool operator()(const GUID &a, const GUID &b) const + { + return memcmp(&a, &b, sizeof(GUID)) < 0; + } +}; + +static void EXPECT_SESSIONS(std::map got, const FoundSession *expect_begin, const FoundSession *expect_end) +{ + std::list expect(expect_begin, expect_end); + + for(auto gi = got.begin(); gi != got.end();) + { + for(auto ei = expect.begin(); ei != expect.end();) + { + if(gi->second == *ei) + { + ei = expect.erase(ei); + gi = got.erase(gi); + goto NEXT_GI; + } + else{ + ++ei; + } + } + + ++gi; + NEXT_GI: + {} + } + + for(auto gi = got.begin(); gi != got.end(); ++gi) + { + wchar_t application_guid_s[128]; + StringFromGUID2(gi->second.application_guid, application_guid_s, 128); + + ADD_FAILURE() << "Extra session:" << std::endl + << " application_guid = " << application_guid_s << std::endl + << " session_description = " << gi->second.session_description; + } + + for(auto ei = expect.begin(); ei != expect.end(); ++ei) + { + wchar_t application_guid_s[128]; + StringFromGUID2(ei->application_guid, application_guid_s, 128); + + ADD_FAILURE() << "Missing session:" << std::endl + << " application_guid = " << application_guid_s << std::endl + << " session_description = " << ei->session_description; + } + + if(got.empty() && expect.empty()) + { + SUCCEED(); + } +} + +TEST(DirectPlay8Peer, EnumHostsSync) +{ + SessionHost a1s1(APP_GUID_1, L"Application 1 Session 1"); + SessionHost a1s2(APP_GUID_1, L"Application 1 Session 2"); + SessionHost a2s1(APP_GUID_2, L"Application 2 Session 1"); + + std::map sessions; + + bool got_async_op_complete = false; + + std::function client_cb = + [&sessions, &got_async_op_complete] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_RESPONSE) + { + DPNMSG_ENUM_HOSTS_RESPONSE *ehr = (DPNMSG_ENUM_HOSTS_RESPONSE*)(pMessage); + + EXPECT_EQ(ehr->dwSize, sizeof(DPNMSG_ENUM_HOSTS_RESPONSE)); + EXPECT_EQ(ehr->pvUserContext, (void*)(0xBEEF)); + + sessions.emplace( + ehr->pApplicationDescription->guidInstance, + FoundSession( + ehr->pApplicationDescription->guidApplication, + ehr->pApplicationDescription->pwszSessionName)); + } + else if(dwMessageType == DPN_MSGID_ASYNC_OP_COMPLETE) + { + got_async_op_complete = true; + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&client_cb, &callback_shim, 0), S_OK); + + DWORD start = GetTickCount(); + + ASSERT_EQ(client.EnumHosts( + NULL, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xBEEF), /* pvUserContext */ + NULL, /* pAsyncHandle */ + DPNENUMHOSTS_SYNC /* dwFlags */ + ), S_OK); + + DWORD end = GetTickCount(); + + FoundSession expect_sessions[] = { + FoundSession(APP_GUID_1, L"Application 1 Session 1"), + FoundSession(APP_GUID_1, L"Application 1 Session 2"), + FoundSession(APP_GUID_2, L"Application 2 Session 1"), + }; + + EXPECT_SESSIONS(sessions, expect_sessions, expect_sessions + 3); + + DWORD enum_time_ms = end - start; + EXPECT_TRUE((enum_time_ms >= 1250) && (enum_time_ms <= 1750)); + + EXPECT_FALSE(got_async_op_complete); +} + +TEST(DirectPlay8Peer, EnumHostsAsync) +{ + SessionHost a1s1(APP_GUID_1, L"Application 1 Session 1"); + SessionHost a1s2(APP_GUID_1, L"Application 1 Session 2"); + SessionHost a2s1(APP_GUID_2, L"Application 2 Session 1"); + + std::map sessions; + + bool got_async_op_complete = false; + DWORD got_async_op_complete_at; + DPNHANDLE async_handle; + + std::function callback = + [&sessions, &got_async_op_complete, &got_async_op_complete_at, &async_handle] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_RESPONSE) + { + DPNMSG_ENUM_HOSTS_RESPONSE *ehr = (DPNMSG_ENUM_HOSTS_RESPONSE*)(pMessage); + + EXPECT_EQ(ehr->dwSize, sizeof(DPNMSG_ENUM_HOSTS_RESPONSE)); + EXPECT_EQ(ehr->pvUserContext, (void*)(0xABCD)); + + sessions.emplace( + ehr->pApplicationDescription->guidInstance, + FoundSession( + ehr->pApplicationDescription->guidApplication, + ehr->pApplicationDescription->pwszSessionName)); + } + else if(dwMessageType == DPN_MSGID_ASYNC_OP_COMPLETE) + { + got_async_op_complete_at = GetTickCount(); + + /* We shouldn't get DPNMSG_ASYNC_OP_COMPLETE multiple times. */ + EXPECT_FALSE(got_async_op_complete); + + DPNMSG_ASYNC_OP_COMPLETE *oc = (DPNMSG_ASYNC_OP_COMPLETE*)(pMessage); + + EXPECT_EQ(oc->dwSize, sizeof(DPNMSG_ASYNC_OP_COMPLETE)); + EXPECT_EQ(oc->hAsyncOp, async_handle); + EXPECT_EQ(oc->pvUserContext, (void*)(0xABCD)); + EXPECT_EQ(oc->hResultCode, S_OK); + + got_async_op_complete = true; + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&callback, &callback_shim, 0), S_OK); + + DWORD start = GetTickCount(); + + ASSERT_EQ(client.EnumHosts( + NULL, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xABCD), /* pvUserContext */ + &async_handle, /* pAsyncHandle */ + 0 /* dwFlags */ + ), DPNSUCCESS_PENDING); + + Sleep(3000); + + FoundSession expect_sessions[] = { + FoundSession(APP_GUID_1, L"Application 1 Session 1"), + FoundSession(APP_GUID_1, L"Application 1 Session 2"), + FoundSession(APP_GUID_2, L"Application 2 Session 1"), + }; + + EXPECT_SESSIONS(sessions, expect_sessions, expect_sessions + 3); + + EXPECT_TRUE(got_async_op_complete); + + if(got_async_op_complete) + { + DWORD enum_time_ms = got_async_op_complete_at - start; + EXPECT_TRUE((enum_time_ms >= 1250) && (enum_time_ms <= 1750)); + } +} + +TEST(DirectPlay8Peer, EnumHostsAsyncCancelByHandle) +{ + bool got_async_op_complete = false; + DWORD got_async_op_complete_at; + DPNHANDLE async_handle; + + std::function callback = + [&got_async_op_complete, &got_async_op_complete_at, &async_handle] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ASYNC_OP_COMPLETE) + { + got_async_op_complete_at = GetTickCount(); + + /* We shouldn't get DPNMSG_ASYNC_OP_COMPLETE multiple times. */ + EXPECT_FALSE(got_async_op_complete); + + DPNMSG_ASYNC_OP_COMPLETE *oc = (DPNMSG_ASYNC_OP_COMPLETE*)(pMessage); + + EXPECT_EQ(oc->dwSize, sizeof(DPNMSG_ASYNC_OP_COMPLETE)); + EXPECT_EQ(oc->hAsyncOp, async_handle); + EXPECT_EQ(oc->pvUserContext, (void*)(0xABCD)); + EXPECT_EQ(oc->hResultCode, DPNERR_USERCANCEL); + + got_async_op_complete = true; + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&callback, &callback_shim, 0), S_OK); + + DWORD start = GetTickCount(); + + ASSERT_EQ(client.EnumHosts( + NULL, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xABCD), /* pvUserContext */ + &async_handle, /* pAsyncHandle */ + 0 /* dwFlags */ + ), DPNSUCCESS_PENDING); + + ASSERT_EQ(client.CancelAsyncOperation(async_handle, 0), S_OK); + + Sleep(500); + + EXPECT_TRUE(got_async_op_complete); + + if(got_async_op_complete) + { + DWORD enum_time_ms = got_async_op_complete_at - start; + EXPECT_TRUE(enum_time_ms <= 250); + } +} + +TEST(DirectPlay8Peer, EnumHostsAsyncCancelAllEnums) +{ + bool got_async_op_complete = false; + DWORD got_async_op_complete_at; + DPNHANDLE async_handle; + + std::function callback = + [&got_async_op_complete, &got_async_op_complete_at, &async_handle] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ASYNC_OP_COMPLETE) + { + got_async_op_complete_at = GetTickCount(); + + /* We shouldn't get DPNMSG_ASYNC_OP_COMPLETE multiple times. */ + EXPECT_FALSE(got_async_op_complete); + + DPNMSG_ASYNC_OP_COMPLETE *oc = (DPNMSG_ASYNC_OP_COMPLETE*)(pMessage); + + EXPECT_EQ(oc->dwSize, sizeof(DPNMSG_ASYNC_OP_COMPLETE)); + EXPECT_EQ(oc->hAsyncOp, async_handle); + EXPECT_EQ(oc->pvUserContext, (void*)(0xABCD)); + EXPECT_EQ(oc->hResultCode, DPNERR_USERCANCEL); + + got_async_op_complete = true; + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&callback, &callback_shim, 0), S_OK); + + DWORD start = GetTickCount(); + + ASSERT_EQ(client.EnumHosts( + NULL, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xABCD), /* pvUserContext */ + &async_handle, /* pAsyncHandle */ + 0 /* dwFlags */ + ), DPNSUCCESS_PENDING); + + ASSERT_EQ(client.CancelAsyncOperation(0, DPNCANCEL_ENUM), S_OK); + + Sleep(500); + + EXPECT_TRUE(got_async_op_complete); + + if(got_async_op_complete) + { + DWORD enum_time_ms = got_async_op_complete_at - start; + EXPECT_TRUE(enum_time_ms <= 250); + } +} + +TEST(DirectPlay8Peer, EnumHostsAsyncCancelAllOperations) +{ + bool got_async_op_complete = false; + DWORD got_async_op_complete_at; + DPNHANDLE async_handle; + + std::function callback = + [&got_async_op_complete, &got_async_op_complete_at, &async_handle] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ASYNC_OP_COMPLETE) + { + got_async_op_complete_at = GetTickCount(); + + /* We shouldn't get DPNMSG_ASYNC_OP_COMPLETE multiple times. */ + EXPECT_FALSE(got_async_op_complete); + + DPNMSG_ASYNC_OP_COMPLETE *oc = (DPNMSG_ASYNC_OP_COMPLETE*)(pMessage); + + EXPECT_EQ(oc->dwSize, sizeof(DPNMSG_ASYNC_OP_COMPLETE)); + EXPECT_EQ(oc->hAsyncOp, async_handle); + EXPECT_EQ(oc->pvUserContext, (void*)(0xABCD)); + EXPECT_EQ(oc->hResultCode, DPNERR_USERCANCEL); + + got_async_op_complete = true; + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&callback, &callback_shim, 0), S_OK); + + DWORD start = GetTickCount(); + + ASSERT_EQ(client.EnumHosts( + NULL, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xABCD), /* pvUserContext */ + &async_handle, /* pAsyncHandle */ + 0 /* dwFlags */ + ), DPNSUCCESS_PENDING); + + ASSERT_EQ(client.CancelAsyncOperation(0, DPNCANCEL_ALL_OPERATIONS), S_OK); + + Sleep(500); + + EXPECT_TRUE(got_async_op_complete); + + if(got_async_op_complete) + { + DWORD enum_time_ms = got_async_op_complete_at - start; + EXPECT_TRUE(enum_time_ms <= 250); + } +} + +TEST(DirectPlay8Peer, EnumHostsAsyncCancelByClose) +{ + bool got_async_op_complete = false; + DWORD got_async_op_complete_at; + DPNHANDLE async_handle; + + std::function callback = + [&got_async_op_complete, &got_async_op_complete_at, &async_handle] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ASYNC_OP_COMPLETE) + { + got_async_op_complete_at = GetTickCount(); + + /* We shouldn't get DPNMSG_ASYNC_OP_COMPLETE multiple times. */ + EXPECT_FALSE(got_async_op_complete); + + DPNMSG_ASYNC_OP_COMPLETE *oc = (DPNMSG_ASYNC_OP_COMPLETE*)(pMessage); + + EXPECT_EQ(oc->dwSize, sizeof(DPNMSG_ASYNC_OP_COMPLETE)); + EXPECT_EQ(oc->hAsyncOp, async_handle); + EXPECT_EQ(oc->pvUserContext, (void*)(0xABCD)); + EXPECT_EQ(oc->hResultCode, DPNERR_USERCANCEL); + + got_async_op_complete = true; + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&callback, &callback_shim, 0), S_OK); + + DWORD start = GetTickCount(); + + ASSERT_EQ(client.EnumHosts( + NULL, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xABCD), /* pvUserContext */ + &async_handle, /* pAsyncHandle */ + 0 /* dwFlags */ + ), DPNSUCCESS_PENDING); + + ASSERT_EQ(client.Close(0), S_OK); + + EXPECT_TRUE(got_async_op_complete); + + if(got_async_op_complete) + { + DWORD enum_time_ms = got_async_op_complete_at - start; + EXPECT_TRUE(enum_time_ms <= 250); + } +} + +TEST(DirectPlay8Peer, EnumHostsFilterByApplicationGUID) +{ + bool right_app_got_host_enum_query = false; + bool wrong_app_got_host_enum_query = false; + + SessionHost a1s1(APP_GUID_1, L"Application 1 Session 1", + [&wrong_app_got_host_enum_query] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_QUERY) + { + wrong_app_got_host_enum_query = true; + } + + return DPN_OK; + }); + + SessionHost a1s2(APP_GUID_1, L"Application 1 Session 2", + [&wrong_app_got_host_enum_query] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_QUERY) + { + wrong_app_got_host_enum_query = true; + } + + return DPN_OK; + }); + + SessionHost a2s1(APP_GUID_2, L"Application 2 Session 1", + [&right_app_got_host_enum_query] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_QUERY) + { + DPNMSG_ENUM_HOSTS_QUERY *ehq = (DPNMSG_ENUM_HOSTS_QUERY*)(pMessage); + + EXPECT_EQ(ehq->dwSize, sizeof(DPNMSG_ENUM_HOSTS_QUERY)); + + /* TODO: Check pAddressSender, pAddressDevice */ + + EXPECT_EQ(ehq->pvReceivedData, nullptr); + EXPECT_EQ(ehq->dwReceivedDataSize, 0); + + EXPECT_EQ(ehq->pvResponseData, nullptr); + EXPECT_EQ(ehq->dwResponseDataSize, 0); + + right_app_got_host_enum_query = true; + } + + return DPN_OK; + }); + + std::map sessions; + + std::function client_cb = + [&sessions] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_RESPONSE) + { + DPNMSG_ENUM_HOSTS_RESPONSE *ehr = (DPNMSG_ENUM_HOSTS_RESPONSE*)(pMessage); + + EXPECT_EQ(ehr->dwSize, sizeof(DPNMSG_ENUM_HOSTS_RESPONSE)); + EXPECT_EQ(ehr->pvUserContext, (void*)(0xBEEF)); + + sessions.emplace( + ehr->pApplicationDescription->guidInstance, + FoundSession( + ehr->pApplicationDescription->guidApplication, + ehr->pApplicationDescription->pwszSessionName)); + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&client_cb, &callback_shim, 0), S_OK); + + DPN_APPLICATION_DESC app_desc; + memset(&app_desc, 0, sizeof(app_desc)); + + app_desc.dwSize = sizeof(app_desc); + app_desc.guidApplication = APP_GUID_2; + + ASSERT_EQ(client.EnumHosts( + &app_desc, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xBEEF), /* pvUserContext */ + NULL, /* pAsyncHandle */ + DPNENUMHOSTS_SYNC /* dwFlags */ + ), S_OK); + + FoundSession expect_sessions[] = { + FoundSession(APP_GUID_2, L"Application 2 Session 1"), + }; + + EXPECT_SESSIONS(sessions, expect_sessions, expect_sessions + 1); + + EXPECT_TRUE(right_app_got_host_enum_query); + EXPECT_FALSE(wrong_app_got_host_enum_query); +} + +TEST(DirectPlay8Peer, EnumHostsFilterByNULLApplicationGUID) +{ + SessionHost a1s1(APP_GUID_1, L"Application 1 Session 1"); + SessionHost a1s2(APP_GUID_1, L"Application 1 Session 2"); + SessionHost a2s1(APP_GUID_2, L"Application 2 Session 1"); + + std::map sessions; + + std::function client_cb = + [&sessions] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_RESPONSE) + { + DPNMSG_ENUM_HOSTS_RESPONSE *ehr = (DPNMSG_ENUM_HOSTS_RESPONSE*)(pMessage); + + EXPECT_EQ(ehr->dwSize, sizeof(DPNMSG_ENUM_HOSTS_RESPONSE)); + EXPECT_EQ(ehr->pvUserContext, (void*)(0xBEEF)); + + sessions.emplace( + ehr->pApplicationDescription->guidInstance, + FoundSession( + ehr->pApplicationDescription->guidApplication, + ehr->pApplicationDescription->pwszSessionName)); + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&client_cb, &callback_shim, 0), S_OK); + + DPN_APPLICATION_DESC app_desc; + memset(&app_desc, 0, sizeof(app_desc)); + + app_desc.dwSize = sizeof(app_desc); + app_desc.guidApplication = GUID_NULL; + + ASSERT_EQ(client.EnumHosts( + &app_desc, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xBEEF), /* pvUserContext */ + NULL, /* pAsyncHandle */ + DPNENUMHOSTS_SYNC /* dwFlags */ + ), S_OK); + + FoundSession expect_sessions[] = { + FoundSession(APP_GUID_1, L"Application 1 Session 1"), + FoundSession(APP_GUID_1, L"Application 1 Session 2"), + FoundSession(APP_GUID_2, L"Application 2 Session 1"), + }; + + EXPECT_SESSIONS(sessions, expect_sessions, expect_sessions + 3); +} + +TEST(DirectPlay8Peer, EnumHostsDataInQuery) +{ + static const unsigned char DATA[] = { 0x00, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF }; + + SessionHost a1s1(APP_GUID_1, L"Application 1 Session 1", + [] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_QUERY) + { + DPNMSG_ENUM_HOSTS_QUERY *ehq = (DPNMSG_ENUM_HOSTS_QUERY*)(pMessage); + + EXPECT_EQ(ehq->dwSize, sizeof(DPNMSG_ENUM_HOSTS_QUERY)); + + std::vector got_data( + (const unsigned char*)(ehq->pvReceivedData), + (const unsigned char*)(ehq->pvReceivedData) + ehq->dwReceivedDataSize); + + std::vector expect_data(DATA, DATA + sizeof(DATA)); + + EXPECT_EQ(got_data, expect_data); + + EXPECT_EQ(ehq->pvResponseData, nullptr); + EXPECT_EQ(ehq->dwResponseDataSize, 0); + } + + return DPN_OK; + }); + + std::map sessions; + + std::function client_cb = + [&sessions] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_RESPONSE) + { + DPNMSG_ENUM_HOSTS_RESPONSE *ehr = (DPNMSG_ENUM_HOSTS_RESPONSE*)(pMessage); + + EXPECT_EQ(ehr->dwSize, sizeof(DPNMSG_ENUM_HOSTS_RESPONSE)); + EXPECT_EQ(ehr->pvUserContext, (void*)(0xBEEF)); + + EXPECT_EQ(ehr->pvResponseData, nullptr); + EXPECT_EQ(ehr->dwResponseDataSize, 0); + + sessions.emplace( + ehr->pApplicationDescription->guidInstance, + FoundSession( + ehr->pApplicationDescription->guidApplication, + ehr->pApplicationDescription->pwszSessionName)); + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&client_cb, &callback_shim, 0), S_OK); + + ASSERT_EQ(client.EnumHosts( + NULL, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + (void*)(DATA), /* pvUserEnumData */ + sizeof(DATA), /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xBEEF), /* pvUserContext */ + NULL, /* pAsyncHandle */ + DPNENUMHOSTS_SYNC /* dwFlags */ + ), S_OK); + + FoundSession expect_sessions[] = { + FoundSession(APP_GUID_1, L"Application 1 Session 1"), + }; + + EXPECT_SESSIONS(sessions, expect_sessions, expect_sessions + 1); +} + +TEST(DirectPlay8Peer, EnumHostsDataInResponse) +{ + static const unsigned char DATA[] = { 0x00, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF }; + + bool got_return_buffer = false; + + SessionHost a1s1(APP_GUID_1, L"Application 1 Session 1", + [&got_return_buffer] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_QUERY) + { + DPNMSG_ENUM_HOSTS_QUERY *ehq = (DPNMSG_ENUM_HOSTS_QUERY*)(pMessage); + + EXPECT_EQ(ehq->dwSize, sizeof(DPNMSG_ENUM_HOSTS_QUERY)); + + EXPECT_EQ(ehq->pvReceivedData, nullptr); + EXPECT_EQ(ehq->dwReceivedDataSize, 0); + + EXPECT_EQ(ehq->pvResponseData, nullptr); + EXPECT_EQ(ehq->dwResponseDataSize, 0); + + ehq->pvResponseData = (void*)(DATA); + ehq->dwResponseDataSize = sizeof(DATA); + ehq->pvResponseContext = (void*)(0x1234); + } + else if(dwMessageType == DPN_MSGID_RETURN_BUFFER) + { + DPNMSG_RETURN_BUFFER *rb = (DPNMSG_RETURN_BUFFER*)(pMessage); + + EXPECT_EQ(rb->dwSize, sizeof(*rb)); + EXPECT_EQ(rb->hResultCode, S_OK); + EXPECT_EQ(rb->pvBuffer, DATA); + EXPECT_EQ(rb->pvUserContext, (void*)(0x1234)); + + got_return_buffer = true; + } + + return DPN_OK; + }); + + std::map sessions; + + std::function client_cb = + [&sessions] + (DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_ENUM_HOSTS_RESPONSE) + { + DPNMSG_ENUM_HOSTS_RESPONSE *ehr = (DPNMSG_ENUM_HOSTS_RESPONSE*)(pMessage); + + EXPECT_EQ(ehr->dwSize, sizeof(DPNMSG_ENUM_HOSTS_RESPONSE)); + EXPECT_EQ(ehr->pvUserContext, (void*)(0xBEEF)); + + std::vector got_data( + (const unsigned char*)(ehr->pvResponseData), + (const unsigned char*)(ehr->pvResponseData) + ehr->dwResponseDataSize); + + std::vector expect_data(DATA, DATA + sizeof(DATA)); + + EXPECT_EQ(got_data, expect_data); + + sessions.emplace( + ehr->pApplicationDescription->guidInstance, + FoundSession( + ehr->pApplicationDescription->guidApplication, + ehr->pApplicationDescription->pwszSessionName)); + } + + return DPN_OK; + }; + + DirectPlay8Peer client(NULL); + + ASSERT_EQ(client.Initialize(&client_cb, &callback_shim, 0), S_OK); + + ASSERT_EQ(client.EnumHosts( + NULL, /* pApplicationDesc */ + NULL, /* pdpaddrHost */ + NULL, /* pdpaddrDeviceInfo */ + NULL, /* pvUserEnumData */ + 0, /* dwUserEnumDataSize */ + 3, /* dwEnumCount */ + 500, /* dwRetryInterval */ + 500, /* dwTimeOut*/ + (void*)(0xBEEF), /* pvUserContext */ + NULL, /* pAsyncHandle */ + DPNENUMHOSTS_SYNC /* dwFlags */ + ), S_OK); + + FoundSession expect_sessions[] = { + FoundSession(APP_GUID_1, L"Application 1 Session 1"), + }; + + EXPECT_SESSIONS(sessions, expect_sessions, expect_sessions + 1); + + EXPECT_TRUE(got_return_buffer); +} + +/* TODO: Test enumerating a session directly. */