diff --git a/src/packet.cpp b/src/packet.cpp index 1235be9..299fcdd 100644 --- a/src/packet.cpp +++ b/src/packet.cpp @@ -10,6 +10,7 @@ const uint32_t FIELD_TYPE_NULL = 0; const uint32_t FIELD_TYPE_DWORD = 1; const uint32_t FIELD_TYPE_DATA = 2; const uint32_t FIELD_TYPE_WSTRING = 3; +const uint32_t FIELD_TYPE_GUID = 4; PacketSerialiser::PacketSerialiser(uint32_t type) { @@ -77,6 +78,18 @@ void PacketSerialiser::append_wstring(const std::wstring &string) ((TLVChunk*)(sbuf.data()))->value_length += sizeof(header) + string_bytes; } +void PacketSerialiser::append_guid(const GUID &guid) +{ + TLVChunk header; + header.type = FIELD_TYPE_GUID; + header.value_length = sizeof(GUID); + + sbuf.insert(sbuf.end(), (unsigned char*)(&header), (unsigned char*)(&header + 1)); + sbuf.insert(sbuf.end(), (unsigned char*)(&guid), (unsigned char*)(&guid) + sizeof(GUID)); + + ((TLVChunk*)(sbuf.data()))->value_length += sizeof(header) + sizeof(GUID); +} + PacketDeserialiser::PacketDeserialiser(const void *serialised_packet, size_t packet_size) { header = (const TLVChunk*)(serialised_packet); @@ -105,17 +118,17 @@ PacketDeserialiser::PacketDeserialiser(const void *serialised_packet, size_t pac } } -uint32_t PacketDeserialiser::packet_type() +uint32_t PacketDeserialiser::packet_type() const { return header->type; } -size_t PacketDeserialiser::num_fields() +size_t PacketDeserialiser::num_fields() const { return fields.size(); } -bool PacketDeserialiser::is_null(size_t index) +bool PacketDeserialiser::is_null(size_t index) const { if(fields.size() <= index) { @@ -125,7 +138,7 @@ bool PacketDeserialiser::is_null(size_t index) return (fields[index]->type == FIELD_TYPE_NULL); } -DWORD PacketDeserialiser::get_dword(size_t index) +DWORD PacketDeserialiser::get_dword(size_t index) const { if(fields.size() <= index) { @@ -145,7 +158,7 @@ DWORD PacketDeserialiser::get_dword(size_t index) return *(DWORD*)(fields[index]->value); } -std::pair PacketDeserialiser::get_data(size_t index) +std::pair PacketDeserialiser::get_data(size_t index) const { if(fields.size() <= index) { @@ -160,7 +173,7 @@ std::pair PacketDeserialiser::get_data(size_t index) return std::make_pair((const void*)(fields[index]->value), (size_t)(fields[index]->value_length)); } -std::wstring PacketDeserialiser::get_wstring(size_t index) +std::wstring PacketDeserialiser::get_wstring(size_t index) const { if(fields.size() <= index) { @@ -179,3 +192,23 @@ std::wstring PacketDeserialiser::get_wstring(size_t index) return std::wstring((const wchar_t*)(fields[index]->value), (fields[index]->value_length / sizeof(wchar_t))); } + +GUID PacketDeserialiser::get_guid(size_t index) const +{ + if(fields.size() <= index) + { + throw Error::MissingField(); + } + + if(fields[index]->type != FIELD_TYPE_GUID) + { + throw Error::TypeMismatch(); + } + + if(fields[index]->value_length != sizeof(GUID)) + { + throw Error::Malformed(); + } + + return *(GUID*)(fields[index]->value); +} diff --git a/src/packet.hpp b/src/packet.hpp index 8daec2a..df95735 100644 --- a/src/packet.hpp +++ b/src/packet.hpp @@ -30,6 +30,7 @@ class PacketSerialiser void append_dword(DWORD value); void append_data(const void *data, size_t size); void append_wstring(const std::wstring &string); + void append_guid(const GUID &guid); }; class PacketDeserialiser @@ -53,13 +54,14 @@ class PacketDeserialiser PacketDeserialiser(const void *serialised_packet, size_t packet_size); - uint32_t packet_type(); - size_t num_fields(); + uint32_t packet_type() const; + size_t num_fields() const; - bool is_null(size_t index); - DWORD get_dword(size_t index); - std::pair get_data(size_t index); - std::wstring get_wstring(size_t index); + bool is_null(size_t index) const; + DWORD get_dword(size_t index) const; + std::pair get_data(size_t index) const; + std::wstring get_wstring(size_t index) const; + GUID get_guid(size_t index) const; }; class PacketDeserialiser::Error::Incomplete: public Error diff --git a/tests/PacketDeserialiser.cpp b/tests/PacketDeserialiser.cpp index 754d55d..d89c626 100644 --- a/tests/PacketDeserialiser.cpp +++ b/tests/PacketDeserialiser.cpp @@ -57,6 +57,11 @@ TEST_F(PacketDeserialiserEmpty, GetWString) EXPECT_THROW({ pd->get_wstring(0); }, PacketDeserialiser::Error::MissingField); } +TEST_F(PacketDeserialiserEmpty, GetGUID) +{ + EXPECT_THROW({ pd->get_guid(0); }, PacketDeserialiser::Error::MissingField); +} + class PacketDeserialiserNull: public PacketDeserialiserTest { protected: virtual void SetUp() override @@ -107,6 +112,12 @@ TEST_F(PacketDeserialiserNull, GetWString) EXPECT_THROW({ pd->get_wstring(1); }, PacketDeserialiser::Error::MissingField); } +TEST_F(PacketDeserialiserNull, GetGUID) +{ + EXPECT_THROW({ pd->get_guid(0); }, PacketDeserialiser::Error::TypeMismatch); + EXPECT_THROW({ pd->get_guid(1); }, PacketDeserialiser::Error::MissingField); +} + class PacketDeserialiserDWORD: public PacketDeserialiserTest { protected: virtual void SetUp() override @@ -158,6 +169,12 @@ TEST_F(PacketDeserialiserDWORD, GetWString) EXPECT_THROW({ pd->get_wstring(1); }, PacketDeserialiser::Error::MissingField); } +TEST_F(PacketDeserialiserDWORD, GetGUID) +{ + EXPECT_THROW({ pd->get_guid(0); }, PacketDeserialiser::Error::TypeMismatch); + EXPECT_THROW({ pd->get_guid(1); }, PacketDeserialiser::Error::MissingField); +} + class PacketDeserialiserData: public PacketDeserialiserTest { protected: virtual void SetUp() override @@ -219,6 +236,12 @@ TEST_F(PacketDeserialiserData, GetWString) EXPECT_THROW({ pd->get_wstring(1); }, PacketDeserialiser::Error::MissingField); } +TEST_F(PacketDeserialiserData, GetGUID) +{ + EXPECT_THROW({ pd->get_guid(0); }, PacketDeserialiser::Error::TypeMismatch); + EXPECT_THROW({ pd->get_guid(1); }, PacketDeserialiser::Error::MissingField); +} + class PacketDeserialiserWString: public PacketDeserialiserTest { protected: virtual void SetUp() override @@ -272,6 +295,74 @@ TEST_F(PacketDeserialiserWString, GetWString) EXPECT_THROW({ pd->get_wstring(1); }, PacketDeserialiser::Error::MissingField); } +TEST_F(PacketDeserialiserWString, GetGUID) +{ + EXPECT_THROW({ pd->get_guid(0); }, PacketDeserialiser::Error::TypeMismatch); + EXPECT_THROW({ pd->get_guid(1); }, PacketDeserialiser::Error::MissingField); +} + +class PacketDeserialiserGUID: public PacketDeserialiserTest { + protected: + virtual void SetUp() override + { + static const unsigned char RAW[] = { + 0x06, 0x00, 0x00, 0x00, /* type */ + 0x18, 0x00, 0x00, 0x00, /* value_length */ + + 0x04, 0x00, 0x00, 0x00, /* type */ + 0x10, 0x00, 0x00, 0x00, /* value_length */ + 0x01, 0x23, 0x45, 0x67, /* value */ + 0x89, 0x1A, 0xBC, 0xDE, + 0xF0, 0x12, 0x34, 0x56, + 0x78, 0x91, 0xAB, 0xCD, + }; + + ASSERT_NO_THROW({ pd = new PacketDeserialiser(RAW, sizeof(RAW)); }); + } +}; + +TEST_F(PacketDeserialiserGUID, Type) +{ + EXPECT_EQ(pd->packet_type(), (uint32_t)(6)); +} + +TEST_F(PacketDeserialiserGUID, NumFields) +{ + EXPECT_EQ(pd->num_fields(), (size_t)(1)); +} + +TEST_F(PacketDeserialiserGUID, IsNull) +{ + EXPECT_NO_THROW({ EXPECT_EQ(pd->is_null(0), false); }); + EXPECT_THROW({ pd->is_null(1); }, PacketDeserialiser::Error::MissingField); +} + +TEST_F(PacketDeserialiserGUID, GetDWORD) +{ + EXPECT_THROW({ pd->get_dword(0); }, PacketDeserialiser::Error::TypeMismatch); + EXPECT_THROW({ pd->get_dword(1); }, PacketDeserialiser::Error::MissingField); +} + +TEST_F(PacketDeserialiserGUID, GetData) +{ + EXPECT_THROW({ pd->get_data(0); }, PacketDeserialiser::Error::TypeMismatch); + EXPECT_THROW({ pd->get_data(1); }, PacketDeserialiser::Error::MissingField); +} + +TEST_F(PacketDeserialiserGUID, GetWString) +{ + EXPECT_THROW({ pd->get_wstring(0); }, PacketDeserialiser::Error::TypeMismatch); + EXPECT_THROW({ pd->get_wstring(1); }, PacketDeserialiser::Error::MissingField); +} + +TEST_F(PacketDeserialiserGUID, GetGUID) +{ + const GUID EXPECT = { 0x67452301, 0x1A89, 0xDEBC, { 0xF0, 0x12, 0x34, 0x56, 0x78, 0x91, 0xAB, 0xCD } }; + + EXPECT_NO_THROW({ EXPECT_EQ(pd->get_guid(0), EXPECT); }); + EXPECT_THROW({ pd->get_guid(1); }, PacketDeserialiser::Error::MissingField); +} + class PacketDeserialiserNullDWORDDataWString: public PacketDeserialiserTest { protected: virtual void SetUp() override @@ -558,3 +649,66 @@ TEST(PacketDeserialiser, OneByteWString) delete pd; } + +TEST(PacketDeserialiser, ZeroLengthGUID) +{ + const unsigned char RAW[] = { + 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + + 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }; + + PacketDeserialiser *pd = NULL; + + ASSERT_NO_THROW({ pd = new PacketDeserialiser(RAW, sizeof(RAW)); }); + EXPECT_THROW({ pd->get_guid(0); }, PacketDeserialiser::Error::Malformed); + + delete pd; +} + +TEST(PacketDeserialiser, UndersizeGUID) +{ + static const unsigned char RAW[] = { + 0x06, 0x00, 0x00, 0x00, /* type */ + 0x17, 0x00, 0x00, 0x00, /* value_length */ + + 0x04, 0x00, 0x00, 0x00, /* type */ + 0x0F, 0x00, 0x00, 0x00, /* value_length */ + 0x01, 0x23, 0x45, 0x67, /* value */ + 0x89, 0x1A, 0xBC, 0xDE, + 0xF0, 0x12, 0x34, 0x56, + 0x78, 0x91, 0xAB, + }; + + PacketDeserialiser *pd = NULL; + + ASSERT_NO_THROW({ pd = new PacketDeserialiser(RAW, sizeof(RAW)); }); + EXPECT_THROW({ pd->get_guid(0); }, PacketDeserialiser::Error::Malformed); + + delete pd; +} + +TEST(PacketDeserialiser, OversizeGUID) +{ + static const unsigned char RAW[] = { + 0x06, 0x00, 0x00, 0x00, /* type */ + 0x19, 0x00, 0x00, 0x00, /* value_length */ + + 0x04, 0x00, 0x00, 0x00, /* type */ + 0x11, 0x00, 0x00, 0x00, /* value_length */ + 0x01, 0x23, 0x45, 0x67, /* value */ + 0x89, 0x1A, 0xBC, 0xDE, + 0xF0, 0x12, 0x34, 0x56, + 0x78, 0x91, 0xAB, 0xCD, + 0xAA, + }; + + PacketDeserialiser *pd = NULL; + + ASSERT_NO_THROW({ pd = new PacketDeserialiser(RAW, sizeof(RAW)); }); + EXPECT_THROW({ pd->get_guid(0); }, PacketDeserialiser::Error::Malformed); + + delete pd; +} diff --git a/tests/PacketSerialiser.cpp b/tests/PacketSerialiser.cpp index 9041dc3..415cc20 100644 --- a/tests/PacketSerialiser.cpp +++ b/tests/PacketSerialiser.cpp @@ -124,6 +124,33 @@ TEST(PacketSerialiser, WString) ASSERT_EQ(got, expect); } +TEST(PacketSerialiser, GUID) +{ + const GUID guid = { 0x67452301, 0x1A89, 0xDEBC, { 0xF0, 0x12, 0x34, 0x56, 0x78, 0x91, 0xAB, 0xCD } }; + + PacketSerialiser p(0x1234); + p.append_guid(guid); + + std::pair raw = p.raw_packet(); + + const unsigned char EXPECT[] = { + 0x34, 0x12, 0x00, 0x00, /* type */ + 0x18, 0x00, 0x00, 0x00, /* value_length */ + + 0x04, 0x00, 0x00, 0x00, /* type */ + 0x10, 0x00, 0x00, 0x00, /* value_length */ + 0x01, 0x23, 0x45, 0x67, /* value */ + 0x89, 0x1A, 0xBC, 0xDE, + 0xF0, 0x12, 0x34, 0x56, + 0x78, 0x91, 0xAB, 0xCD, + }; + + std::vector got((unsigned char*)(raw.first), (unsigned char*)(raw.first) + raw.second); + std::vector expect(EXPECT, EXPECT + sizeof(EXPECT)); + + ASSERT_EQ(got, expect); +} + TEST(PacketSerialiser, NullDWORDDataWString) { PacketSerialiser p(0x1234);