forked from mirrors/gecko-dev
		
	 acc75d964c
			
		
	
	
		acc75d964c
		
	
	
	
	
		
			
			Depends on D174432 Differential Revision: https://phabricator.services.mozilla.com/D174433
		
			
				
	
	
		
			216 lines
		
	
	
	
		
			7.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			216 lines
		
	
	
	
		
			7.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
 | |
| /* This Source Code Form is subject to the terms of the Mozilla Public
 | |
|  * License, v. 2.0. If a copy of the MPL was not distributed with this
 | |
|  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
 | |
| 
 | |
| #include "mozilla/Base64.h"
 | |
| #include "nsUrlClassifierUtils.h"
 | |
| #include "safebrowsing.pb.h"
 | |
| 
 | |
| #include "Common.h"
 | |
| 
 | |
| using namespace mozilla;
 | |
| 
 | |
| template <size_t N>
 | |
| static void ToBase64EncodedStringArray(nsCString (&aInput)[N],
 | |
|                                        nsTArray<nsCString>& aEncodedArray) {
 | |
|   for (size_t i = 0; i < N; i++) {
 | |
|     nsCString encoded;
 | |
|     nsresult rv = mozilla::Base64Encode(aInput[i], encoded);
 | |
|     NS_ENSURE_SUCCESS_VOID(rv);
 | |
|     aEncodedArray.AppendElement(std::move(encoded));
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(UrlClassifierFindFullHash, Request)
 | |
| {
 | |
|   nsUrlClassifierUtils* urlUtil = nsUrlClassifierUtils::GetInstance();
 | |
| 
 | |
|   nsTArray<nsCString> listNames;
 | |
|   listNames.AppendElement("moztest-phish-proto");
 | |
|   listNames.AppendElement("moztest-unwanted-proto");
 | |
| 
 | |
|   nsCString listStates[] = {nsCString("sta\x00te1", 7),
 | |
|                             nsCString("sta\x00te2", 7)};
 | |
|   nsTArray<nsCString> listStateArray;
 | |
|   ToBase64EncodedStringArray(listStates, listStateArray);
 | |
| 
 | |
|   nsCString prefixes[] = {nsCString("\x00\x00\x00\x01", 4),
 | |
|                           nsCString("\x00\x00\x00\x00\x01", 5),
 | |
|                           nsCString("\x00\xFF\x00\x01", 4),
 | |
|                           nsCString("\x00\xFF\x00\x01\x11\x23\xAA\xBC", 8),
 | |
|                           nsCString("\x00\x00\x00\x01\x00\x01\x98", 7)};
 | |
|   nsTArray<nsCString> prefixArray;
 | |
|   ToBase64EncodedStringArray(prefixes, prefixArray);
 | |
| 
 | |
|   nsCString requestBase64;
 | |
|   nsresult rv;
 | |
|   rv = urlUtil->MakeFindFullHashRequestV4(listNames, listStateArray,
 | |
|                                           prefixArray, requestBase64);
 | |
|   ASSERT_NS_SUCCEEDED(rv);
 | |
| 
 | |
|   // Base64 URL decode first.
 | |
|   FallibleTArray<uint8_t> requestBinary;
 | |
|   rv = Base64URLDecode(requestBase64, Base64URLDecodePaddingPolicy::Require,
 | |
|                        requestBinary);
 | |
|   ASSERT_NS_SUCCEEDED(rv);
 | |
| 
 | |
|   // Parse the FindFullHash binary and compare with the expected values.
 | |
|   FindFullHashesRequest r;
 | |
|   ASSERT_TRUE(r.ParseFromArray(&requestBinary[0], requestBinary.Length()));
 | |
| 
 | |
|   // Compare client states.
 | |
|   ASSERT_EQ(r.client_states_size(), (int)ArrayLength(listStates));
 | |
|   for (int i = 0; i < r.client_states_size(); i++) {
 | |
|     auto s = r.client_states(i);
 | |
|     ASSERT_TRUE(listStates[i].Equals(nsCString(s.c_str(), s.size())));
 | |
|   }
 | |
| 
 | |
|   auto threatInfo = r.threat_info();
 | |
| 
 | |
|   // Compare threat types.
 | |
|   ASSERT_EQ(threatInfo.threat_types_size(), (int)ArrayLength(listStates));
 | |
|   for (int i = 0; i < threatInfo.threat_types_size(); i++) {
 | |
|     uint32_t expectedThreatType;
 | |
|     rv =
 | |
|         urlUtil->ConvertListNameToThreatType(listNames[i], &expectedThreatType);
 | |
|     ASSERT_NS_SUCCEEDED(rv);
 | |
|     ASSERT_EQ(threatInfo.threat_types(i), (int)expectedThreatType);
 | |
|   }
 | |
| 
 | |
|   // Compare prefixes.
 | |
|   ASSERT_EQ(threatInfo.threat_entries_size(), (int)ArrayLength(prefixes));
 | |
|   for (int i = 0; i < threatInfo.threat_entries_size(); i++) {
 | |
|     auto p = threatInfo.threat_entries(i).hash();
 | |
|     ASSERT_TRUE(prefixes[i].Equals(nsCString(p.c_str(), p.size())));
 | |
|   }
 | |
| }
 | |
| 
 | |
| /////////////////////////////////////////////////////////////
 | |
| // Following is to test parsing the gethash response.
 | |
| 
 | |
| namespace {
 | |
| 
 | |
| // safebrowsing::Duration manipulation.
 | |
| struct MyDuration {
 | |
|   uint32_t mSecs;
 | |
|   uint32_t mNanos;
 | |
| };
 | |
| void PopulateDuration(Duration& aDest, const MyDuration& aSrc) {
 | |
|   aDest.set_seconds(aSrc.mSecs);
 | |
|   aDest.set_nanos(aSrc.mNanos);
 | |
| }
 | |
| 
 | |
| // The expected match data.
 | |
| static MyDuration EXPECTED_MIN_WAIT_DURATION = {12, 10};
 | |
| static MyDuration EXPECTED_NEG_CACHE_DURATION = {120, 9};
 | |
| static const struct ExpectedMatch {
 | |
|   nsCString mCompleteHash;
 | |
|   ThreatType mThreatType;
 | |
|   MyDuration mPerHashCacheDuration;
 | |
| } EXPECTED_MATCH[] = {
 | |
|     {nsCString("01234567890123456789012345678901"),
 | |
|      SOCIAL_ENGINEERING_PUBLIC,
 | |
|      {8, 500}},
 | |
|     {nsCString("12345678901234567890123456789012"),
 | |
|      SOCIAL_ENGINEERING_PUBLIC,
 | |
|      {7, 100}},
 | |
|     {nsCString("23456789012345678901234567890123"),
 | |
|      SOCIAL_ENGINEERING_PUBLIC,
 | |
|      {1, 20}},
 | |
| };
 | |
| 
 | |
| class MyParseCallback final : public nsIUrlClassifierParseFindFullHashCallback {
 | |
|  public:
 | |
|   NS_DECL_ISUPPORTS
 | |
| 
 | |
|   explicit MyParseCallback(uint32_t& aCallbackCount)
 | |
|       : mCallbackCount(aCallbackCount) {}
 | |
| 
 | |
|   NS_IMETHOD
 | |
|   OnCompleteHashFound(const nsACString& aCompleteHash,
 | |
|                       const nsACString& aTableNames,
 | |
|                       uint32_t aPerHashCacheDuration) override {
 | |
|     Verify(aCompleteHash, aTableNames, aPerHashCacheDuration);
 | |
| 
 | |
|     return NS_OK;
 | |
|   }
 | |
| 
 | |
|   NS_IMETHOD
 | |
|   OnResponseParsed(uint32_t aMinWaitDuration,
 | |
|                    uint32_t aNegCacheDuration) override {
 | |
|     VerifyDuration(aMinWaitDuration / 1000, EXPECTED_MIN_WAIT_DURATION);
 | |
|     VerifyDuration(aNegCacheDuration, EXPECTED_NEG_CACHE_DURATION);
 | |
| 
 | |
|     return NS_OK;
 | |
|   }
 | |
| 
 | |
|  private:
 | |
|   void Verify(const nsACString& aCompleteHash, const nsACString& aTableNames,
 | |
|               uint32_t aPerHashCacheDuration) {
 | |
|     auto expected = EXPECTED_MATCH[mCallbackCount];
 | |
| 
 | |
|     ASSERT_TRUE(aCompleteHash.Equals(expected.mCompleteHash));
 | |
| 
 | |
|     // Verify aTableNames
 | |
|     nsUrlClassifierUtils* urlUtil = nsUrlClassifierUtils::GetInstance();
 | |
| 
 | |
|     nsCString tableNames;
 | |
|     nsresult rv =
 | |
|         urlUtil->ConvertThreatTypeToListNames(expected.mThreatType, tableNames);
 | |
|     ASSERT_NS_SUCCEEDED(rv);
 | |
|     ASSERT_TRUE(aTableNames.Equals(tableNames));
 | |
| 
 | |
|     VerifyDuration(aPerHashCacheDuration, expected.mPerHashCacheDuration);
 | |
| 
 | |
|     mCallbackCount++;
 | |
|   }
 | |
| 
 | |
|   void VerifyDuration(uint32_t aToVerify, const MyDuration& aExpected) {
 | |
|     ASSERT_TRUE(aToVerify == aExpected.mSecs);
 | |
|   }
 | |
| 
 | |
|   ~MyParseCallback() = default;
 | |
| 
 | |
|   uint32_t& mCallbackCount;
 | |
| };
 | |
| 
 | |
| NS_IMPL_ISUPPORTS(MyParseCallback, nsIUrlClassifierParseFindFullHashCallback)
 | |
| 
 | |
| }  // end of unnamed namespace.
 | |
| 
 | |
| TEST(UrlClassifierFindFullHash, ParseRequest)
 | |
| {
 | |
|   // Build response.
 | |
|   FindFullHashesResponse r;
 | |
| 
 | |
|   // Init response-wise durations.
 | |
|   auto minWaitDuration = r.mutable_minimum_wait_duration();
 | |
|   PopulateDuration(*minWaitDuration, EXPECTED_MIN_WAIT_DURATION);
 | |
|   auto negCacheDuration = r.mutable_negative_cache_duration();
 | |
|   PopulateDuration(*negCacheDuration, EXPECTED_NEG_CACHE_DURATION);
 | |
| 
 | |
|   // Init matches.
 | |
|   for (uint32_t i = 0; i < ArrayLength(EXPECTED_MATCH); i++) {
 | |
|     auto expected = EXPECTED_MATCH[i];
 | |
|     auto match = r.mutable_matches()->Add();
 | |
|     match->set_threat_type(expected.mThreatType);
 | |
|     match->mutable_threat()->set_hash(expected.mCompleteHash.BeginReading(),
 | |
|                                       expected.mCompleteHash.Length());
 | |
|     auto perHashCacheDuration = match->mutable_cache_duration();
 | |
|     PopulateDuration(*perHashCacheDuration, expected.mPerHashCacheDuration);
 | |
|   }
 | |
|   std::string s;
 | |
|   r.SerializeToString(&s);
 | |
| 
 | |
|   uint32_t callbackCount = 0;
 | |
|   nsCOMPtr<nsIUrlClassifierParseFindFullHashCallback> callback =
 | |
|       new MyParseCallback(callbackCount);
 | |
| 
 | |
|   nsUrlClassifierUtils* urlUtil = nsUrlClassifierUtils::GetInstance();
 | |
|   nsresult rv = urlUtil->ParseFindFullHashResponseV4(
 | |
|       nsCString(s.c_str(), s.size()), callback);
 | |
|   NS_ENSURE_SUCCESS_VOID(rv);
 | |
| 
 | |
|   ASSERT_EQ(callbackCount, ArrayLength(EXPECTED_MATCH));
 | |
| }
 |