From b74b018ff205147d0ebaea0106c43821af97051e Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Mon, 15 Oct 2018 13:29:07 +0100 Subject: [PATCH] Implement IDirectPlay8Peer::TerminateSession() --- src/DirectPlay8Peer.cpp | 236 +++++++++++++++++++++++++++++++++++- src/DirectPlay8Peer.hpp | 1 + src/Messages.hpp | 8 ++ tests/DirectPlay8Peer.cpp | 244 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 488 insertions(+), 1 deletion(-) diff --git a/src/DirectPlay8Peer.cpp b/src/DirectPlay8Peer.cpp index a5ccc22..90f9733 100644 --- a/src/DirectPlay8Peer.cpp +++ b/src/DirectPlay8Peer.cpp @@ -1960,7 +1960,138 @@ HRESULT DirectPlay8Peer::RegisterLobby(CONST DPNHANDLE dpnHandle, struct IDirect HRESULT DirectPlay8Peer::TerminateSession(void* CONST pvTerminateData, CONST DWORD dwTerminateDataSize, CONST DWORD dwFlags) { - UNIMPLEMENTED("DirectPlay8Peer::TerminateSession"); + std::unique_lock l(lock); + + switch(state) + { + case STATE_NEW: return DPNERR_UNINITIALIZED; + case STATE_INITIALISED: return DPNERR_NOCONNECTION; + case STATE_HOSTING: break; + case STATE_CONNECTING_TO_HOST: return DPNERR_CONNECTING; + case STATE_CONNECTING_TO_PEERS: return DPNERR_CONNECTING; + case STATE_CONNECT_FAILED: return DPNERR_CONNECTING; + case STATE_CONNECTED: return DPNERR_NOTHOST; + case STATE_CLOSING: return DPNERR_CONNECTIONLOST; + case STATE_TERMINATED: return DPNERR_HOSTTERMINATEDSESSION; + } + + if(discovery_socket != -1) + { + closesocket(discovery_socket); + discovery_socket = -1; + } + + if(listener_socket != -1) + { + closesocket(listener_socket); + listener_socket = -1; + } + + if(udp_socket != -1) + { + closesocket(udp_socket); + udp_socket = -1; + } + + /* First, we iterate over all the peers. + * + * For connected peers: Notify them of the impending doom, and add them to closing_peers + * so we can raise a DPNMSG_DESTROY_PLAYER later. + * + * For other peers: Add them to destroy_peers, to be destroyed later. + * + * We defer these actions until later to ensure nothing can change the state of the peers + * underneath us, as dealing with peers potentially changing state while this runs would + * be rather horrible. + */ + + PacketSerialiser terminate_session(DPLITE_MSGID_TERMINATE_SESSION); + terminate_session.append_data(pvTerminateData, dwTerminateDataSize); + + std::list< std::pair > closing_peers; + std::list destroy_peers; + + for(auto pi = peers.begin(); pi != peers.end(); ++pi) + { + unsigned int peer_id = pi->first; + Peer *peer = pi->second; + + if(peer->state == Peer::PS_CONNECTED) + { + peer->sq.send(SendQueue::SEND_PRI_HIGH, terminate_session, NULL, [](std::unique_lock &l, HRESULT result){}); + peer->state = Peer::PS_CLOSING; + + closing_peers.push_back(std::make_pair(peer->player_id, peer->player_ctx)); + } + else if(peer->state == Peer::PS_CLOSING) + { + /* Do nothing. We're waiting for this peer to go away. */ + } + else{ + destroy_peers.push_back(peer_id); + } + } + + state = STATE_TERMINATED; + + /* Raise DPNMSG_TERMINATE_SESSION. */ + + { + DPNMSG_TERMINATE_SESSION ts; + memset(&ts, 0, sizeof(ts)); + + ts.dwSize = sizeof(DPNMSG_TERMINATE_SESSION); + ts.hResultCode = DPNERR_HOSTTERMINATEDSESSION; + ts.pvTerminateData = pvTerminateData; + ts.dwTerminateDataSize = dwTerminateDataSize; + + l.unlock(); + message_handler(message_handler_ctx, DPN_MSGID_TERMINATE_SESSION, &ts); + l.lock(); + } + + /* Raise a DPNMSG_DESTROY_PLAYER for ourself. */ + + { + DPNMSG_DESTROY_PLAYER dp; + memset(&dp, 0, sizeof(dp)); + + dp.dwSize = sizeof(DPNMSG_DESTROY_PLAYER); + dp.dpnidPlayer = local_player_id; + dp.pvPlayerContext = local_player_ctx; + dp.dwReason = DPNDESTROYPLAYERREASON_SESSIONTERMINATED; + + l.unlock(); + message_handler(message_handler_ctx, DPN_MSGID_DESTROY_PLAYER, &dp); + l.lock(); + } + + /* Raise a DPNMSG_DESTROY_PLAYER for each connected peer. */ + + for(auto cp = closing_peers.begin(); cp != closing_peers.end(); ++cp) + { + DPNMSG_DESTROY_PLAYER dp; + memset(&dp, 0, sizeof(dp)); + + dp.dwSize = sizeof(dp); + dp.dpnidPlayer = cp->first; + dp.pvPlayerContext = cp->second; + dp.dwReason = DPNDESTROYPLAYERREASON_SESSIONTERMINATED; + + l.unlock(); + message_handler(message_handler_ctx, DPN_MSGID_DESTROY_PLAYER, &dp); + l.lock(); + } + + /* Destroy any peers which weren't fully connected. */ + + for(auto dp = destroy_peers.begin(); dp != destroy_peers.end();) + { + unsigned int peer_id = *dp; + peer_destroy(l, *dp, DPNERR_USERCANCEL, DPNDESTROYPLAYERREASON_NORMAL); + } + + return S_OK; } DirectPlay8Peer::Peer *DirectPlay8Peer::get_peer_by_peer_id(unsigned int peer_id) @@ -2509,6 +2640,12 @@ void DirectPlay8Peer::io_peer_recv(std::unique_lock &l, unsigned int break; } + case DPLITE_MSGID_TERMINATE_SESSION: + { + handle_terminate_session(l, peer_id, *pd); + break; + } + default: log_printf( "Unexpected message type %u received from peer %u", @@ -3799,6 +3936,103 @@ void DirectPlay8Peer::handle_destroy_peer(std::unique_lock &l, unsig } } +void DirectPlay8Peer::handle_terminate_session(std::unique_lock &l, unsigned int peer_id, const PacketDeserialiser &pd) +{ + Peer *peer = get_peer_by_peer_id(peer_id); + assert(peer != NULL); + + try { + std::pair terminate_data = pd.get_data(0); + + if(peer->state != Peer::PS_CONNECTED) + { + log_printf("Received unexpected DPLITE_MSGID_TERMINATE_SESSION from peer %u, in state %u", + peer_id, (unsigned)(peer->state)); + return; + } + + /* host_player_id must be initialised by this point, as the host is always the + * first peer to enter state PS_CONNECTED, initialising it in the process. + */ + + if(peer->player_id != host_player_id) + { + log_printf("Received unexpected DPLITE_MSGID_TERMINATE_SESSION from non-host peer %u", + peer_id); + return; + } + + state = STATE_TERMINATED; + + DPNMSG_TERMINATE_SESSION ts; + memset(&ts, 0, sizeof(ts)); + + ts.dwSize = sizeof(DPNMSG_TERMINATE_SESSION); + ts.hResultCode = DPNERR_HOSTTERMINATEDSESSION; + ts.pvTerminateData = (void*)(terminate_data.first); /* TODO: Make non-const copy? */ + ts.dwTerminateDataSize = terminate_data.second; + + l.unlock(); + message_handler(message_handler_ctx, DPN_MSGID_TERMINATE_SESSION, &ts); + l.lock(); + + DPNMSG_DESTROY_PLAYER dp; + memset(&dp, 0, sizeof(dp)); + + dp.dwSize = sizeof(DPNMSG_DESTROY_PLAYER); + dp.dpnidPlayer = local_player_id; + dp.pvPlayerContext = local_player_ctx; + dp.dwReason = DPNDESTROYPLAYERREASON_SESSIONTERMINATED; + + l.unlock(); + message_handler(message_handler_ctx, DPN_MSGID_DESTROY_PLAYER, &dp); + l.lock(); + + for(auto pi = peers.begin(); pi != peers.end();) + { + unsigned int peer_id = pi->first; + Peer *peer = pi->second; + + if(peer->state == Peer::PS_CONNECTED) + { + DPNMSG_DESTROY_PLAYER dp; + memset(&dp, 0, sizeof(dp)); + + dp.dwSize = sizeof(dp); + dp.dpnidPlayer = peer->player_id; + dp.pvPlayerContext = peer->player_ctx; + dp.dwReason = DPNDESTROYPLAYERREASON_SESSIONTERMINATED; + + peer->state = Peer::PS_CLOSING; + + /* Wake up a worker to deal with closing the connection. */ + SetEvent(peer->event); + + l.unlock(); + message_handler(message_handler_ctx, DPN_MSGID_DESTROY_PLAYER, &dp); + l.lock(); + + pi = peers.begin(); + } + else if(peer->state == Peer::PS_CLOSING) + { + /* Do nothing. We're waiting for this peer to go away. */ + ++pi; + } + else{ + peer_destroy(l, peer_id, DPNERR_USERCANCEL, DPNDESTROYPLAYERREASON_NORMAL); + + pi = peers.begin(); + } + } + } + catch(const PacketDeserialiser::Error &e) + { + log_printf("Received invalid DPLITE_MSGID_TERMINATE_SESSION from peer %u: %s", + peer_id, e.what()); + } +} + /* Check if we have finished connecting and should enter STATE_CONNECTED. * * This is called after processing either of: diff --git a/src/DirectPlay8Peer.hpp b/src/DirectPlay8Peer.hpp index eda8375..676af05 100644 --- a/src/DirectPlay8Peer.hpp +++ b/src/DirectPlay8Peer.hpp @@ -212,6 +212,7 @@ class DirectPlay8Peer: public IDirectPlay8Peer void handle_ack(std::unique_lock &l, unsigned int peer_id, const PacketDeserialiser &pd); void handle_appdesc(std::unique_lock &l, unsigned int peer_id, const PacketDeserialiser &pd); void handle_destroy_peer(std::unique_lock &l, unsigned int peer_id, const PacketDeserialiser &pd); + void handle_terminate_session(std::unique_lock &l, unsigned int peer_id, const PacketDeserialiser &pd); void connect_check(std::unique_lock &l); void connect_fail(std::unique_lock &l, HRESULT hResultCode, const void *pvApplicationReplyData, DWORD dwApplicationReplyDataSize); diff --git a/src/Messages.hpp b/src/Messages.hpp index bf9f156..899f78e 100644 --- a/src/Messages.hpp +++ b/src/Messages.hpp @@ -149,4 +149,12 @@ * DATA - DPNMSG_TERMINATE_SESSION.pvTerminateData (only from host to victim) */ +#define DPLITE_MSGID_TERMINATE_SESSION 14 + +/* Host is destroying the session. This message is sent from only the host and to all peers in the + * session simultaneously. + * + * DATA - pvTerminateData passed to IDirectPlay8Peer::TerminateSession() +*/ + #endif /* !DPLITE_MESSAGES_HPP */ diff --git a/tests/DirectPlay8Peer.cpp b/tests/DirectPlay8Peer.cpp index da48a5d..f94674e 100644 --- a/tests/DirectPlay8Peer.cpp +++ b/tests/DirectPlay8Peer.cpp @@ -6795,3 +6795,247 @@ TEST(DirectPlay8Peer, DestroyPeer) EXPECT_EQ(p1_dp_dpnidPlayer, all_players); EXPECT_EQ(p1_ts, 1); } + +TEST(DirectPlay8Peer, TerminateSession) +{ + const unsigned char TERMINATE_DATA[] = { 0x01, 0x00, 0x02, 0x03, 0x04, 0x05, 0x06 }; + + DPN_APPLICATION_DESC app_desc; + memset(&app_desc, 0, sizeof(app_desc)); + + app_desc.dwSize = sizeof(app_desc); + app_desc.guidApplication = APP_GUID_1; + app_desc.pwszSessionName = L"Session 1"; + + IDP8AddressInstance host_addr(CLSID_DP8SP_TCPIP, PORT); + + TestPeer host("host"); + ASSERT_EQ(host->Host(&app_desc, &(host_addr.instance), 1, NULL, NULL, 0, 0), S_OK); + + IDP8AddressInstance connect_addr(CLSID_DP8SP_TCPIP, L"127.0.0.1", PORT); + + TestPeer peer1("peer1"); + ASSERT_EQ(peer1->Connect( + &app_desc, /* pdnAppDesc */ + connect_addr, /* pHostAddr */ + NULL, /* pDeviceInfo */ + NULL, /* pdnSecurity */ + NULL, /* pdnCredentials */ + NULL, /* pvUserConnectData */ + 0, /* dwUserConnectDataSize */ + 0, /* pvPlayerContext */ + NULL, /* pvAsyncContext */ + NULL, /* phAsyncHandle */ + DPNCONNECT_SYNC /* dwFlags */ + ), S_OK); + + TestPeer peer2("peer2"); + ASSERT_EQ(peer2->Connect( + &app_desc, /* pdnAppDesc */ + connect_addr, /* pHostAddr */ + NULL, /* pDeviceInfo */ + NULL, /* pdnSecurity */ + NULL, /* pdnCredentials */ + NULL, /* pvUserConnectData */ + 0, /* dwUserConnectDataSize */ + 0, /* pvPlayerContext */ + NULL, /* pvAsyncContext */ + NULL, /* phAsyncHandle */ + DPNCONNECT_SYNC /* dwFlags */ + ), S_OK); + + Sleep(100); + + std::set h_dp_dpnidPlayer; + int h_ts = 0; + + host.expect_begin(); + host.expect_push([&host, &peer1, &peer2, &h_dp_dpnidPlayer, &h_ts, &TERMINATE_DATA](DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_DESTROY_PLAYER) + { + DPNMSG_DESTROY_PLAYER *dp = (DPNMSG_DESTROY_PLAYER*)(pMessage); + + EXPECT_EQ(dp->dwSize, sizeof(DPNMSG_DESTROY_PLAYER)); + + EXPECT_TRUE((dp->dpnidPlayer == host.first_cp_dpnidPlayer || dp->dpnidPlayer == peer1.first_cc_dpnidLocal || dp->dpnidPlayer == peer2.first_cc_dpnidLocal)) + << "(dpnidPlayer = " << dp->dpnidPlayer + << ", host = " << host.first_cp_dpnidPlayer + << ", peer1 = " << peer1.first_cc_dpnidLocal + << ", peer2 = " << peer2.first_cc_dpnidLocal << ")"; + + EXPECT_EQ(dp->pvPlayerContext, (void*)~(uintptr_t)(dp->dpnidPlayer)); + + if(dp->dpnidPlayer == host.first_cp_dpnidPlayer) + { + EXPECT_EQ(dp->dwReason, DPNDESTROYPLAYERREASON_SESSIONTERMINATED); + } + else{ + EXPECT_TRUE((dp->dwReason == DPNDESTROYPLAYERREASON_SESSIONTERMINATED || dp->dwReason == DPNDESTROYPLAYERREASON_NORMAL)) + << "dwReason = " << dp->dwReason; + } + + h_dp_dpnidPlayer.insert(dp->dpnidPlayer); + } + else if(dwMessageType == DPN_MSGID_TERMINATE_SESSION) + { + DPNMSG_TERMINATE_SESSION *ts = (DPNMSG_TERMINATE_SESSION*)(pMessage); + + EXPECT_EQ(ts->dwSize, sizeof(DPNMSG_TERMINATE_SESSION)); + EXPECT_EQ(ts->hResultCode, DPNERR_HOSTTERMINATEDSESSION); + + EXPECT_EQ( + std::string((const char*)(ts->pvTerminateData), ts->dwTerminateDataSize), + std::string((const char*)(TERMINATE_DATA), sizeof(TERMINATE_DATA))); + + ++h_ts; + } + else{ + ADD_FAILURE() << "Unexpected message type: " << dwMessageType; + } + + return DPN_OK; + }, 4); + + std::set p1_dp_dpnidPlayer; + int p1_ts = 0; + + peer1.expect_begin(); + peer1.expect_push([&host, &peer1, &peer2, &p1_dp_dpnidPlayer, &p1_ts, &TERMINATE_DATA](DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_DESTROY_PLAYER) + { + DPNMSG_DESTROY_PLAYER *dp = (DPNMSG_DESTROY_PLAYER*)(pMessage); + + EXPECT_EQ(dp->dwSize, sizeof(DPNMSG_DESTROY_PLAYER)); + + EXPECT_TRUE((dp->dpnidPlayer == host.first_cp_dpnidPlayer || dp->dpnidPlayer == peer1.first_cc_dpnidLocal || dp->dpnidPlayer == peer2.first_cc_dpnidLocal)) + << "(dpnidPlayer = " << dp->dpnidPlayer + << ", host = " << host.first_cp_dpnidPlayer + << ", peer1 = " << peer1.first_cc_dpnidLocal + << ", peer2 = " << peer2.first_cc_dpnidLocal << ")"; + + EXPECT_EQ(dp->pvPlayerContext, (void*)~(uintptr_t)(dp->dpnidPlayer)); + + if(dp->dpnidPlayer == host.first_cp_dpnidPlayer) + { + EXPECT_EQ(dp->dwReason, DPNDESTROYPLAYERREASON_SESSIONTERMINATED); + } + else{ + EXPECT_TRUE((dp->dwReason == DPNDESTROYPLAYERREASON_SESSIONTERMINATED || dp->dwReason == DPNDESTROYPLAYERREASON_NORMAL)) + << "dwReason = " << dp->dwReason; + } + + p1_dp_dpnidPlayer.insert(dp->dpnidPlayer); + } + else if(dwMessageType == DPN_MSGID_TERMINATE_SESSION) + { + DPNMSG_TERMINATE_SESSION *ts = (DPNMSG_TERMINATE_SESSION*)(pMessage); + + EXPECT_EQ(ts->dwSize, sizeof(DPNMSG_TERMINATE_SESSION)); + EXPECT_EQ(ts->hResultCode, DPNERR_HOSTTERMINATEDSESSION); + + EXPECT_EQ( + std::string((const char*)(ts->pvTerminateData), ts->dwTerminateDataSize), + std::string((const char*)(TERMINATE_DATA), sizeof(TERMINATE_DATA))); + + ++p1_ts; + } + else{ + ADD_FAILURE() << "Unexpected message type: " << dwMessageType; + } + + return DPN_OK; + }, 4); + + std::set p2_dp_dpnidPlayer; + int p2_ts = 0; + + peer2.expect_begin(); + peer2.expect_push([&host, &peer1, &peer2, &p2_dp_dpnidPlayer, &p2_ts, &TERMINATE_DATA](DWORD dwMessageType, PVOID pMessage) + { + if(dwMessageType == DPN_MSGID_DESTROY_PLAYER) + { + DPNMSG_DESTROY_PLAYER *dp = (DPNMSG_DESTROY_PLAYER*)(pMessage); + + EXPECT_EQ(dp->dwSize, sizeof(DPNMSG_DESTROY_PLAYER)); + + EXPECT_TRUE((dp->dpnidPlayer == host.first_cp_dpnidPlayer || dp->dpnidPlayer == peer1.first_cc_dpnidLocal || dp->dpnidPlayer == peer2.first_cc_dpnidLocal)) + << "(dpnidPlayer = " << dp->dpnidPlayer + << ", host = " << host.first_cp_dpnidPlayer + << ", peer1 = " << peer1.first_cc_dpnidLocal + << ", peer2 = " << peer2.first_cc_dpnidLocal << ")"; + + EXPECT_EQ(dp->pvPlayerContext, (void*)~(uintptr_t)(dp->dpnidPlayer)); + + if(dp->dpnidPlayer == host.first_cp_dpnidPlayer) + { + EXPECT_EQ(dp->dwReason, DPNDESTROYPLAYERREASON_SESSIONTERMINATED); + } + else{ + EXPECT_TRUE((dp->dwReason == DPNDESTROYPLAYERREASON_SESSIONTERMINATED || dp->dwReason == DPNDESTROYPLAYERREASON_NORMAL)) + << "dwReason = " << dp->dwReason; + } + + p2_dp_dpnidPlayer.insert(dp->dpnidPlayer); + } + else if(dwMessageType == DPN_MSGID_TERMINATE_SESSION) + { + DPNMSG_TERMINATE_SESSION *ts = (DPNMSG_TERMINATE_SESSION*)(pMessage); + + EXPECT_EQ(ts->dwSize, sizeof(DPNMSG_TERMINATE_SESSION)); + EXPECT_EQ(ts->hResultCode, DPNERR_HOSTTERMINATEDSESSION); + + EXPECT_EQ( + std::string((const char*)(ts->pvTerminateData), ts->dwTerminateDataSize), + std::string((const char*)(TERMINATE_DATA), sizeof(TERMINATE_DATA))); + + ++p2_ts; + } + else{ + ADD_FAILURE() << "Unexpected message type: " << dwMessageType; + } + + return DPN_OK; + }, 4); + + ASSERT_EQ(host->TerminateSession((void*)(TERMINATE_DATA), sizeof(TERMINATE_DATA), 0), S_OK); + + Sleep(250); + + peer2.expect_end(); + peer1.expect_end(); + host.expect_end(); + + std::set all_players; + all_players.insert(host.first_cp_dpnidPlayer); + all_players.insert(peer1.first_cc_dpnidLocal); + all_players.insert(peer2.first_cc_dpnidLocal); + + EXPECT_EQ(h_dp_dpnidPlayer, all_players); + EXPECT_EQ(h_ts, 1); + + EXPECT_EQ(p1_dp_dpnidPlayer, all_players); + EXPECT_EQ(p1_ts, 1); + + EXPECT_EQ(p2_dp_dpnidPlayer, all_players); + EXPECT_EQ(p2_ts, 1); + + /* All peers should now be in a state where no further messages are raised by calling + * IDirectPlay8Peer::Close() + */ + + host.expect_begin(); + peer1.expect_begin(); + peer2.expect_begin(); + + host->Close(0); + peer1->Close(0); + peer2->Close(0); + + Sleep(250); + + peer2.expect_end(); + peer1.expect_end(); + host.expect_end(); +}