| # Copyright 2016 The Chromium Authors. All rights reserved. | 
 | # Use of this source code is governed by a BSD-style license that can be | 
 | # found in the LICENSE file. | 
 |  | 
 | import unittest | 
 |  | 
 | import loading_graph_view | 
 | import request_dependencies_lens | 
 | from request_dependencies_lens_unittest import TestRequests | 
 |  | 
 |  | 
 | class MockContentClassificationLens(object): | 
 |   def __init__(self, ad_request_ids, tracking_request_ids): | 
 |     self._ad_requests_ids = ad_request_ids | 
 |     self._tracking_request_ids = tracking_request_ids | 
 |  | 
 |   def IsAdRequest(self, request): | 
 |     return request.request_id in self._ad_requests_ids | 
 |  | 
 |   def IsTrackingRequest(self, request): | 
 |     return request.request_id in self._tracking_request_ids | 
 |  | 
 |  | 
 | class LoadingGraphViewTestCase(unittest.TestCase): | 
 |   def setUp(self): | 
 |     super(LoadingGraphViewTestCase, self).setUp() | 
 |     self.trace = TestRequests.CreateLoadingTrace() | 
 |     self.deps_lens = request_dependencies_lens.RequestDependencyLens(self.trace) | 
 |  | 
 |   def testAnnotateNodesNoLenses(self): | 
 |     graph_view = loading_graph_view.LoadingGraphView(self.trace, self.deps_lens) | 
 |     for node in graph_view.deps_graph.graph.Nodes(): | 
 |       self.assertFalse(node.is_ad) | 
 |       self.assertFalse(node.is_tracking) | 
 |     for edge in graph_view.deps_graph.graph.Edges(): | 
 |       self.assertFalse(edge.is_timing) | 
 |  | 
 |   def testAnnotateNodesContentLens(self): | 
 |     ad_request_ids = set([TestRequests.JS_REQUEST_UNRELATED_FRAME.request_id]) | 
 |     tracking_request_ids = set([TestRequests.JS_REQUEST.request_id]) | 
 |     content_lens = MockContentClassificationLens( | 
 |         ad_request_ids, tracking_request_ids) | 
 |     graph_view = loading_graph_view.LoadingGraphView(self.trace, self.deps_lens, | 
 |                                                      content_lens) | 
 |     for node in graph_view.deps_graph.graph.Nodes(): | 
 |       request_id = node.request.request_id | 
 |       self.assertEqual(request_id in ad_request_ids, node.is_ad) | 
 |       self.assertEqual(request_id in tracking_request_ids, node.is_tracking) | 
 |  | 
 |   def testRemoveAds(self): | 
 |     ad_request_ids = set([TestRequests.JS_REQUEST_UNRELATED_FRAME.request_id]) | 
 |     tracking_request_ids = set([TestRequests.JS_REQUEST.request_id]) | 
 |     content_lens = MockContentClassificationLens( | 
 |         ad_request_ids, tracking_request_ids) | 
 |     graph_view = loading_graph_view.LoadingGraphView(self.trace, self.deps_lens, | 
 |                                                      content_lens) | 
 |     graph_view.RemoveAds() | 
 |     request_ids = set([n.request.request_id | 
 |                        for n in graph_view.deps_graph.graph.Nodes()]) | 
 |     expected_request_ids = set([r.request_id for r in [ | 
 |         TestRequests.FIRST_REDIRECT_REQUEST, | 
 |         TestRequests.SECOND_REDIRECT_REQUEST, | 
 |         TestRequests.REDIRECTED_REQUEST, | 
 |         TestRequests.REQUEST, | 
 |         TestRequests.JS_REQUEST_OTHER_FRAME]]) | 
 |     self.assertSetEqual(expected_request_ids, request_ids) | 
 |  | 
 |   def testRemoveAdsPruneGraph(self): | 
 |     ad_request_ids = set([TestRequests.SECOND_REDIRECT_REQUEST.request_id]) | 
 |     tracking_request_ids = set([]) | 
 |     content_lens = MockContentClassificationLens( | 
 |         ad_request_ids, tracking_request_ids) | 
 |     graph_view = loading_graph_view.LoadingGraphView( | 
 |         self.trace, self.deps_lens, content_lens) | 
 |     graph_view.RemoveAds() | 
 |     request_ids = set([n.request.request_id | 
 |                        for n in graph_view.deps_graph.graph.Nodes()]) | 
 |     expected_request_ids = set( | 
 |         [TestRequests.FIRST_REDIRECT_REQUEST.request_id]) | 
 |     self.assertSetEqual(expected_request_ids, request_ids) | 
 |  | 
 |   def testEventInversion(self): | 
 |     self._UpdateRequestTiming({ | 
 |         '1234.redirect.1': (0, 0), | 
 |         '1234.redirect.2': (0, 0), | 
 |         '1234.1': (10, 100), | 
 |         '1234.12': (20, 50), | 
 |         '1234.42': (40, 70), | 
 |         '1234.56': (40, 150)}) | 
 |     graph_view = loading_graph_view.LoadingGraphView( | 
 |         self.trace, self.deps_lens) | 
 |     self.assertEqual(None, graph_view.GetInversionsAtTime(40)) | 
 |     self.assertEqual('1234.1', graph_view.GetInversionsAtTime(60)[0].request_id) | 
 |     self.assertEqual('1234.1', graph_view.GetInversionsAtTime(80)[0].request_id) | 
 |     self.assertEqual(None, graph_view.GetInversionsAtTime(110)) | 
 |     self.assertEqual(None, graph_view.GetInversionsAtTime(160)) | 
 |  | 
 |   def _UpdateRequestTiming(self, changes): | 
 |     for rq in self.trace.request_track.GetEvents(): | 
 |       if rq.request_id in changes: | 
 |         start_msec, end_msec = changes[rq.request_id] | 
 |         rq.timing.request_time = float(start_msec) / 1000 | 
 |         rq.timing.loading_finished = end_msec - start_msec | 
 |  | 
 |  | 
 | if __name__ == '__main__': | 
 |   unittest.main() |