Revision created by MOE tool init_codebases.
Revision created by MOE tool push_codebase.
MOE_MIGRATION=
git-svn-id: http://google-safe-browsing.googlecode.com/svn/trunk@103 2195c2fd-d934-0410-ae3f-cd772a4098b8
diff --git a/dashboard/dashboard.py b/dashboard/dashboard.py
new file mode 100755
index 0000000..d080fc7
--- /dev/null
+++ b/dashboard/dashboard.py
@@ -0,0 +1,197 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+A web server demonstrating the use of googlesafebrowsing.client.
+"""
+
+from googlesafebrowsing import client
+from googlesafebrowsing import datastore
+
+import BaseHTTPServer
+import cgi
+import datetime
+import getopt
+import logging
+import SocketServer
+import sys
+import threading
+
+
+class ListStats(object):
+ """
+ ListStats objects have the following fields:
+ sbl: sblist.List object
+ chunk_range_str: A string representing the chunk ranges for sbl
+ num_expressions: Number of expressions in sbl
+ num_addchunks: Number of addchunks in sbl
+ num_subchunks: Number of subchunks in sbl
+ """
+ def __init__(self, sbl):
+ self.sbl = sbl
+ self.chunk_range_str = sbl.DownloadRequest()
+ self.num_expressions = sbl.NumPrefixes()
+ self.num_addchunks = len(sbl.AddChunkMap())
+ self.num_subchunks = len(sbl.SubChunkMap())
+
+
+class DashboardServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
+ def __init__(self, host_port, handler_class, ds, apikey):
+ BaseHTTPServer.HTTPServer.__init__(self, host_port, handler_class)
+ self.lists_stats = []
+ self.stats_time = None
+ self.sync_time = None
+ self.sync_updates = None
+ self.stats_lock = threading.RLock()
+ self.sbc = client.Client(ds,
+ apikey=apikey,
+ use_mac=True,
+ post_update_hook=self._ListsUpdated)
+
+ def _ListsUpdated(self, sbc):
+ """
+ This runs in the Client's updater thread. Compute some list statistics.
+ """
+ # self.sbc is sbc
+ lists_stats = []
+ for sbl in sbc.Lists().values():
+ lists_stats.append(ListStats(sbl))
+
+ now = datetime.datetime.now()
+ self.stats_lock.acquire()
+ self.lists_stats = lists_stats
+ self.stats_time = now
+ if sbc.InSync() and self.sync_time is None:
+ self.sync_time = now
+ self.sync_updates = sbc.FirstSyncUpdates()
+ self.stats_lock.release()
+
+
+class DashboardRequest(BaseHTTPServer.BaseHTTPRequestHandler):
+ PARAM_URL = 'url'
+
+ def do_GET(self):
+ query_start = self.path.find('?')
+ self.query_params = {}
+ if query_start >= 0:
+ query = self.path[query_start + 1:]
+ self.path = self.path[0:query_start]
+ self.query_params = cgi.parse_qs(query)
+
+ {'/' : self.HandleStatus,
+ '/check_url' : self.HandleCheckUrl,
+ '/quitquitquit' : self.Quit}.get(self.path,
+ lambda: self.send_error(404, '%s not found' % self.path))()
+
+ def HandleStatus(self):
+ write = self.wfile.write
+ write('<html><head><title>Safe Browsing Client</title></head><body>')
+
+ self.server.stats_lock.acquire()
+ lists_stats = self.server.lists_stats
+ stats_time = self.server.stats_time
+ sync_time = self.server.sync_time
+ sync_updates = self.server.sync_updates
+ self.server.stats_lock.release()
+
+ if sync_time is None:
+ write('Client waiting for initial sync.<br/>')
+ else:
+ write('Client completed initial sync at %s after %d downloads.<br/>' % (
+ sync_time, sync_updates))
+ write('Client received last update at %s.<br/>' % (stats_time,))
+
+ for s in lists_stats:
+ write('<table border=1><tr><th align=left>%s</th></tr></table>' % (
+ s.chunk_range_str,))
+ write('<table border=1><tr><th>Expressions</th>' +
+ '<th>Add Chunks</th><th>Sub Chunks</th>' +
+ '<th>Expressions / Chunk</th></tr>')
+ write(('<tr align=right><td>%d</td><td>%d</td><td>%d</td>' +
+ '<td>%f</td></tr></table><br/>') % (
+ s.num_expressions, s.num_addchunks, s.num_subchunks,
+ float(s.num_expressions) / s.num_addchunks))
+
+ write(('<hr/><form action="/check_url"><input type=text name="%s" />'
+ '<input type="submit" value="Check URL" /></form>') % (
+ DashboardRequest.PARAM_URL,))
+ write('</body></html>\n')
+
+ def HandleCheckUrl(self):
+ """
+ Show if/why a URL is blocked.
+ """
+ write = self.wfile.write
+ write('<html><head><title>Check URL</title></head><body>')
+ url_param = self.query_params.get(DashboardRequest.PARAM_URL, [])
+ if len(url_param) != 1:
+ write('bad url query param: "%s"</body></html>' % (url_param,))
+ return
+ url = url_param[0]
+ matches = self.server.sbc.CheckUrl(url, debug_info=True)
+ if len(matches) == 0:
+ write('No matches for "%s"</body></html>' % (url,))
+ return
+
+ write('<ul>')
+ for listname, match, addchunknum in matches:
+ write('<li>%s, addchunk number %d: %s</li>' % (
+ listname, addchunknum, match))
+ write('</ul></body></html>')
+
+ def Quit(self):
+ self.server.sbc.ExitUpdater()
+ self.server.server_close()
+
+
+def Usage():
+ print >>sys.stderr, ('dashboard --port <port> --apikey <apikey> ' +
+ '[--datastore <datastore>]')
+ sys.exit(1)
+
+
+def main(argv):
+ try:
+ optlist = getopt.getopt(sys.argv[1:], None,
+ ['port=', 'apikey=', 'datastore='])[0]
+ except getopt.GetoptError, e:
+ print >>sys.stderr, str(e)
+ Usage()
+ print 'optlist:', optlist
+ port = None
+ apikey = None
+ dspath = '/tmp/dashboard_datastore'
+ for argname, argvalue in optlist:
+ if argname == '--port':
+ try:
+ port = int(argvalue)
+ except ValueError:
+ Usage()
+ elif argname == '--datastore':
+ dspath = argvalue
+ elif argname == '--apikey':
+ apikey = argvalue
+ if port is None or apikey is None:
+ Usage()
+
+ ds = datastore.DataStore(dspath)
+ http_server = DashboardServer(('', port), DashboardRequest, ds, apikey)
+ http_server.serve_forever()
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+ main(sys.argv)
diff --git a/python/COPYING b/python/COPYING
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/python/COPYING
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/python/__init__.py b/python/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/python/__init__.py
diff --git a/python/client.py b/python/client.py
new file mode 100755
index 0000000..692ff6c
--- /dev/null
+++ b/python/client.py
@@ -0,0 +1,428 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Google Safe Browsing protocol version 2.2 client."""
+
+import datetime
+import logging
+import os
+import sys
+import tempfile
+import threading
+
+import datastore
+import expression
+import hashprefix_trie
+import server
+import util
+
+
+class Error(Exception):
+ pass
+
+
+def FullHashIsCurrent(list_entry, sbl, now):
+ """
+ Returns true if full hash should be considered valid.
+ """
+ updated = list_entry.GetHashTimestamp()
+ # All hashes are considered valid if we have made downloads request in
+ # UPDATED_MAX time. If not, then it is still considered valid if we have
+ # made a hashserver request that verified the hash is still valid in
+ # UPDATED_MAX time
+ return (((sbl.UpdateTime() is not None) and
+ (now - sbl.UpdateTime() < Client.UPDATED_MAX)) or
+ ((updated is not None) and
+ (now - updated < Client.UPDATED_MAX)))
+
+
+def ExternalCheckUrl(url, sbls, server, debug_info=False):
+ """
+ Return a list of 2-tuples [(blacklist, matching expression), ...] for
+ all matches.
+ url needs to be ASCII. Encode url with Punycode if necessary.
+ """
+ gen = expression.ExpressionGenerator(url)
+
+ # A trie that maps hash prefixes to expression objects.
+ # This tracks prefixes for which we need to do a gethash request.
+ prefix_matches = hashprefix_trie.HashprefixTrie()
+
+ # Return value.
+ matches = []
+
+ # Keep a list of all lists which we want to check as we should only need
+ # to look for one match per list. Once we've found a matching hash we can
+ # ignore the rest.
+ # TODO(gcasto): Is this really worth it? The increase in code complexity
+ # is non trivial and in most cases it probably doesn't save any work.
+ # Also, this is ineffecient as we are copying around the entire list contents.
+ check_sbls = set(sbls.itervalues())
+
+ now = datetime.datetime.now()
+
+ for expr in gen.Expressions():
+ if len(check_sbls) == 0:
+ break
+ logging.debug('Checking expression: "%s"', expr.Value())
+
+ # Cast check_sbls to list so that we can modify elements.
+ for sbl in list(check_sbls):
+ for list_entry in sbl.GetPrefixMatches(expr.HashValue()):
+ fullhash = list_entry.FullHash()
+ if fullhash is None or not FullHashIsCurrent(list_entry, sbl, now):
+ # Multiple prefix matches per list are rare, but they do happen.
+ # Make sure to keep track of all matches.
+ prefix_matches.Insert(list_entry.Prefix(), expr)
+ elif fullhash == expr.HashValue():
+ if debug_info:
+ matches.append((sbl.Name(), expr.Value(),
+ list_entry.AddChunkNum()))
+ else:
+ matches.append((sbl.Name(), expr.Value()))
+ check_sbls.remove(sbl)
+ break # Found a match. Continue to the next list.
+ # TODO(gcasto): This is not technically correct as you could be trying to look
+ # up a prefix from one list and have check_sbls populated by a different
+ # list and you would proceed with the lookup even though it doesn't matter.
+ if len(check_sbls) == 0 or prefix_matches.Size() == 0:
+ return matches
+
+ # Get full length hashes for cases where we only had a matching prefix or
+ # had a full length hash that was too old.
+ ghresp = server.GetAllFullHashes(prefix_matches.PrefixIterator())
+ logging.debug('get all thashes response: %s', ghresp)
+
+ # Check these full length hashes for a match.
+ for listname, addchunknum_map in ghresp.listmap.iteritems():
+ sbl = sbls.get(listname, None)
+ if sbl is None:
+ logging.info("No Listname")
+ # listname showed up on the gethash server before the downloads server.
+ continue
+ for addchunknum, hashes_set in addchunknum_map.iteritems():
+ for fullhash in hashes_set:
+ for expr in prefix_matches.GetPrefixMatches(fullhash):
+ if (sbl.AddFullHash(fullhash, addchunknum, ghresp.timestamp)
+ and sbl in check_sbls and expr.HashValue() == fullhash):
+ if debug_info:
+ matches.append((sbl.Name(), expr.Value(), addchunknum))
+ else:
+ matches.append((sbl.Name(), expr.Value()))
+ check_sbls.remove(sbl)
+
+ return matches
+
+
+class Client(object):
+ """
+ An automatically self-updating container for safebrowsing lists. Uses a
+ background thread to update the local list cache.
+
+ ds: DataStore instance
+ apikey: SafeBrowsing API key.
+ hp: 2-tuple with host and port for HTTP connections
+ ssl_hp: 2-tuple with host and port for HTTPS(SSL) connections
+ base_path: Base of HTTP path on host:port.
+ use_mac: True to enable verification with MACs.
+ size_limit: Preferred maximum download size in bytes. Intended for slow
+ connections.
+ force_delay: Use this value as the server polling delay until the client is
+ fully in sync.
+ pre_update_hook: A function that is called immediately before an update
+ finishes. This function must accept this Client object as an argument. The
+ function is called from the Client's updater thread.
+ post_update_hook: A function that is called immediately after an update
+ finishes. This function must accept this Client object as an argument. The
+ function is called from the Client's updater thread.
+ gethash_server: 2-tuple of host and port for gethash requests. If unspecified,
+ hp will be used.
+ update_lists: If true, constantly get lists to download from the safebrowsing
+ servers, otherwise just use the lists that are already present in the
+ datastore. If the datastore has no information, we ask the server for the
+ lists to download regardless.
+ sb_server: If not None the client uses this server instance instead of
+ creating its own server instance.
+ sb_lists: If not None, will use these lists instead of asking server what
+ lists are available.
+ """
+
+ DEFAULT_DELAY = 60 * 15
+
+ UPDATED_MAX = datetime.timedelta(minutes=45)
+
+ def __init__(self, ds, apikey, hp=('safebrowsing.clients.google.com', 80),
+ ssl_hp=('sb-ssl.google.com', 443), base_path='/safebrowsing',
+ use_mac=True, size_limit=None, force_delay=None,
+ pre_update_hook=lambda cl: None,
+ post_update_hook=lambda cl: None, gethash_server=None,
+ update_lists=False, sb_server=None, sb_lists=None):
+ self._force_delay = force_delay
+ self._post_update_hook = post_update_hook
+ self._pre_update_hook = pre_update_hook
+ self._datastore = ds
+ self._update_lists = update_lists
+ self._size_limit = size_limit
+
+ # A dict of listnames and sblist.Lists.
+ if sb_lists:
+ self._sbls = dict([(x.Name(), x) for x in sb_lists])
+ else:
+ self._sbls = self._datastore.GetLists()
+ clientkey = self._datastore.GetClientKey()
+ wrkey = self._datastore.GetWrKey()
+ if not sb_server:
+ self._server = server.Server(hp, ssl_hp, base_path,
+ clientkey=clientkey, wrkey=wrkey,
+ apikey=apikey, gethash_server=gethash_server)
+ else:
+ self._server = sb_server
+
+ if use_mac and (clientkey is None or wrkey is None):
+ self._Rekey()
+
+ if not self._sbls:
+ self._sbls = dict(
+ [(x.Name(), x) for x in self._server.GetLists()])
+
+ # This lock prevents concurrent access from the background updater thread
+ # and user threads.
+ self._lock = threading.RLock()
+
+ self._in_sync = False
+ self._first_sync_updates = 0
+ self._exit_cond = threading.Condition()
+ self._exit_updater = False
+
+ self._thr = threading.Thread(target=self._PollForData,
+ args=(Client.DEFAULT_DELAY,))
+ self._thr.setDaemon(True)
+ self._thr.start()
+
+ def _MakeLockedMethod(unbound):
+ def LockedMethod(self, *args, **kwargs):
+ self._lock.acquire()
+ try:
+ return unbound(self, *args, **kwargs)
+ finally:
+ self._lock.release()
+ return LockedMethod
+
+ def InSync(self):
+ return self._in_sync
+
+ def FirstSyncUpdates(self):
+ return self._first_sync_updates
+
+ def CheckUrl(self, url, debug_info=False):
+ return ExternalCheckUrl(url, self._sbls, self._server, debug_info)
+ ### Block updates from the background thread while checking a URL.
+ CheckUrl = _MakeLockedMethod(CheckUrl)
+
+ def Lists(self):
+ """
+ Return a map of listnames -> sblist.List objects.
+ """
+ return self._sbls
+ Lists = _MakeLockedMethod(Lists)
+
+ def ExitUpdater(self):
+ """
+ Call this to get a proper shutdown with a sync to the datastore.
+ """
+ self._exit_cond.acquire()
+ self._exit_updater = True
+ self._exit_cond.notify()
+ self._exit_cond.release()
+ self._thr.join()
+
+ def Server(self):
+ return self._server
+
+ def _PollForData(self, requested_delay):
+ """
+ Continuously poll the safe browsing server for updates.
+ """
+ num_updates = 0
+ while True:
+ try:
+ self._pre_update_hook(self)
+ num_updates += 1
+ requested_delay, updates_done = self._Update()
+ logging.info('Finished update number %d, next delay: %d',
+ num_updates, requested_delay)
+ if updates_done:
+ logging.info('Fully in sync')
+ self._force_delay = None
+ self._in_sync = True
+ if self._first_sync_updates == 0:
+ self._first_sync_updates = num_updates
+ else:
+ self._in_sync = False
+ self._post_update_hook(self)
+ except:
+ logging.exception('exception in client update thread')
+ logging.debug('requested_delay: %d, force_delay: %s', requested_delay,
+ self._force_delay)
+ if self._force_delay is None:
+ delay = requested_delay
+ else:
+ delay = self._force_delay
+
+ self._exit_cond.acquire()
+ try:
+ if not self._exit_updater:
+ logging.info('Waiting %d seconds' % delay)
+ self._exit_cond.wait(delay)
+ if self._exit_updater:
+ logging.info('Exiting')
+ self._datastore.Sync()
+ return
+ finally:
+ self._exit_cond.release()
+
+ def _Update(self):
+ """
+ Update the client state (blacklists, keys). Return 2-tuple (poll_delay,
+ updates_done). poll_delay is the minimum delay requested by the server.
+ updates_done is True if no changes were received from the server.
+ """
+ # Possibly check for new or deleted lists.
+ if self._update_lists:
+ self._UpdateLists()
+
+ # Get new data.
+ logging.debug('lists: "%s"', ','.join(self._sbls.iterkeys()))
+ download = self._server.Download(self._sbls.values(),
+ size_limit_bytes=self._size_limit)
+ logging.debug('Minimum delay: %d', download.min_delay_sec)
+
+ if download.rekey:
+ self._Rekey()
+ return (download.min_delay_sec, False)
+
+ if download.reset:
+ self._sbls.clear()
+ return (download.min_delay_sec, False)
+
+ updates_done = True
+ for name, ops in download.listops.iteritems():
+ for op in ops:
+ # Make sure that we actually received data before claiming that
+ # we aren't up to date.
+ updates_done = False
+ logging.debug('Applying operation to list %s: %s', name, op)
+ op.Apply()
+ # Update the List's timestamp after successfully updating it.
+ self._sbls[name].SetUpdateTime(download.timestamp)
+ return (download.min_delay_sec, updates_done)
+ _Update = _MakeLockedMethod(_Update)
+
+ # This should only be called from locked methods.
+ def _Rekey(self):
+ logging.debug('rekey')
+ clientkey, wrkey = self._server.Rekey()
+ self._datastore.SetClientKey(clientkey)
+ self._datastore.SetWrKey(wrkey)
+
+ def _UpdateLists(self):
+ sbls = self._server.GetLists()
+ deleted = set(self._sbls.iterkeys())
+ for server_l in sbls:
+ logging.debug('server returned list: "%s"', server_l.Name())
+ if server_l.Name() in deleted:
+ deleted.remove(server_l.Name())
+ if not self._sbls.has_key(server_l.Name()):
+ logging.debug('added list: "%s"', server_l.Name())
+ self._sbls[server_l.Name()] = server_l
+ for name in deleted:
+ logging.debug('Deleting list: %s', name)
+ del self._sbls[name]
+
+
+class UrlChecker(object):
+ def __init__(self, urls):
+ self._urls = urls
+ self._event = threading.Event()
+
+ def Updated(self, cl):
+ """
+ This runs in the client's updater thread.
+ """
+ logging.debug('List states:')
+ for sbl in cl.Lists().itervalues():
+ logging.debug('%s: %d prefixes, %s', sbl.Name(), sbl.NumPrefixes(),
+ sbl.DownloadRequest())
+
+ if not cl.InSync():
+ logging.info('Waiting to complete updates...')
+ return
+ for url in self._urls:
+ matches = cl.CheckUrl(url)
+ logging.info('CheckUrl %s: %s', url, matches)
+ print '%s:' % (url,)
+ if len(matches) == 0:
+ print '\t(no matches)'
+ else:
+ for listname, matching in matches:
+ print '\t%s: %s' % (listname, matching)
+ self._event.set()
+
+ def WaitForFinish(self):
+ self._event.wait()
+
+
+def PrintUsage(argv):
+ print >>sys.stderr, ('Usage: %s <APIKey> [check <URL> <URL> ...]\n'
+ 'Visit "http://code.google.com/apis/safebrowsing/'
+ 'key_signup.html" to obtain an APIKey'
+ % (argv[0],))
+
+
+def CheckForUrl(apikey, urls):
+ checking_datastore_loc = os.path.join(tempfile.mkdtemp(), 'datastore_checker')
+ ds = datastore.DataStore(checking_datastore_loc)
+
+ checker = UrlChecker(urls)
+
+ cl = Client(ds,
+ apikey,
+ post_update_hook=checker.Updated)
+ checker.WaitForFinish()
+ cl.ExitUpdater()
+
+
+def main(argv):
+ """
+ A command line google safe browsing client. Usage:
+ client.py <APIKey> [check <URLs>]
+ """
+ logging.basicConfig(level=logging.INFO)
+ if len(argv) < 3:
+ PrintUsage(argv)
+ return 1
+
+ apikey = argv[1]
+ command = argv[2]
+ if command == "check":
+ CheckForUrl(apikey, argv[2:])
+ else:
+ PrintUsage(argv)
+ return 1
+
+
+if __name__ == '__main__':
+ main(sys.argv)
diff --git a/python/client_test.py b/python/client_test.py
new file mode 100755
index 0000000..ed94ea6
--- /dev/null
+++ b/python/client_test.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""test for googlesafebrowsing.client.
+
+TODO: write some tests for the SafeBrowsing client. Indirectly, the
+client code is tested with the client testing framework but it would
+be nice to have a few tests here as well.
+"""
+
+import unittest
+
+
+class ClientTest(unittest.TestCase):
+ def testClient(self):
+ pass
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/datastore.py b/python/datastore.py
new file mode 100755
index 0000000..e504dc3
--- /dev/null
+++ b/python/datastore.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A very simple (and slow) persistence mechanism based on the shelve module.
+"""
+
+import logging
+import shelve
+
+
+class Error(Exception):
+ pass
+
+
+class DataStore(object):
+ """
+ Changes made to mutable objects returned from a DataStore are automatically
+ written back to the persistent store. Python strs are not mutable, so SetWrKey
+ and SetClientKey must be used. For the List objects, use GetLists to get a
+ mutable dict. Any changes to this dict and the List objects it contains will
+ be written back to persistent storage when Sync is called.
+ """
+
+ # Value is a dict of listname:sblist.List.
+ LISTS = 'lists'
+
+ WRKEY = 'wrkey'
+ CLIENTKEY = 'clientkey'
+
+ def __init__(self, basefile, create=True):
+ flags = 'w'
+ if create:
+ flags = 'c'
+ try:
+ self._db = shelve.open(basefile, flag=flags, writeback=True)
+ except Exception, e:
+ raise Error(e)
+
+ self._db.setdefault(DataStore.LISTS, {})
+ self._db.setdefault(DataStore.WRKEY, None)
+ self._db.setdefault(DataStore.CLIENTKEY, None)
+
+ def Sync(self):
+ """
+ This is very slow. Also, it will replace the objects in the database with
+ new copies so that existing references to the old objects will no longer
+ update the datastore. E.g., you must call GetLists() again after calling
+ this.
+ """
+ self._db.sync()
+
+ def GetLists(self):
+ """
+ Return a dict of listname:sblist.List. Changes to this dict and the List
+ objects in it are written back to the data store when Sync is called.
+ """
+ return self._db[DataStore.LISTS]
+
+ def GetWrKey(self):
+ return self._db[DataStore.WRKEY]
+
+ def SetWrKey(self, wrkey):
+ self._db[DataStore.WRKEY] = wrkey
+
+ def GetClientKey(self):
+ """
+ Unescaped client key.
+ """
+ return self._db[DataStore.CLIENTKEY]
+
+ def SetClientKey(self, clientkey):
+ self._db[DataStore.CLIENTKEY] = clientkey
diff --git a/python/expression.py b/python/expression.py
new file mode 100755
index 0000000..1a3a135
--- /dev/null
+++ b/python/expression.py
@@ -0,0 +1,352 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper classes which help converting a url to a list of SB expressions."""
+
+import array
+import logging
+import re
+import string
+import urllib
+import urlparse
+
+import util
+
+
+class UrlParseError(Exception):
+ pass
+
+
+def GenerateSafeChars():
+ """
+ Return a string containing all 'safe' characters that shouldn't be escaped
+ for url encoding. This includes all printable characters except '#%' and
+ whitespace characters.
+ """
+ unfiltered_chars = string.digits + string.ascii_letters + string.punctuation
+ filtered_list = [c for c in unfiltered_chars if c not in '%#']
+ return array.array('c', filtered_list).tostring()
+
+
+class ExpressionGenerator(object):
+ """Class does the conversion url -> list of SafeBrowsing expressions.
+
+ This class converts a given url into the list of all SafeBrowsing host-suffix,
+ path-prefix expressions for that url. These are expressions that are on the
+ SafeBrowsing lists.
+ """
+ HEX = re.compile(r'^0x([a-fA-F0-9]+)$')
+ OCT = re.compile(r'^0([0-7]+)$')
+ DEC = re.compile(r'^(\d+)$')
+ IP_WITH_TRAILING_SPACE = re.compile(r'^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}) ')
+ POSSIBLE_IP = re.compile(r'^(?i)((?:0x[0-9a-f]+|[0-9\\.])+)$')
+ FIND_BAD_OCTAL_REGEXP = re.compile(r'(^|\.)0\d*[89]')
+ # This regular expression parses the host and port from a hostname. Note: any
+ # user and password are removed from the hostname.
+ HOST_PORT_REGEXP = re.compile(r'^(?:.*@)?(?P<host>[^:]*)(:(?P<port>\d+))?$')
+ SAFE_CHARS = GenerateSafeChars()
+ # Dict that maps supported schemes to their default port number.
+ DEFAULT_PORTS = {'http': '80', 'https': '443', 'ftp': '21'}
+
+ def __init__(self, url):
+ parse_exception = UrlParseError('failed to parse URL "%s"' % (url,))
+ canonical_url = ExpressionGenerator.CanonicalizeUrl(url)
+ if not canonical_url:
+ raise parse_exception
+
+ # Each element is a list of host components used to build expressions.
+ self._host_lists = []
+ # A list of paths used to build expressions.
+ self._path_exprs = []
+
+ url_split = urlparse.urlsplit(canonical_url)
+ canonical_host, canonical_path = url_split[1], url_split[2]
+ self._MakeHostLists(canonical_host, parse_exception)
+
+ if url_split[3]:
+ # Include canonicalized path with query arguments
+ self._path_exprs.append(canonical_path + '?' + url_split[3])
+ self._path_exprs.append(canonical_path)
+
+ # Get the first three directory path components and create the 4 path
+ # expressions starting at the root (/) and successively appending directory
+ # path components, including the trailing slash. E.g.:
+ # /a/b/c/d.html -> [/, /a/, /a/b/, /a/b/c/]
+ path_parts = canonical_path.rstrip('/').lstrip('/').split('/')[:3]
+ if canonical_path.count('/') < 4:
+ # If the last component in not a directory we remove it.
+ path_parts.pop()
+ while path_parts:
+ self._path_exprs.append('/' + '/'.join(path_parts) + '/')
+ path_parts.pop()
+
+ if canonical_path != '/':
+ self._path_exprs.append('/')
+
+ @staticmethod
+ def CanonicalizeUrl(url):
+ """Canonicalize the given URL for the SafeBrowsing protocol.
+
+ Args:
+ url: URL to canonicalize.
+ Returns:
+ A canonical URL or None if the URL could not be canonicalized.
+ """
+ # Start by stripping off the fragment identifier.
+ tmp_pos = url.find('#')
+ if tmp_pos >= 0:
+ url = url[0:tmp_pos]
+
+ # Stripping off leading and trailing white spaces.
+ url = url.lstrip().rstrip()
+
+ # Remove any embedded tabs and CR/LF characters which aren't escaped.
+ url = url.replace('\t', '').replace('\r', '').replace('\n', '')
+
+ # Un-escape and re-escpae the URL just in case there are some encoded
+ # characters in the url scheme for example.
+ url = ExpressionGenerator._Escape(url)
+
+ url_split = urlparse.urlsplit(url)
+ if not url_split[0]:
+ # URL had no scheme. In this case we assume it is http://.
+ url = 'http://' + url
+ url_split = urlparse.urlsplit(url)
+
+ url_scheme = url_split[0].lower()
+ if url_scheme not in ExpressionGenerator.DEFAULT_PORTS:
+ return None # Unsupported scheme.
+
+ # Note: applying HOST_PORT_REGEXP also removes any user and password.
+ m = ExpressionGenerator.HOST_PORT_REGEXP.match(url_split[1])
+ if not m:
+ return None
+ host, port = m.group('host'), m.group('port')
+
+ canonical_host = ExpressionGenerator.CanonicalizeHost(host)
+ if not canonical_host:
+ return None
+
+ # Now that the host is canonicalized we add the port back if it's not the
+ # default port for that url scheme.
+ if port and port != ExpressionGenerator.DEFAULT_PORTS[url_scheme]:
+ canonical_host += ':' + port
+
+ canonical_path = ExpressionGenerator.CanonicalizePath(url_split[2])
+
+ # If the URL ends with ? we want to keep the ?.
+ canonical_url = url_split[0] + '://' + canonical_host + canonical_path
+ if url_split[3] != '' or url.endswith('?'):
+ canonical_url += '?' + url_split[3]
+ return canonical_url
+
+ @staticmethod
+ def CanonicalizePath(path):
+ """Canonicalize the given path."""
+ if not path:
+ return '/'
+
+ # There are some cases where the path will not start with '/'. Example:
+ # "ftp://host.com?q" -- the hostname is 'host.com' and the path '%3Fq'.
+ # Browsers typically do prepend a leading slash to the path in this case,
+ # we'll do the same.
+ if path[0] != '/':
+ path = '/' + path
+
+ path = ExpressionGenerator._Escape(path)
+
+ path_components = []
+ for path_component in path.split('/'):
+ # If the path component is '..' we skip it and remove the preceding path
+ # component if there are any.
+ if path_component == '..':
+ if len(path_components) > 0:
+ path_components.pop()
+ # We skip empty path components to remove successive slashes (i.e.,
+ # // -> /). Note: this means that the leading and trailing slash will
+ # also be removed and need to be re-added afterwards.
+ #
+ # If the path component is '.' we also skip it (i.e., /./ -> /).
+ elif path_component != '.' and path_component != '':
+ path_components.append(path_component)
+
+ # Put the path components back together and re-add the leading slash which
+ # got stipped by removing empty path components.
+ canonical_path = '/' + '/'.join(path_components)
+ # If necessary we also re-add the trailing slash.
+ if path.endswith('/') and not canonical_path.endswith('/'):
+ canonical_path += '/'
+
+ return canonical_path
+
+ @staticmethod
+ def CanonicalizeHost(host):
+ """Canonicalize the given host. Returns None in case of an error."""
+ if not host:
+ return None
+ host = ExpressionGenerator._Escape(host.lower())
+
+ ip = ExpressionGenerator.CanonicalizeIp(host)
+ if ip:
+ # Host is an IP address.
+ host = ip
+ else:
+ # Host is a normal hostname.
+ # Skip trailing, leading and consecutive dots.
+ host_split = [part for part in host.split('.') if part]
+ if len(host_split) < 2:
+ return None
+ host = '.'.join(host_split)
+
+ return host
+
+ @staticmethod
+ def CanonicalizeIp(host):
+ """
+ Return a canonicalized IP if host can represent an IP and None otherwise.
+ """
+ if len(host) <= 15:
+ # The Windows resolver allows a 4-part dotted decimal IP address to have a
+ # space followed by any old rubbish, so long as the total length of the
+ # string doesn't get above 15 characters. So, "10.192.95.89 xy" is
+ # resolved to 10.192.95.89.
+ # If the string length is greater than 15 characters,
+ # e.g. "10.192.95.89 xy.wildcard.example.com", it will be resolved through
+ # DNS.
+ m = ExpressionGenerator.IP_WITH_TRAILING_SPACE.match(host)
+ if m:
+ host = m.group(1)
+
+ if not ExpressionGenerator.POSSIBLE_IP.match(host):
+ return None
+
+ # Basically we should parse octal if we can, but if there are illegal octal
+ # numbers, i.e. 08 or 09, then we should just look at decimal and hex.
+ allow_octal = not ExpressionGenerator.FIND_BAD_OCTAL_REGEXP.search(host)
+
+ # Skip trailing, leading and consecutive dots.
+ host_split = [part for part in host.split('.') if part]
+ if len(host_split) > 4:
+ return None
+
+ ip = []
+ for i in xrange(len(host_split)):
+ m = ExpressionGenerator.HEX.match(host_split[i])
+ if m:
+ base = 16
+ else:
+ m = ExpressionGenerator.OCT.match(host_split[i])
+ if m and allow_octal:
+ base = 8
+ else:
+ m = ExpressionGenerator.DEC.match(host_split[i])
+ if m:
+ base = 10
+ else:
+ return None
+ n = long(m.group(1), base)
+ if n > 255:
+ if i < len(host_split) - 1:
+ n &= 0xff
+ ip.append(n)
+ else:
+ bytes = []
+ shift = 0
+ while n > 0 and len(bytes) < 4:
+ bytes.append(n & 0xff)
+ n >>= 8
+ if len(ip) + len(bytes) > 4:
+ return None
+ bytes.reverse()
+ ip.extend(bytes)
+ else:
+ ip.append(n)
+
+ while len(ip) < 4:
+ ip.append(0)
+ return '%u.%u.%u.%u' % tuple(ip)
+
+ def Expressions(self):
+ """
+ A generator of the possible expressions.
+ """
+ for host_parts in self._host_lists:
+ host = '.'.join(host_parts)
+ for p in self._path_exprs:
+ yield Expression(host, p)
+
+ @staticmethod
+ def _Escape(unescaped_str):
+ """Fully unescape the given string, then re-escape once.
+
+ Args:
+ unescaped_str: string that should be escaped.
+ Returns:
+ Escaped string according to the SafeBrowsing protocol.
+ """
+ unquoted = urllib.unquote(unescaped_str)
+ while unquoted != unescaped_str:
+ unescaped_str = unquoted
+ unquoted = urllib.unquote(unquoted)
+
+ return urllib.quote(unquoted, ExpressionGenerator.SAFE_CHARS)
+
+ def _MakeHostLists(self, host, parse_exception):
+ """
+ Canonicalize host and build self._host_lists.
+ """
+ ip = ExpressionGenerator.CanonicalizeIp(host)
+ if ip is not None:
+ # Is an IP.
+ self._host_lists.append([ip])
+ return
+
+ # Is a hostname.
+ # Skip trailing, leading and consecutive dots.
+ host_split = [part for part in host.split('.') if part]
+ if len(host_split) < 2:
+ raise parse_exception
+ start = len(host_split) - 5
+ stop = len(host_split) - 1
+ if start <= 0:
+ start = 1
+ self._host_lists.append(host_split)
+ for i in xrange(start, stop):
+ self._host_lists.append(host_split[i:])
+
+
+class Expression(object):
+ """Class which represents a host-suffix, path-prefix expression."""
+ def __init__(self, host, path):
+ self._host = host
+ self._path = path
+ self._value = host + path
+ self._hash_value = util.GetHash256(self._value)
+
+ def __str__(self):
+ return self.Value()
+
+ def __repr__(self):
+ """
+ Not really a good repr. This is for debugging.
+ """
+ return self.Value()
+
+ def Value(self):
+ return self._value
+
+ def HashValue(self):
+ return self._hash_value
diff --git a/python/expression_test.py b/python/expression_test.py
new file mode 100755
index 0000000..fe4b5de
--- /dev/null
+++ b/python/expression_test.py
@@ -0,0 +1,249 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Test for googlesafebrowsing.expression."""
+
+import expression
+
+import logging
+import unittest
+
+
+class CanonicalizationTest(unittest.TestCase):
+ def testCanonicalizeIp(self):
+ ips = [
+ ('1.2.3.4', '1.2.3.4'),
+ ('012.034.01.055', '10.28.1.45'),
+ ('0x12.0x43.0x44.0x01', '18.67.68.1'),
+ ('167838211', '10.1.2.3'),
+ ('12.0x12.01234', '12.18.2.156'),
+ ('0x10000000b', '0.0.0.11'),
+ ('asdf.com', None),
+ ('0x120x34', None),
+ ('123.123.0.0.1', None),
+ ('1.2.3.00x0', None),
+ ('fake ip', None),
+ ('123.123.0.0.1', None),
+ ('255.0.0.1', '255.0.0.1'),
+ ('12.0x12.01234', '12.18.2.156'),
+ # TODO: Make this test case work.
+ # This doesn't seem very logical to me, but it might be how microsoft's
+ # dns works. Certainly it's how Netcraft does it.
+ #('276.2.3', '20.2.0.3'),
+ ('012.034.01.055', '10.28.1.45'),
+ ('0x12.0x43.0x44.0x01', '18.67.68.1'),
+ ('167838211', '10.1.2.3'),
+ ('3279880203', '195.127.0.11'),
+ ('4294967295', '255.255.255.255'),
+ ('10.192.95.89 xy', '10.192.95.89'),
+ ('1.2.3.00x0', None),
+ # If we find bad octal parse the whole IP as decimal or hex.
+ ('012.0xA0.01.089', '12.160.1.89')]
+
+ for testip, expected in ips:
+ actual = expression.ExpressionGenerator.CanonicalizeIp(testip)
+ self.assertEqual(actual, expected,
+ 'test input: %s, actual: %s, expected: %s' % (testip,
+ actual,
+ expected))
+
+ def testCanonicalizeUrl(self):
+ urls = [
+ ('http://google.com/', 'http://google.com/'),
+ ('http://google.com:80/a/b', 'http://google.com/a/b'),
+ ('http://google.com:80/a/b/c/', 'http://google.com/a/b/c/'),
+ ('http://GOOgle.com', 'http://google.com/'),
+ ('http://..google..com../', 'http://google.com/'),
+ ('http://google.com/%25%34%31%25%31%46', 'http://google.com/A%1F'),
+ ('http://google^.com/', 'http://google^.com/'),
+ ('http://google.com/1/../2/././', 'http://google.com/2/'),
+ ('http://google.com/1//2?3//4', 'http://google.com/1/2?3//4'),
+ # Some more examples of our url lib unittest.
+ ('http://host.com/%25%32%35', 'http://host.com/%25'),
+ ('http://host.com/%25%32%35%25%32%35', 'http://host.com/%25%25'),
+ ('http://host.com/%2525252525252525', 'http://host.com/%25'),
+ ('http://host.com/asdf%25%32%35asd', 'http://host.com/asdf%25asd'),
+ ('http://host.com/%%%25%32%35asd%%',
+ 'http://host.com/%25%25%25asd%25%25'),
+ ('http://www.google.com/', 'http://www.google.com/'),
+ ('http://%31%36%38%2e%31%38%38%2e%39%39%2e%32%36/%2E%73%65%63%75%72%65/%77%77%77%2E%65%62%61%79%2E%63%6F%6D/', 'http://168.188.99.26/.secure/www.ebay.com/'),
+ ('http://195.127.0.11/uploads/%20%20%20%20/.verify/.eBaysecure=updateuserdataxplimnbqmn-xplmvalidateinfoswqpcmlx=hgplmcx/', 'http://195.127.0.11/uploads/%20%20%20%20/.verify/.eBaysecure=updateuserdataxplimnbqmn-xplmvalidateinfoswqpcmlx=hgplmcx/'),
+ ('http://host%23.com/%257Ea%2521b%2540c%2523d%2524e%25f%255E00%252611%252A22%252833%252944_55%252B', 'http://host%23.com/~a!b@c%23d$e%25f^00&11*22(33)44_55+'),
+ ('http://3279880203/blah', 'http://195.127.0.11/blah'),
+ ('http://www.google.com/blah/..', 'http://www.google.com/'),
+ ('http://a.com/../b', 'http://a.com/b'),
+ ('www.google.com/', 'http://www.google.com/'),
+ ('www.google.com', 'http://www.google.com/'),
+ ('http://www.evil.com/blah#frag', 'http://www.evil.com/blah'),
+ ('http://www.GOOgle.com/', 'http://www.google.com/'),
+ ('http://www.google.com.../', 'http://www.google.com/'),
+ ('http://www.google.com/foo\tbar\rbaz\n2', 'http://www.google.com/foobarbaz2'),
+ ('http://www.google.com/q?', 'http://www.google.com/q?'),
+ ('http://www.google.com/q?r?', 'http://www.google.com/q?r?'),
+ ('http://www.google.com/q?r?s', 'http://www.google.com/q?r?s'),
+ ('http://evil.com/foo#bar#baz', 'http://evil.com/foo'),
+ ('http://evil.com/foo;', 'http://evil.com/foo;'),
+ ('http://evil.com/foo?bar;', 'http://evil.com/foo?bar;'),
+ ('http://\x01\x80.com/', 'http://%01%80.com/'),
+ ('http://notrailingslash.com', 'http://notrailingslash.com/'),
+ ('http://www.gotaport.com:1234/', 'http://www.gotaport.com:1234/'),
+ ('http://www.google.com:443/', 'http://www.google.com:443/'),
+ (' http://www.google.com/ ', 'http://www.google.com/'),
+ ('http:// leadingspace.com/', 'http://%20leadingspace.com/'),
+ ('http://%20leadingspace.com/', 'http://%20leadingspace.com/'),
+ ('%20leadingspace.com/', 'http://%20leadingspace.com/'),
+ ('https://www.securesite.com:443/', 'https://www.securesite.com/'),
+ ('ftp://ftp.myfiles.com:21/', 'ftp://ftp.myfiles.com/'),
+ ('http://some%1Bhost.com/%1B', 'http://some%1Bhost.com/%1B'),
+ # Test NULL character
+ ('http://test%00\x00.com/', 'http://test%00%00.com/'),
+ # Username and password should be removed
+ ('http://user:password@google.com/', 'http://google.com/'),
+ # All of these cases are missing a valid hostname and should return ''
+ ('', None),
+ (':', None),
+ ('/blah', None),
+ ('#ref', None),
+ ('/blah#ref', None),
+ ('?query#ref', None),
+ ('/blah?query#ref', None),
+ ('/blah;param', None),
+ ('http://#ref', None),
+ ('http:///blah#ref', None),
+ ('http://?query#ref', None),
+ ('http:///blah?query#ref', None),
+ ('http:///blah;param', None),
+ ('http:///blah;param?query#ref', None),
+ ('mailto:bryner@google.com', None),
+ # If the protocol is unrecognized, the URL class does not parse out
+ # a hostname.
+ ('myprotocol://site.com/', None),
+ # This URL should _not_ have hostname shortening applied to it.
+ ('http://i.have.way.too.many.dots.com/', 'http://i.have.way.too.many.dots.com/'),
+ # WholeSecurity escapes parts of the scheme
+ ('http%3A%2F%2Fwackyurl.com:80/', 'http://wackyurl.com/'),
+ ('http://W!eird<>Ho$^.com/', 'http://w!eird<>ho$^.com/'),
+ # The path should have a leading '/' even if the hostname was terminated
+ # by something other than a '/'.
+ ('ftp://host.com?q', 'ftp://host.com/?q')]
+
+ for testin, expected in urls:
+ actual = expression.ExpressionGenerator.CanonicalizeUrl(testin)
+ self.assertEqual(
+ actual, expected,
+ 'test input: %s, actual: %s, expected: %s' % (
+ testin, actual, expected))
+
+
+class ExprGenTest(unittest.TestCase):
+ def CheckExpr(self, url, expected):
+ gen = expression.ExpressionGenerator(url)
+ exprs = list(gen.Expressions())
+ self.assertEqual(len(exprs), len(expected),
+ 'Length mismatch.\nExpected: %s\nActual: %s' % (
+ expected, exprs))
+ for i in xrange(len(exprs)):
+ self.assertEqual(exprs[i].Value(), expected[i],
+ 'List mismatch.\nExpected: %s\nAactual: %s' % (expected,
+ exprs))
+
+ def testExpressionGenerator(self):
+ self.CheckExpr('http://12.0x12.01234/a/b/cde/f?g=foo&h=bar#quux',
+ [
+ '12.18.2.156/a/b/cde/f?g=foo&h=bar',
+ '12.18.2.156/a/b/cde/f',
+ '12.18.2.156/a/b/cde/',
+ '12.18.2.156/a/b/',
+ '12.18.2.156/a/',
+ '12.18.2.156/',])
+
+ self.CheckExpr('http://www.google.com/a/b/cde/f?g=foo&h=bar#quux',
+ [
+ 'www.google.com/a/b/cde/f?g=foo&h=bar',
+ 'www.google.com/a/b/cde/f',
+ 'www.google.com/a/b/cde/',
+ 'www.google.com/a/b/',
+ 'www.google.com/a/',
+ 'www.google.com/',
+
+ 'google.com/a/b/cde/f?g=foo&h=bar',
+ 'google.com/a/b/cde/f',
+ 'google.com/a/b/cde/',
+ 'google.com/a/b/',
+ 'google.com/a/',
+ 'google.com/'])
+
+ self.CheckExpr('http://a.b.c.d.e.f.g/h/i/j/k/l/m/n/o?p=foo&q=bar#quux',
+ [
+ 'a.b.c.d.e.f.g/h/i/j/k/l/m/n/o?p=foo&q=bar',
+ 'a.b.c.d.e.f.g/h/i/j/k/l/m/n/o',
+ 'a.b.c.d.e.f.g/h/i/j/',
+ 'a.b.c.d.e.f.g/h/i/',
+ 'a.b.c.d.e.f.g/h/',
+ 'a.b.c.d.e.f.g/',
+
+ 'c.d.e.f.g/h/i/j/k/l/m/n/o?p=foo&q=bar',
+ 'c.d.e.f.g/h/i/j/k/l/m/n/o',
+ 'c.d.e.f.g/h/i/j/',
+ 'c.d.e.f.g/h/i/',
+ 'c.d.e.f.g/h/',
+ 'c.d.e.f.g/',
+
+ 'd.e.f.g/h/i/j/k/l/m/n/o?p=foo&q=bar',
+ 'd.e.f.g/h/i/j/k/l/m/n/o',
+ 'd.e.f.g/h/i/j/',
+ 'd.e.f.g/h/i/',
+ 'd.e.f.g/h/',
+ 'd.e.f.g/',
+
+ 'e.f.g/h/i/j/k/l/m/n/o?p=foo&q=bar',
+ 'e.f.g/h/i/j/k/l/m/n/o',
+ 'e.f.g/h/i/j/',
+ 'e.f.g/h/i/',
+ 'e.f.g/h/',
+ 'e.f.g/',
+
+ 'f.g/h/i/j/k/l/m/n/o?p=foo&q=bar',
+ 'f.g/h/i/j/k/l/m/n/o',
+ 'f.g/h/i/j/',
+ 'f.g/h/i/',
+ 'f.g/h/',
+ 'f.g/'])
+
+ self.CheckExpr('http://www.phisher.co.uk/a/b',
+ [
+ 'www.phisher.co.uk/a/b',
+ 'www.phisher.co.uk/a/',
+ 'www.phisher.co.uk/',
+
+ 'phisher.co.uk/a/b',
+ 'phisher.co.uk/a/',
+ 'phisher.co.uk/',
+
+ 'co.uk/a/b',
+ 'co.uk/a/',
+ 'co.uk/'])
+
+ self.CheckExpr('http://a.b/?', ['a.b/'])
+ self.CheckExpr('http://1.2.3.4/a/b',
+ ['1.2.3.4/a/b', '1.2.3.4/a/', '1.2.3.4/'])
+ self.CheckExpr('foo.com', ['foo.com/'])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/python/hashprefix_trie.py b/python/hashprefix_trie.py
new file mode 100755
index 0000000..638603f
--- /dev/null
+++ b/python/hashprefix_trie.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Simple trie implementation that is used by the SB client."""
+
+import itertools
+
+class HashprefixTrie(object):
+ """Trie that maps hash prefixes to a list of values."""
+
+ # Prefixes shorter than this will not be stored in the HashprefixTrie for
+ # performance reasons. Insertion, Lookup and Deletion will fail on prefixes
+ # shorter than this value.
+ MIN_PREFIX_LEN = 4
+
+ class Node(object):
+ """Represents a node in the trie.
+
+ Holds a list of values and a dict that maps char -> Node.
+ """
+ __slots__ = ('values', 'children', 'parent')
+
+ def __init__(self, parent=None):
+ self.values = []
+ self.children = {} # Maps char -> HashprefixTrie.Node
+ self.parent = parent
+
+ def __init__(self):
+ self._root = HashprefixTrie.Node()
+ self._size = 0 # Number of hash prefixes in the trie.
+
+ def _GetPrefixComponents(self, hashprefix):
+ # For performance reasons we will not store any prefixes that are shorter
+ # than 4B. The SafeBrowsing protocol will most probably never serve
+ # prefixes shorter than 4B because it would lead to a high number of
+ # collisions.
+ assert(len(hashprefix) >= HashprefixTrie.MIN_PREFIX_LEN)
+ # Collapse the first 4B together to reduce the number of nodes we have to
+ # store in memory.
+ yield hashprefix[:HashprefixTrie.MIN_PREFIX_LEN]
+ for char in hashprefix[HashprefixTrie.MIN_PREFIX_LEN:]:
+ yield char
+
+ def _GetNode(self, hashprefix, create_if_necessary=False):
+ """Returns the trie node that will contain hashprefix.
+
+ If create_if_necessary is True this method will create the necessary
+ trie nodes to store hashprefix in the trie.
+ """
+ node = self._root
+ for char in self._GetPrefixComponents(hashprefix):
+ if char in node.children:
+ node = node.children[char]
+ elif create_if_necessary:
+ node = node.children.setdefault(char, HashprefixTrie.Node(node))
+ else:
+ return None
+ return node
+
+ def Insert(self, hashprefix, entry):
+ """Insert entry with a given hash prefix."""
+ self._GetNode(hashprefix, True).values.append(entry)
+ self._size += 1
+
+ def Delete(self, hashprefix, entry):
+ """Delete a given entry with hash prefix."""
+ node = self._GetNode(hashprefix)
+ if node and entry in node.values:
+ node.values.remove(entry)
+ self._size -= 1
+
+ # recursively delete parent nodes if necessary.
+ while not node.values and not node.children and node.parent:
+ node = node.parent
+
+ if len(hashprefix) == HashprefixTrie.MIN_PREFIX_LEN:
+ del node.children[hashprefix]
+ break
+
+ char, hashprefix = hashprefix[-1], hashprefix[:-1]
+ del node.children[char]
+
+ def Size(self):
+ """Returns the number of values stored in the trie."""
+ return self._size;
+
+ def GetPrefixMatches(self, fullhash):
+ """Yields all values that have a prefix of the given fullhash."""
+ node = self._root
+ for char in self._GetPrefixComponents(fullhash):
+ node = node.children.get(char, None)
+ if not node:
+ break
+ for value in node.values:
+ yield value
+
+ def PrefixIterator(self):
+ """Iterator over all the hash prefixes that have values."""
+ stack = [('', self._root)]
+ while stack:
+ hashprefix, node = stack.pop()
+ if node.values:
+ yield hashprefix
+
+ for char, child in node.children.iteritems():
+ stack.append((hashprefix + char, child))
diff --git a/python/hashprefix_trie_test.py b/python/hashprefix_trie_test.py
new file mode 100755
index 0000000..03852fe
--- /dev/null
+++ b/python/hashprefix_trie_test.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unittest for googlesafebrowsing.hashprefix_trie."""
+
+import hashprefix_trie
+import unittest
+
+class HashPrefixTrieTest(unittest.TestCase):
+
+ def assertSameElements(self, a, b):
+ a = sorted(list(a))
+ b = sorted(list(b))
+ self.assertEqual(a, b)
+
+ def testSimple(self):
+ trie = hashprefix_trie.HashprefixTrie()
+ trie.Insert('aabc', 1)
+ trie.Insert('aabcd', 2)
+ trie.Insert('acde', 3)
+ trie.Insert('abcdefgh', 4)
+
+ self.assertSameElements([1, 2], trie.GetPrefixMatches('aabcdefg'))
+ self.assertSameElements([1, 2], trie.GetPrefixMatches('aabcd'))
+ self.assertSameElements([1], trie.GetPrefixMatches('aabc'))
+ self.assertSameElements([3], trie.GetPrefixMatches('acde'))
+ self.assertEqual(4, trie.Size())
+
+ trie.Delete('abcdefgh', 4)
+ # Make sure that all nodes between abcd and abcdefgh were deleted because
+ # they were emtpy.
+ self.assertEqual(None, trie._GetNode('abcd'))
+
+ trie.Delete('aabc', 2) # No such prefix, value pair.
+ trie.Delete('aaaa', 1) # No such prefix, value pair.
+ self.assertEqual(3, trie.Size())
+ trie.Delete('aabc', 1)
+ self.assertEqual(2, trie.Size())
+
+ self.assertSameElements(['aabcd', 'acde'], trie.PrefixIterator())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/sblist.py b/python/sblist.py
new file mode 100755
index 0000000..a934bf3
--- /dev/null
+++ b/python/sblist.py
@@ -0,0 +1,376 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""List objects represent a Google safe browsing blacklist."""
+
+import hashprefix_trie
+import util
+
+import logging
+import urlparse
+
+
+class List(object):
+ """
+ This represents a google safebrowsing list.
+ """
+ def __init__(self, name):
+ self._name = name
+
+ # Time this list was last successfully updated from a download request.
+ self._update_time = None
+
+ # Trie that maps hashprefix to AddEntries.
+ self._prefix_trie = hashprefix_trie.HashprefixTrie()
+ # Map of addchunknum to map of hashprefix to AddEntry.
+ # Keys are only deleted from this when we get an AddDel.
+ self._chunknum_map = {}
+
+ # Maps addchunknum -> prefix -> SubEntry
+ self._subbed = {}
+ # Map of subchunknum to a list of SubEntry. Sometimes different subchunks
+ # will sub the same expression. In that case, _subbed will reference the
+ # most recent subchunk, and _subchunks will store all of the subchunks.
+ self._subchunks = {}
+
+ def Name(self):
+ return self._name
+
+ def SetUpdateTime(self, timestamp):
+ self._update_time = timestamp
+
+ def UpdateTime(self):
+ return self._update_time
+
+ def AddChunkMap(self):
+ """
+ Returns the mapping of add chunks -> prefix -> AddEntry
+ """
+ return self._chunknum_map
+
+ def SubChunkMap(self):
+ """
+ Returns the mapping of sub chunks -> SubEntry
+ """
+ return self._subchunks
+
+ def NumPrefixes(self):
+ """
+ Return the number of prefixes in this list.
+ """
+ return self._prefix_trie.Size()
+
+ def GetPrefixMatches(self, fullhash):
+ """
+ Returns all AddEntry objects whose hash is a prefix of the given fullhash.
+ """
+ return self._prefix_trie.GetPrefixMatches(fullhash)
+
+ def GotAddChunk(self, chunknum):
+ return self._chunknum_map.has_key(chunknum)
+
+ def GotSubChunk(self, chunknum):
+ return self._subchunks.has_key(chunknum)
+
+ def AddFullHash(self, fullhash, addchunknum, timestamp):
+ """
+ Add the full hash for an existing prefix.
+ Return True if the expression was actually added, or False if it was
+ previously subbed, or if no prefix for fullhash has been received.
+ """
+ for entry in self._prefix_trie.GetPrefixMatches(fullhash):
+ if entry.AddChunkNum() == addchunknum:
+ entry.SetFullHash(fullhash, timestamp)
+ return True
+ return False
+
+ def AddPrefix(self, hash, addchunknum):
+ """Try to add a prefix for the list.
+
+ Args:
+ hash: either a hash-prefix or a full-hash.
+ addchunknum: the add chunk number.
+
+ Return:
+ True if the expression was added, or False if it was
+ previously subbed.
+ """
+ if util.IsFullHash(hash):
+ prefix, fullhash = hash, hash
+ else:
+ prefix, fullhash = hash, None
+
+ # Check to see whether that add entry was previously subbed.
+ sub_entry = self._subbed.get(addchunknum, {}).get(prefix, None)
+ if sub_entry:
+ # This expression has been subbed.
+ logging.debug('Presubbed: %s:%d:%s', self.Name(), addchunknum,
+ util.Bin2Hex(prefix))
+ # We have to create an empty add chunk in case it doesn't exist so that we
+ # record that we have received the chunk.
+ self._chunknum_map.setdefault(addchunknum, {})
+
+ # We no longer need this sub entry since we received its corresponding add
+ # entry.
+ del self._subbed[addchunknum][prefix]
+ if not self._subbed[addchunknum]:
+ del self._subbed[addchunknum]
+ self._subchunks[sub_entry.SubNum()].remove(sub_entry)
+ return False
+
+ chunknum_prefixes = self._chunknum_map.setdefault(addchunknum, {})
+ if prefix in chunknum_prefixes:
+ logging.warning('Prefix %s already added from add chunk %d. Ignoring',
+ util.Bin2Hex(prefix), addchunknum)
+ return False
+
+ add_entry = AddEntry(prefix, addchunknum, fullhash=fullhash)
+ chunknum_prefixes[prefix] = add_entry
+ self._prefix_trie.Insert(prefix, add_entry)
+ return True
+
+ def RemovePrefix(self, prefix, subchunknum, addchunknum):
+ """
+ Return True iff there is a prefix to remove.
+ """
+ logmsg = '%s:%d:%s' % (self.Name(), addchunknum, util.Bin2Hex(prefix))
+ logging.debug('attempted sub: %s', logmsg)
+
+ # Lets see if we already have the corresponding add entry.
+ if addchunknum in self._chunknum_map:
+ # We have to create an empty sub chunk in case it does not exist so that
+ # we record that we received the sub chunk.
+ self._subchunks.setdefault(subchunknum, [])
+
+ # If an add entry exists we need to remove it. If the entry does not
+ # exist but the add chunk is empty we don't have to do anything.
+ add_entry = self._chunknum_map[addchunknum].get(prefix, None)
+ if add_entry is not None:
+ logging.debug('successful sub: %s', logmsg)
+ self._prefix_trie.Delete(prefix, add_entry)
+ # Now delete entry from the chunknum map as well.
+ del self._chunknum_map[addchunknum][prefix]
+ elif self._chunknum_map[addchunknum]:
+ # The prefix does not exist in this add chunk and the add chunk is not
+ # empty. This should never happen.
+ logging.warning('Unable to remove missing prefix:%s sub:%d add:%s',
+ util.Bin2Hex(prefix), subchunknum, addchunknum)
+ return False
+ return True
+
+ # We have not yet received the corresponding add entry. Store the
+ # sub entry for later.
+ entry = SubEntry(prefix, subchunknum, addchunknum)
+ self._subbed.setdefault(addchunknum, {})[prefix] = entry
+ self._subchunks.setdefault(subchunknum, []).append(entry)
+ return False
+
+ def AddEmptyAddChunk(self, addchunknum):
+ """
+ Adds the addchunknum to the list of known chunks but without any associated
+ data. If data currently exists for the chunk it is removed.
+ """
+ if self.DeleteAddChunk(addchunknum):
+ logging.debug("Removing data that was associated with add chunk %d" %
+ addchunknum)
+ self._chunknum_map[addchunknum] = {}
+
+ def AddEmptySubChunk(self, subchunknum):
+ """
+ Adds the subchunknum to the list of known chunks but without any associated
+ data. If data currently exists for the chunk it is removed.
+ """
+ if subchunknum in self._subchunks:
+ self.DeleteSubChunk(subchunknum)
+ self._subchunks[subchunknum] = []
+
+ def DeleteAddChunk(self, addchunknum):
+ # No matter what, we remove sub expressions that point to this chunk as they
+ # will never need to be applied.
+ if addchunknum in self._subbed:
+ for sub_entry in self._subbed[addchunknum].itervalues():
+ # Remove the sub entry from the subchunks map.
+ self._subchunks[sub_entry.SubNum()].remove(sub_entry)
+ del self._subbed[addchunknum]
+
+ if addchunknum not in self._chunknum_map:
+ # Never received or already AddDel-ed this add chunk.
+ return False
+
+ # Remove entries from _chunknum_map
+ chunknum_prefixes = self._chunknum_map[addchunknum]
+ del self._chunknum_map[addchunknum]
+ if not len(chunknum_prefixes):
+ # Add chunk was already empty.
+ return True
+
+ # Remove entries from _prefix_trie
+ for prefix, add_entry in chunknum_prefixes.iteritems():
+ self._prefix_trie.Delete(prefix, add_entry)
+
+ return True
+
+ def DeleteSubChunk(self, subchunknum):
+ """Deletes the sub chunk with the given sub chunk number.
+
+ Returns:
+ True iff the sub chunk was removed. Note: this method returns true when
+ an empty sub chunk gets removed.
+ """
+ if subchunknum not in self._subchunks:
+ return False
+ entries = self._subchunks.pop(subchunknum)
+ for entry in entries:
+ del self._subbed[entry.AddNum()][entry.Prefix()]
+ if not self._subbed[entry.AddNum()]:
+ # No more subs for that add chunk.
+ del self._subbed[entry.AddNum()]
+ return True
+
+ def DownloadRequest(self, should_mac=False):
+ """
+ Return the state of this List as a string as required for download requests.
+ """
+ addnums = self._chunknum_map.keys()
+ addnums.sort()
+ subnums = self._subchunks.keys()
+ subnums.sort()
+ dlreq = '%s;' % (self.Name(),)
+ if addnums:
+ dlreq = '%sa:%s' % (dlreq, self._GetRangeStr(addnums))
+ if subnums:
+ if addnums:
+ dlreq = '%s:' % (dlreq,)
+ dlreq = '%ss:%s' % (dlreq, self._GetRangeStr(subnums))
+ if should_mac:
+ if addnums or subnums:
+ dlreq = '%s:mac' % (dlreq,)
+ else:
+ dlreq = '%smac' % (dlreq,)
+ return dlreq
+
+ def _GetRangeStr(self, nums):
+ """
+ nums: sorted list of integers.
+ """
+ if len(nums) == 0:
+ return ''
+ output = []
+ i = 0
+ while i < len(nums):
+ output.append(str(nums[i]))
+ use_range = False
+ while i < len(nums) - 1 and nums[i + 1] - nums[i] == 1:
+ i += 1
+ use_range = True
+ if use_range:
+ output.append('-')
+ output.append(str(nums[i]))
+ if i < len(nums) - 1:
+ output.append(',')
+ i += 1
+ return ''.join(output)
+
+
+class AddEntry(object):
+ __slots__ = ('_prefix', '_addchunknum', '_fulllength', '_gethash_timestamp')
+
+ def __init__(self, prefix, addchunknum, fullhash=None):
+ """
+ Create an add entry with the given prefix and addchunknum. Fullhash
+ is set to the full-length hash, if one is present for this entry.
+ """
+ self._prefix = prefix
+ self._addchunknum = addchunknum
+ # Full length hash associated with this AddEntry, if any.
+ self._fulllength = fullhash
+
+ # Timestamp associated with the most recent gethash response that set
+ # self._fulllength, if any.
+ self._gethash_timestamp = None
+
+ def __str__(self):
+ p = self._prefix
+ if p is not None:
+ p = util.Bin2Hex(p)
+ f = self._fulllength
+ if f is not None:
+ f = util.Bin2Hex(f)
+ return 'AddEntry(%s, %s, %d)' % (p, f, self._addchunknum)
+
+ def __eq__(self, other):
+ return str(self) == str(other)
+
+ def __repr__(self):
+ return self.__str__()
+
+ def __cmp__(self, other):
+ if self._addchunknum == other._addchunknum:
+ if self._prefix == other._prefix:
+ return cmp(self._fulllength, other._fulllength)
+ return cmp(self._prefix, other._prefix)
+ return cmp(self._addchunknum, other._addchunknum)
+
+ def Prefix(self):
+ return self._prefix
+
+ def FullHash(self):
+ """
+ Return the full length hash if we have it. Otherwise, None.
+ """
+ return self._fulllength
+
+ def GetHashTimestamp(self):
+ return self._gethash_timestamp
+
+ def SetFullHash(self, fullhash, timestamp):
+ self._fulllength = fullhash
+ self._gethash_timestamp = timestamp
+
+ def AddChunkNum(self):
+ return self._addchunknum
+
+
+class SubEntry(object):
+ __slots__ = ('_prefix', '_subnum', '_addnum')
+
+ def __init__(self, hash_prefix, subchunknum, addchunknum):
+ """
+ hash_prefix: None to sub a full domain add.
+ """
+ self._prefix = hash_prefix
+ self._subnum = subchunknum
+ self._addnum = addchunknum
+
+ def __str__(self):
+ return 'SubEntry(%s, sub:%d, add:%d)' % (util.Bin2Hex(self.Prefix()),
+ self.SubNum(), self.AddNum())
+
+ def __cmp__(self, other):
+ if self._prefix == other._prefix:
+ if self._subnum == other._subnum:
+ return cmp(self._addnum, other._addnum)
+ return cmp(self._subnum, other._subnum)
+ return cmp(self._prefix, other._prefix)
+
+ def Prefix(self):
+ return self._prefix
+
+ def SubNum(self):
+ return self._subnum
+
+ def AddNum(self):
+ return self._addnum
diff --git a/python/sblist_test.py b/python/sblist_test.py
new file mode 100755
index 0000000..3c520f7
--- /dev/null
+++ b/python/sblist_test.py
@@ -0,0 +1,250 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unittest for googlesafebrowsing.sblist."""
+
+import sblist
+import server
+import util
+
+import unittest
+
+
+class ListTest(unittest.TestCase):
+
+ def assertSameElements(self, a, b):
+ a = sorted(list(a))
+ b = sorted(list(b))
+ self.assertEqual(a, b, 'Expected: [%s], Found: [%s]' %
+ (', '.join(map(str, a)), ', '.join(map(str, b))))
+
+ def setUp(self):
+ self._list = sblist.List('goog-malware-shavar')
+ self._list.AddPrefix('aprefix', 1)
+ self._list.AddPrefix('bprefix', 2)
+ self._list.AddPrefix('aprefix', 3)
+ self._list.AddPrefix('cprefix', 1)
+ self._list.AddPrefix('0000', 4)
+ self.assertTrue(self._list.AddFullHash('0000fullhash', 4, 10))
+
+ self._list.AddPrefix('dprefix', 5)
+ self._list.AddEmptyAddChunk(5) # should remove dprefix
+
+ self._list.AddPrefix('eprefix', 6)
+ self._list.AddPrefix('fprefix', 6)
+ # After this add chunk 6 should still be around
+ self.assertTrue(self._list.RemovePrefix('eprefix', 1, 6))
+
+ self._list.AddPrefix('gprefix', 7)
+ self._list.AddPrefix('hprefix', 8)
+ self.assertTrue(self._list.RemovePrefix('gprefix', 2, 7))
+ self.assertTrue(self._list.RemovePrefix('hprefix', 2, 8))
+ # Subs for adds we have not yet received.
+ self.assertFalse(self._list.RemovePrefix('iprefix', 2, 11))
+ self.assertFalse(self._list.RemovePrefix('jprefix', 2, 12))
+
+ # Test prefix matching
+ self._list.AddPrefix('prefixaa', 9)
+ self._list.AddPrefix('prefixab', 9)
+ self._list.AddPrefix('prefixaa', 10)
+
+ self._list.AddEmptySubChunk(3)
+
+ # Add some empty sub chunks to see that we would support fragmentation.
+ self._list.AddEmptySubChunk(5)
+ self._list.AddEmptySubChunk(6)
+ self._list.AddEmptySubChunk(7)
+
+ def testGetRangeStr(self):
+ sbl = sblist.List('foo')
+
+ s = sbl._GetRangeStr([1, 2, 3, 4])
+ self.assertEqual(s, '1-4')
+
+ s = sbl._GetRangeStr([1, 2, 4, 5, 7, 8, 9, 10, 11, 13, 15, 17])
+ self.assertEqual(s, '1-2,4-5,7-11,13,15,17')
+
+ s = sbl._GetRangeStr([1])
+ self.assertEqual(s, '1')
+
+ def testName(self):
+ self.assertEqual('goog-malware-shavar', self._list.Name())
+
+ def testGetSetUpdateTime(self):
+ self.assertEqual(None, self._list.UpdateTime())
+ self._list.SetUpdateTime(42)
+ self.assertEqual(42, self._list.UpdateTime())
+
+ def testAddChukMap(self):
+ self.assertSameElements([1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ self._list.AddChunkMap())
+ self.assertSameElements(['aprefix', 'cprefix'],
+ self._list.AddChunkMap()[1])
+ self.assertSameElements(['bprefix',], self._list.AddChunkMap()[2])
+ self.assertSameElements(['aprefix',], self._list.AddChunkMap()[3])
+ self.assertSameElements(['0000',], self._list.AddChunkMap()[4])
+ self.assertSameElements([], self._list.AddChunkMap()[5])
+ self.assertSameElements(['fprefix',], self._list.AddChunkMap()[6])
+ self.assertSameElements([], self._list.AddChunkMap()[7])
+ self.assertSameElements([], self._list.AddChunkMap()[8])
+ self.assertSameElements(['prefixaa', 'prefixab'],
+ self._list.AddChunkMap()[9])
+ self.assertSameElements(['prefixaa'], self._list.AddChunkMap()[10])
+
+ def testSubChunkMap(self):
+ self.assertSameElements([1, 2, 3, 5, 6, 7],
+ self._list.SubChunkMap())
+ self.assertEqual(0, len(self._list.SubChunkMap()[1]))
+ self.assertSameElements([sblist.SubEntry('iprefix', 2, 11),
+ sblist.SubEntry('jprefix', 2, 12)],
+ self._list.SubChunkMap()[2])
+ self.assertSameElements([], self._list.SubChunkMap()[3])
+ self.assertSameElements([], self._list.SubChunkMap()[5])
+ self.assertSameElements([], self._list.SubChunkMap()[6])
+ self.assertSameElements([], self._list.SubChunkMap()[7])
+
+ def testNumPrefixes(self):
+ self.assertEqual(9, self._list.NumPrefixes())
+
+ def testGotAddChunk(self):
+ for i in [1, 2, 3, 4, 5, 6, 7]:
+ self.assertTrue(self._list.GotAddChunk(i))
+ self.assertFalse(self._list.GotAddChunk(100))
+
+ def testGotSubChunk(self):
+ for i in [1, 2, 3, 5, 6, 7]:
+ self.assertTrue(self._list.GotSubChunk(i))
+ self.assertFalse(self._list.GotSubChunk(4))
+
+ def testAddFullHash(self):
+ # The prefix must be present in the list.
+ self.assertFalse(self._list.AddFullHash('noprefix', 4, 10))
+ self.assertTrue(self._list.AddFullHash('0000full', 4, 42))
+
+ entry = sblist.AddEntry('0000', 4, '0000full')
+ self.assertSameElements([entry], self._list.GetPrefixMatches('0000'))
+
+ def testAddPrefix(self):
+ # This should return false because this add chunk is already subbed.
+ self.assertFalse(self._list.AddPrefix('iprefix', 11))
+
+ # Test adding a prefix to a new chunk.
+ self.assertTrue(self._list.AddPrefix('asdf', 10))
+ entry = sblist.AddEntry('asdf', 10)
+ self.assertSameElements([entry], self._list.GetPrefixMatches('asdfasdf'))
+ self.assertSameElements([entry], self._list.GetPrefixMatches('asdf'))
+
+ # Test adding a prefix to an existing chunk.
+ self.assertTrue(self._list.AddPrefix('asdfasdf', 3))
+ other_entry = sblist.AddEntry('asdfasdf', 3)
+ self.assertSameElements([entry, other_entry],
+ self._list.GetPrefixMatches('asdfasdf'))
+
+ # Check to see if it supports full hashes correctly.
+ fullhash = util.GetHash256('asdf')
+ self.assertTrue(self._list.AddPrefix(fullhash, 11))
+ self.assertEqual(1, len(list(self._list.GetPrefixMatches(fullhash))))
+
+ def testRemovePrefix(self):
+ # Can't remove non existent prefix.
+ self.assertFalse(self._list.RemovePrefix('some_prefix', 8, 1))
+ # Remove first of two prefixes.
+ self.assertTrue(self._list.RemovePrefix('aprefix', 8, 1))
+ entry = sblist.AddEntry('aprefix', 3)
+ self.assertSameElements([entry], self._list.GetPrefixMatches('aprefix'))
+ # Remove second prefix.
+ self.assertTrue(self._list.RemovePrefix('aprefix', 8, 3))
+ self.assertSameElements([], self._list.GetPrefixMatches('aprefix'))
+
+ def testDeleteAddChunk(self):
+ # Delete add chunk that does not exist.
+ self.assertFalse(self._list.DeleteAddChunk(11))
+ # Delete empty add chunk
+ self.assertTrue(self._list.DeleteAddChunk(5))
+ self.assertFalse(self._list.GotAddChunk(5))
+ self.assertSameElements([], self._list.GetPrefixMatches('dprefix'))
+ # Delete normal add chunk
+ self.assertTrue(self._list.DeleteAddChunk(1))
+ self.assertFalse(self._list.GotAddChunk(1))
+ entry = sblist.AddEntry('aprefix', 3)
+ self.assertSameElements([entry], self._list.GetPrefixMatches('aprefix'))
+ self.assertSameElements([], self._list.GetPrefixMatches('cprefix'))
+
+ def testDeleteSubChunk(self):
+ # Delete sub chunk that does not exist.
+ self.assertFalse(self._list.DeleteSubChunk(8))
+ # Delete empty sub chunk.
+ self.assertTrue(self._list.DeleteSubChunk(7))
+ self.assertFalse(self._list.GotSubChunk(7))
+ # Delete non-empty sub chunk
+ self.assertTrue(self._list.DeleteSubChunk(2))
+ self.assertFalse(self._list.GotSubChunk(2))
+
+ def testDownloadRequest(self):
+ self.assertEqual('goog-malware-shavar;a:1-10:s:1-3,5-7',
+ self._list.DownloadRequest(False))
+ self.assertEqual('goog-malware-shavar;a:1-10:s:1-3,5-7:mac',
+ self._list.DownloadRequest(True))
+
+ # Make sure that this works properly on an empty list as well
+ list = sblist.List("empty-testing-list")
+ self.assertEqual('empty-testing-list;', list.DownloadRequest(False))
+ self.assertEqual('empty-testing-list;mac', list.DownloadRequest(True))
+
+
+ def testGetPrefixMatches(self):
+ self.assertSameElements([self._list.AddChunkMap()[9]['prefixaa'],
+ self._list.AddChunkMap()[10]['prefixaa']],
+ self._list.GetPrefixMatches('prefixaa'))
+ self.assertSameElements([self._list.AddChunkMap()[9]['prefixaa'],
+ self._list.AddChunkMap()[10]['prefixaa']],
+ self._list.GetPrefixMatches('prefixaaasdfasdf'))
+ self.assertSameElements([], self._list.GetPrefixMatches('prefixa'))
+ self.assertSameElements([self._list.AddChunkMap()[9]['prefixab']],
+ self._list.GetPrefixMatches('prefixabasdasdf'))
+
+
+class SubEntryTest(unittest.TestCase):
+ def testAccessors(self):
+ entry = sblist.SubEntry('hash_prefix', 1, 2)
+ self.assertEqual('hash_prefix', entry.Prefix())
+ self.assertEqual(1, entry.SubNum())
+ self.assertEqual(2, entry.AddNum())
+
+
+class AddEntryTest(unittest.TestCase):
+ def testSimple(self):
+ # Test with no full-hash.
+ entry = sblist.AddEntry('prefix', 1)
+ self.assertEqual('prefix', entry.Prefix())
+ self.assertEqual(None, entry.FullHash())
+ self.assertEqual(None, entry.GetHashTimestamp())
+ self.assertEqual(1, entry.AddChunkNum())
+ # Now set a full-hash and check that the accessors return the right thing.
+ entry.SetFullHash('fullhash', 42)
+ self.assertEqual('fullhash', entry.FullHash())
+ self.assertEqual(42, entry.GetHashTimestamp())
+
+ # Test with full-hash
+ entry = sblist.AddEntry('another_prefix', 2, 'fullhash')
+ self.assertEqual('another_prefix', entry.Prefix())
+ self.assertEqual('fullhash', entry.FullHash())
+ self.assertEqual(None, entry.GetHashTimestamp())
+ self.assertEqual(2, entry.AddChunkNum())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/server.py b/python/server.py
new file mode 100755
index 0000000..0050491
--- /dev/null
+++ b/python/server.py
@@ -0,0 +1,798 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Encapsulates interaction with the safebrowsing servers."""
+
+import sblist
+import util
+
+import base64
+import datetime
+import hmac
+import httplib
+import itertools
+import logging
+import re
+import sha
+import socket
+import StringIO
+import struct
+import sys
+import urllib2
+
+
+class Error(Exception):
+ def __init__(self, msg, original_error=None, *args, **kwargs):
+ Exception.__init__(self, msg, *args, **kwargs)
+ self._ServerError_original=original_error
+
+ def OriginalError(self):
+ return self._ServerError_original
+
+class ServerError(Error):
+ pass
+
+class ResponseError(Error):
+ pass
+
+def UrllibRequest(url, postdata):
+ return urllib2.urlopen(url, postdata)
+
+class Server(object):
+ """
+ This is the main interface to the Google Safe Browsing servers.
+
+ server = Server()
+ googlelists = server.GetLists()
+ download = server.Download(googlelists)
+ for googlelist, listops in download:
+ for listop in listops:
+ listop.Apply()
+ """
+
+ CLIENT = 'api'
+ APPVER = '1.0'
+ PVER = '2.2'
+
+ # Request types
+ LIST = 'list'
+ DOWNLOADS = 'downloads'
+ NEWKEY = 'newkey'
+ GETHASH = 'gethash'
+
+ MAC = re.compile(r'm:(.+)')
+ NEXT = re.compile(r'n:(\d+)')
+ PLEASEREKEY = re.compile(r'e:pleaserekey')
+ PLEASERESET = re.compile(r'r:pleasereset')
+ LISTRESP = re.compile(r'i:(.+)')
+
+ URLRESP = re.compile(r'u:(.+)')
+
+ ADDORSUB = re.compile(r'([as]):(\d+):(\d+):(\d+)')
+
+ ADDDELRESP = re.compile(r'ad:(.+)')
+ SUBDELRESP = re.compile(r'sd:(.+)')
+
+ # Bytes in a full length hash (full sha256).
+ FULLHASHLEN = 32
+
+ def __init__(self, hp, ssl_hp, base_path, clientkey=None, wrkey=None,
+ apikey=None, timeout=20, gethash_server=None,
+ url_request_function=UrllibRequest):
+ assert callable(url_request_function)
+ self._host, self._port = hp
+ self._ssl_host, self._ssl_port = ssl_hp
+ self._base_path = base_path
+ self._base_qry = 'client=%s&appver=%s&pver=%s' % (
+ Server.CLIENT, Server.APPVER, Server.PVER)
+ if gethash_server is None:
+ self._gethash_host, self._gethash_port = hp
+ else:
+ self._gethash_host, self._gethash_port = gethash_server
+
+ # Unescaped client key.
+ self._clientkey = clientkey
+ self._wrkey = wrkey
+
+ self._apikey = apikey
+ self._timeout = timeout
+ self._url_request_function = url_request_function
+
+
+ def WillUseMac(self):
+ return self._wrkey is not None and self._clientkey is not None
+
+ def Rekey(self):
+ """
+ Get a new set of keys, replacing any existing keys. Returns (clientkey,
+ wrkey). The keys are stored in the Server object.
+ """
+ self._clientkey, self._wrkey = self._GetMacKeys()
+ return self.Keys()
+
+ def Keys(self):
+ """
+ Return (clientkey, wrkey).
+ """
+ return (self._clientkey, self._wrkey)
+
+ def GetLists(self):
+ """
+ Get the available blacklists. Returns a list of List objects.
+ """
+ resp = self._MakeRequest(Server.LIST, use_apikey=True)
+ mac = None
+ if self.WillUseMac():
+ mac = resp.readline().strip()
+ sbls = []
+ raw_data = []
+ for line in resp:
+ raw_data.append(line)
+ sbls.append(sblist.List(line.strip()))
+ resp.close()
+ self._CheckMac(mac, ''.join(raw_data))
+ return sbls
+
+ def Download(self, sbls, size_limit_bytes=None):
+ """
+ Download updates for safebrowsing Lists. sbls is a sequence of sblist.List
+ objects. size_limit_bytes specifies an approximate maximum to the number of
+ bytes of data we are willing to download. Returns a DownloadResponse object.
+ """
+ # Build the request.
+ req_lines = []
+ if size_limit_bytes is not None:
+ # Convert to kilobytes for the server.
+ size_limit_kb = int(size_limit_bytes / 1024)
+ if size_limit_kb == 0:
+ size_limit_kb = 1
+ req_lines.append('s;%d' % (size_limit_kb,))
+ for sbl in sbls:
+ dlreq = sbl.DownloadRequest(self.WillUseMac())
+ req_lines.append(dlreq)
+ req_lines.append('') # Terminating newline.
+
+ # Process the response.
+ linereader = LineReader(
+ self._MakeRequest(Server.DOWNLOADS,
+ postdata='\n'.join(req_lines),
+ use_apikey=True))
+ # Make DownloadResponse contain listops for each list, though no ops may
+ # be present. This is so that the client will know when the last time we
+ # made a request for that list.
+ dlresp = DownloadResponse(datetime.datetime.now())
+ for sbl in sbls:
+ dlresp.listops.setdefault(sbl.Name(), [])
+
+ line = linereader.ReadLine()
+ main_body_escaped_mac = None
+ if self.WillUseMac():
+ m = Server.MAC.match(line)
+ if not m:
+ raise ResponseError('Could not parse MAC for downloads: "%s"' % (line,))
+ main_body_escaped_mac = m.group(1)
+ logging.debug('Parsed main body MAC: "%s"', main_body_escaped_mac)
+ linereader.ClearLinesRead()
+ line = linereader.ReadLine()
+ m = Server.NEXT.match(line)
+ if not m:
+ raise ResponseError('Could not parse next for downloads: "%s"' % (line,))
+ try:
+ dlresp.min_delay_sec = int(m.group(1))
+ except ValueError, e:
+ raise ResponseError('Could not parse next for downloads: "%s"' % (line,))
+ active_sbl = None
+ sblist_map = dict([(l.Name(), l) for l in sbls])
+ logging.debug('valid list names: "%s"', ','.join(sblist_map.iterkeys()))
+ while linereader.ReadLine() != '':
+ line = linereader.LastLine().strip()
+ logging.debug('download response line: "%s"', line)
+
+ if Server.PLEASEREKEY.match(line):
+ dlresp.rekey = True
+ return dlresp
+
+ if Server.PLEASERESET.match(line):
+ dlresp.reset = True
+ return dlresp
+
+ m = Server.LISTRESP.match(line)
+ if m:
+ if not sblist_map.has_key(m.group(1)):
+ raise ResponseError('invalid list in response: "%s"' % (m.group(1),))
+ active_sbl = sblist_map[m.group(1)]
+ continue
+
+ if active_sbl is None:
+ raise ResponseError('no list set: "%s"' % (line,))
+
+ m = Server.URLRESP.match(line)
+ if m:
+ url = m.group(1)
+ mac = None
+ if self.WillUseMac():
+ trailing_comma_index = url.rfind(',')
+ mac = url[trailing_comma_index+1:]
+ url = url[:trailing_comma_index]
+ self._GetRedirect(active_sbl, url, dlresp, mac)
+ continue
+
+ m = Server.ADDDELRESP.match(line)
+ if m:
+ dlresp.listops[active_sbl.Name()].append(
+ AddDel(active_sbl, Server._GetSequence(m.group(1))))
+ continue
+
+ m = Server.SUBDELRESP.match(line)
+ if m:
+ dlresp.listops[active_sbl.Name()].append(
+ SubDel(active_sbl, Server._GetSequence(m.group(1))))
+ continue
+
+ # Clients are supposed to ignore unrecognized command keywords.
+ logging.info('Unrecognized response line: "%s"', line)
+
+ # Check the main body MAC.
+ self._CheckMac(main_body_escaped_mac, ''.join(linereader.LinesRead()))
+ return dlresp
+
+ def GetAllFullHashes(self, prefixes):
+ """Get full length hashes for all prefixes in prefx. If prefixes are
+ not all of the same length we have to do multiple gethash requests.
+ Returns a merged GetHashResponse.
+ """
+ prefix_sizes = {} # prefix length -> list of prefixes.
+ for prefix in prefixes:
+ prefix_sizes.setdefault(len(prefix), []).append(prefix)
+
+ response = GetHashResponse(datetime.datetime.now())
+ for prefix_length, prefix_list in prefix_sizes.iteritems():
+ ghresp = self.GetFullHashes(prefix_list, prefix_length)
+ logging.debug('gethash response: %s', ghresp)
+ if ghresp.rekey:
+ self.Rekey()
+ # Try request again once we rekeyed.
+ ghresp = self.GetFullHashes(prefix_list, prefix_length)
+ if ghresp.rekey:
+ raise Error('cannot get a valid key')
+ response.MergeWith(ghresp)
+ return response
+
+ def GetFullHashes(self, prefixes, prefix_length):
+ """
+ Get the full length hashes that correspond to prefixes. prefixes is a
+ list of the prefixes to look up. All prefixes must have a length equal
+ to prefix_length.
+ Returns a GetHashResponse.
+ """
+ ghresp = GetHashResponse(datetime.datetime.now())
+ if len(prefixes) == 0:
+ # Empty response for empty input.
+ return ghresp
+ for pre in prefixes:
+ if len(pre) != prefix_length:
+ raise Error('All prefixes must have length: %d' % prefix_length)
+
+ try:
+ resp = self._MakeRequest(
+ Server.GETHASH,
+ postdata='%d:%d\n%s' % (prefix_length, prefix_length * len(prefixes),
+ ''.join(prefixes)),
+ use_apikey=True,
+ hp=(self._gethash_host, self._gethash_port))
+ except ServerError, e:
+ orig = e.OriginalError()
+ if hasattr(orig, 'code') and orig.code == httplib.NO_CONTENT:
+ # No Content is not an error. Return an empty response.
+ return ghresp
+ else:
+ # Re-raise for other errors.
+ raise e
+
+ mac = None
+ if self._wrkey is not None:
+ line = resp.readline().rstrip()
+ if Server.PLEASEREKEY.match(line):
+ ghresp.rekey = True
+ return ghresp
+ mac = line
+
+ raw_data = []
+ for line in resp:
+ raw_data.append(line)
+ bad_header = ResponseError('gethash: bad hashentry header: "%s"' % (
+ line,))
+ spl = line.rstrip().split(':')
+ if len(spl) != 3:
+ raise bad_header
+ listname, addchunk, hashdatalen = spl
+ try:
+ addchunk = int(addchunk)
+ hashdatalen = int(hashdatalen)
+ except ValueError:
+ raise bad_header
+ datareader = BlockReader(hashdatalen, resp)
+ while not datareader.End():
+ ghresp.listmap.setdefault(listname, {}).setdefault(addchunk, set()).add(
+ datareader.Read(Server.FULLHASHLEN))
+ raw_data.extend(datareader.DataList())
+ # Verify the MAC.
+ self._CheckMac(mac, ''.join(raw_data))
+ return ghresp
+
+ def _CheckMac(self, escaped_mac, data):
+ """
+ Raise a ResponseError if the MAC is not valid.
+ """
+ if not self.WillUseMac() or escaped_mac is None:
+ return
+ try:
+ computed_mac = hmac.new(self._clientkey, data, sha).digest()
+ given_mac = base64.urlsafe_b64decode(escaped_mac)
+ except Exception, e:
+ logging.exception(e)
+ raise ResponseError('Bad MAC: %s' % (e,), e)
+ if computed_mac != given_mac:
+ raise ResponseError('Bad MAC. Computed: "%s", received: "%s"' % (
+ base64.urlsafe_b64encode(computed_mac), escaped_mac))
+
+ def _MakeRequest(self, path, postdata=None, hp=None, use_wrkey=True,
+ use_apikey=False, extra_params="", protocol="http"):
+ if hp is None:
+ hp = (self._host, self._port)
+
+ wrkey = ''
+ if use_wrkey and self._wrkey is not None:
+ wrkey = '&wrkey=%s' % self._wrkey
+ apikey_param = ''
+ if use_apikey and self._apikey:
+ apikey_param = '&apikey=' + self._apikey
+ url = '%s://%s:%d%s/%s?%s%s%s%s' % (
+ protocol, hp[0], hp[1], self._base_path,
+ path, self._base_qry, wrkey, apikey_param, extra_params)
+ logging.debug('http url: "%s"', url)
+ try:
+ resp = self._url_request_function(url, postdata)
+ except Exception, e:
+ raise ServerError('%s failed: %s' % (path, e), original_error=e)
+ return resp
+
+ def _GetMacKeys(self):
+ """
+ Request a new key from the server.
+ """
+ resp = self._MakeRequest(Server.NEWKEY,
+ hp = (self._ssl_host, self._ssl_port),
+ protocol = 'https')
+ clientkey = None
+ wrkey = None
+ for line in resp:
+ split = line.split(':')
+ if len(split) != 3:
+ raise ResponseError('newkey: "%s"' % (line,))
+ try:
+ length = int(split[1])
+ except ValueError:
+ raise ResponseError('newkey: "%s"' % (line,))
+ if len(split[2]) < length:
+ raise ResponseError('newkey: "%s"' % (line,))
+ if split[0] == 'clientkey':
+ try:
+ clientkey = split[2][:length]
+ clientkey = base64.urlsafe_b64decode(clientkey)
+ except TypeError:
+ raise ResponseError('could not decode clientkey: "%s", "%s"' % (
+ line, clientkey))
+ elif split[0] == 'wrappedkey':
+ wrkey = split[2][:length]
+ else:
+ raise ResponseError('newkey: "%s"' % (line,))
+ resp.close()
+ if clientkey is None or wrkey is None:
+ raise ResponseError('response is missing wrappedkey or clientkey')
+ return clientkey, wrkey
+
+ def _GetRedirect(self, sbl, redirect, dlresp, mac=None, extra_params=''):
+ """
+ sbl: Safebrowsing list object that we append data from the redirect to.
+ redirect: URL to fetch with HTTP. Should not include the http:// part.
+ dlresp: DownloadResponse object. The results of the redirect are stored
+ here.
+ mac: If set, the mac to verify the redirect request with.
+ extra_params: string to use as CGI args for the redirect request.
+ """
+ url = 'http://%s%s' % (redirect, extra_params)
+ logging.debug('Getting redirect: "%s"', url)
+ try:
+ resp = self._url_request_function(url, None)
+ except Exception, e:
+ raise ServerError('Redirect to "%s" failed: %s' % (url, e),
+ original_error=e)
+
+ # Verify mac
+ if mac:
+ total_response = resp.read()
+ self._CheckMac(mac, total_response)
+ resp = StringIO.StringIO(total_response)
+
+ # Get the chunks.
+ empty_adds = []
+ empty_subs = []
+ for line in resp:
+ line = line.strip()
+ if line == '':
+ continue
+ bad_header = ResponseError('bad add or sub header: "%s"' % (line,))
+ m = Server.ADDORSUB.match(line)
+ if not m:
+ raise bad_header
+ typechar = m.group(1)
+ try:
+ chunknum = int(m.group(2))
+ prefixlen = int(m.group(3))
+ chunklen = int(m.group(4))
+ except ValueError:
+ raise bad_header
+ logging.debug('chunk header: "%s"', line)
+ reader = BlockReader(chunklen, resp)
+ if typechar == 'a':
+ if chunklen == 0:
+ empty_adds.append(chunknum)
+ continue
+ else:
+ chunk = AddChunk(sbl, chunknum, prefixlen, reader)
+ elif typechar == 's':
+ if chunklen == 0:
+ empty_subs.append(chunknum)
+ continue
+ else:
+ chunk = SubChunk(sbl, chunknum, prefixlen, reader)
+ else:
+ raise bad_header
+ dlresp.listops.setdefault(sbl.Name(), []).append(chunk)
+
+ if empty_adds:
+ chunk = EmptyAddChunks(sbl, empty_adds)
+ dlresp.listops.setdefault(sbl.Name(), []).append(chunk)
+ if empty_subs:
+ chunk = EmptySubChunks(sbl, empty_subs)
+ dlresp.listops.setdefault(sbl.Name(), []).append(chunk)
+
+ @staticmethod
+ def _GetSequence(seq_str):
+ # TODO: This doesn't check for errors like overlap and invalid ranges.
+ ranges = seq_str.split(',')
+ iters = []
+ ex = ResponseError('invalid sequence: "%s"' % (seq_str,))
+ for r in ranges:
+ low_high = r.split('-')
+ if len(low_high) == 1:
+ try:
+ x = int(low_high[0])
+ except ValueError:
+ raise ex
+ iters.append(xrange(x, x + 1))
+ elif len(low_high) == 2:
+ try:
+ l = int(low_high[0])
+ h = int(low_high[1])
+ except ValueError:
+ raise ex
+ iters.append(xrange(l, h + 1))
+ else:
+ raise ex
+ return ChunkSequence(iters)
+
+
+class DownloadResponse(object):
+ """
+ timestamp: A datetime object that marks the time of this transaction.
+ min_delay_sec: Number of seconds clients should wait before downloading again.
+ listops: A dict mapping listnames to lists of ListOps.
+ rekey: True iff client should request a new set of keys (see Server.Rekey()).
+ reset: True iff client should clear all list data.
+ """
+ def __init__(self, timestamp):
+ self.timestamp = timestamp
+ self.min_delay_sec = None
+ self.listops = {}
+ self.rekey = False
+ self.reset = False
+
+
+class GetHashResponse(object):
+ """
+ listmap: {<listname> : {<addchunknum> : <set of hashes>}}
+ """
+ def __init__(self, timestamp):
+ self.timestamp = timestamp
+ self.rekey = False
+ self.listmap = {}
+
+ def MergeWith(self, gethash_response):
+ self.rekey = gethash_response.rekey or self.rekey
+ for listname in gethash_response.listmap:
+ addchunks = self.listmap.setdefault(listname, {})
+ for chunk, hashes in gethash_response.listmap[listname].iteritems():
+ addchunks[chunk] = addchunks.get(chunk, set()).union(hashes)
+
+ def __str__(self):
+ def cmp_first(a, b):
+ return cmp(a[0], b[0])
+
+ listmap_str = ''
+ listmap = sorted(self.listmap.items(), cmp=cmp_first)
+ for listname, chunk_set in listmap:
+ listmap_str += '\t\t%s:\n' % (listname,)
+ for chunknum, prefixes in sorted(chunk_set.items(), cmp=cmp_first):
+ listmap_str += '\t\t%d: %s\n' % (
+ chunknum, ', '.join(
+ [util.Bin2Hex(pre) for pre in prefixes]))
+ return 'GetHashResponse:\n\trekey: %s\n\tlistmap:\n%s' % (
+ self.rekey, listmap_str)
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class ChunkSequence(object):
+ """
+ A repeatable iterator over a list of chunk ranges.
+ """
+ def __init__(self, iters):
+ self._iters = iters
+
+ def __iter__(self):
+ return itertools.chain(*self._iters)
+
+ def __str__(self):
+ return ','.join([str(x) for x in self])
+
+
+class BlockReader(object):
+ """
+ A BlockReader allows reading at most maxbytes from the given file.
+ """
+ def __init__(self, maxbytes, fh):
+ self._maxbytes = maxbytes
+ self._fh = fh
+ self._consumed = 0
+ # List of strings representing all data read.
+ self._data = []
+
+ def Consumed(self):
+ """
+ Return number of bytes that have been read.
+ """
+ return self._consumed
+
+ def DataList(self):
+ return self._data
+
+ def Read(self, n):
+ """
+ Read n bytes and return as a string.
+ """
+ if self._consumed + n > self._maxbytes:
+ raise Error('attempt to read more than %s bytes (%s)' %
+ (self._maxbytes, self._consumed + n))
+ s = self._fh.read(n)
+ self._consumed += len(s)
+ self._data.append(s)
+ if len(s) != n:
+ raise ResponseError('unable to read %d bytes' % (n,))
+ return s
+
+ def ReadChunkNum(self):
+ """
+ Read a chunk number encoded as a 32-bit network byte order value and return
+ a long.
+ """
+ numbin = self.Read(4)
+ num = struct.unpack('>L', numbin)[0]
+ return num
+
+ def ReadPrefixCount(self):
+ """
+ Read the 1-byte per-hostkey prefix count and return an int.
+ """
+ count = struct.unpack('B', self.Read(1))[0]
+ return count
+
+ def ReadHostKey(self):
+ """
+ Read the 4-byte hostkey and return a str.
+ """
+ hk = self.Read(4)
+ return hk
+
+ def End(self):
+ return self._consumed >= self._maxbytes
+
+
+class LineReader(object):
+ def __init__(self, fh):
+ self._fh = fh
+ self._data = []
+
+ def ReadLine(self):
+ """
+ Return the line read, or empty string at end of file.
+ """
+ line = self._fh.readline()
+ self._data.append(line)
+ return line
+
+ def LinesRead(self):
+ return self._data
+
+ def ClearLinesRead(self):
+ self._data = []
+
+ def LastLine(self):
+ if len(self._data) == 0:
+ raise ResponseError('no line read')
+ return self._data[-1]
+
+
+class ListOp(object):
+ def Apply(self):
+ """
+ Apply the changes from this ListOp.
+ """
+ raise NotImplementedError
+
+
+class AddChunk(ListOp):
+ def __init__(self, sbl, chunknum, prefixlen, reader):
+ self._sbl = sbl
+ self._chunknum = chunknum
+ self._prefixlen = prefixlen
+ self._prefixes = []
+ while not reader.End():
+ # We read the hostkey, but ignore it unless it specifies a whole host
+ # block. We don't need the locality it can provide since we are not trying
+ # to optimize this implementation.
+ hostkey = reader.ReadHostKey()
+ numkeys = reader.ReadPrefixCount()
+ if numkeys == 0:
+ # lack of prefix means that the hostkey is the prefix.
+ self._prefixes.append(hostkey)
+ else:
+ for i in xrange(0, numkeys):
+ self._prefixes.append(reader.Read(prefixlen))
+
+ def Apply(self):
+ if self._sbl.GotAddChunk(self._chunknum):
+ if not len(self._prefixes):
+ # This might apply an empty chunk over an empty chunk, but that
+ # shouldn't matter.
+ logging.debug('Applying empty add chunk over current chunk')
+ else:
+ # A chunk should always be the same after it's created until it's
+ # emptied, so this is safe to ignore.
+ logging.debug('Recieved duplicate chunk, ignoring')
+ return
+ assert len(self._prefixes), \
+ 'AddChunk objects should only be created for non-empty chunks'
+ for prefix in self._prefixes:
+ self._sbl.AddPrefix(prefix, self._chunknum)
+
+ def __str__(self):
+ return 'AddChunk %d, list %s: %d prefixes' % (
+ self._chunknum, self._sbl.Name(), len(self._prefixes))
+
+
+class SubChunk(ListOp):
+ def __init__(self, sbl, chunknum, prefixlen, reader):
+ self._sbl = sbl
+ self._chunknum = chunknum
+ self._prefixlen = prefixlen
+ self._prefixes = []
+ while not reader.End():
+ # We read the hostkey, but ignore it unless it specifies a whole host
+ # block. We don't need the locality it can provide since we are not trying
+ # to optimize this implementation.
+ hostkey = reader.ReadHostKey()
+ numkeys = reader.ReadPrefixCount()
+ if numkeys == 0:
+ # No prefix means that they hostkey is the prefix.
+ self._prefixes.append((hostkey, reader.ReadChunkNum()))
+ else:
+ for i in xrange(0, numkeys):
+ addchunknum = reader.ReadChunkNum()
+ prefix = reader.Read(prefixlen)
+ self._prefixes.append((prefix, addchunknum))
+
+ def Apply(self):
+ if self._sbl.GotSubChunk(self._chunknum):
+ if not len(self._prefixes):
+ logging.debug('Applying empty sub chunk over current chunk')
+ else:
+ logging.debug('Recieved duplicate chunk, ignoring')
+ assert len(self._prefixes), \
+ 'SubChunk objects should only be created for non-empty chunks'
+ for prefix, addchunknum in self._prefixes:
+ self._sbl.RemovePrefix(prefix, self._chunknum, addchunknum)
+
+ def __str__(self):
+ return 'SubChunk %d, list %s: %d prefixes' % (
+ self._chunknum, self._sbl.Name(), len(self._prefixes))
+
+
+class AddDel(ListOp):
+ def __init__(self, sbl, chunknums):
+ """
+ chunknums: a sequence of chunk numbers.
+ """
+ self._sbl = sbl
+ self._chunknums = chunknums
+
+ def Apply(self):
+ for num in self._chunknums:
+ self._sbl.DeleteAddChunk(num)
+
+
+ def __str__(self):
+ return 'AddDel: (%s, %s)' % (self._sbl.Name(), self._chunknums)
+
+
+class SubDel(ListOp):
+ def __init__(self, sbl, chunknums):
+ """
+ chunknums: a sequence of chunk numbers.
+ """
+ self._sbl = sbl
+ self._chunknums = chunknums
+
+ def Apply(self):
+ for num in self._chunknums:
+ self._sbl.DeleteSubChunk(num)
+
+ def __str__(self):
+ return 'SubDel: (%s, %s)' % (self._sbl.Name(), self._chunknums)
+
+
+class EmptyAddChunks(ListOp):
+ """Applies a series of empty add chunks to a List object."""
+
+ def __init__(self, sbl, chunknums):
+ """chunknums: a sequence of chunk numbers."""
+ self._sbl = sbl
+ self._chunknums = chunknums
+
+ def Apply(self):
+ for num in self._chunknums:
+ self._sbl.AddEmptyAddChunk(num)
+
+ def __str__(self):
+ return 'EmptyAddChunks: (%s, %s)' % (self._sbl.Name(), self._chunknums)
+
+
+class EmptySubChunks(ListOp):
+ """Applies a series of empty sub chunks to a List object."""
+
+ def __init__(self, sbl, chunknums):
+ """chunknums: a sequence of chunk numbers."""
+ self._sbl = sbl
+ self._chunknums = chunknums
+
+ def Apply(self):
+ for num in self._chunknums:
+ self._sbl.AddEmptySubChunk(num)
+
+ def __str__(self):
+ return 'EmptySubChunks: (%s, %s)' % (self._sbl.Name(), self._chunknums)
diff --git a/python/server_test.py b/python/server_test.py
new file mode 100755
index 0000000..1b688eb
--- /dev/null
+++ b/python/server_test.py
@@ -0,0 +1,273 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""These tests run against the prod safebrowsing servers."""
+
+import sblist
+import server
+
+import base64
+import cgi
+import hmac
+import logging
+import sha
+import StringIO
+import sys
+import unittest
+import urlparse
+
+
+class HTTPError(Exception):
+ """Fake HTTP error used to test gethash requests that return a 204."""
+ def __init__(self, code):
+ self.code = code
+
+ def __str__(self):
+ return 'HTTPError code:%d' % self.code
+
+
+class FakeServer(object):
+ """Helper class which simulates the SafeBrowsing server."""
+ def __init__(self):
+ # array of (url prefix, exp. params, request, response data or exception)
+ self._responses = []
+
+ def SetResponse(self, url_prefix, params, request, response):
+ """Set a fake response for a particular request.
+
+ If a request comes in to the fake server with a url that matches the given
+ url_prefix and a request body that matches the given request body and the
+ given params are a subset of the request CGI arguments the fake server will
+ serve the given baked response or raise an exception if the response is an
+ exception object.
+
+ Args:
+ url_prefix: url prefix that has to match for the response to be sent.
+ params: sub-set of CGI parameters that have to be present for the request
+ to be valid and the response to be sent. Can be None.
+ Format: [(arg1, value1), (arg2, value2), ...].
+ request: request body that has to be set for the response to be sent.
+ Can be None if no request body is expected for a particular
+ request.
+ response: Response data to send or exception to raise if the conditions
+ above are met.
+ """
+ self._responses.append((url_prefix, params, request, response))
+
+ def _HasExpectedParams(self, url, expected_params):
+ """Returns true if the expected CGI parameters are set in the given URL."""
+ if expected_params:
+ actual_params = cgi.parse_qs(urlparse.urlparse(url)[4])
+ for key, value in expected_params:
+ if key not in actual_params or actual_params[key][0] != value:
+ return False
+ return True
+
+ def HandleRequest(self, url, data):
+ for url_prefix, params, request, response in self._responses:
+ if (url.startswith(url_prefix) and
+ request == data and
+ self._HasExpectedParams(url, params)):
+ if isinstance(response, Exception):
+ raise response
+ else:
+ return StringIO.StringIO(response)
+ raise Exception('No such data: %s' % url)
+
+
+class ServerTest(unittest.TestCase):
+ def setUp(self):
+ self._fake_sb_server = FakeServer()
+ self._server = server.Server(
+ ('safebrowsing.clients.google.com', 80),
+ ('sb-ssl.google.com', 443),
+ '/safebrowsing',
+ clientkey="BOGUS_CLIENT_KEY",
+ wrkey="BOGUS_WRAPPED_KEY",
+ apikey="BOGUS_API_KEY",
+ url_request_function=self._fake_sb_server.HandleRequest)
+
+ self._base_url = 'http://safebrowsing.clients.google.com:80/safebrowsing'
+
+ def _Mac(self, data):
+ clientkey, wrkey = self._server.Keys()
+ return base64.urlsafe_b64encode(hmac.new(clientkey, data, sha).digest())
+
+ def testGetLists(self):
+ response = 'lista\nlistb\nlistc'
+ response = '%s\n%s' % (self._Mac(response), response)
+ self._fake_sb_server.SetResponse(url_prefix='%s/list?' % self._base_url,
+ params=[('wrkey', 'BOGUS_WRAPPED_KEY'),
+ ('apikey', 'BOGUS_API_KEY')],
+ request=None,
+ response=response)
+ self.assertEqual(['lista', 'listb', 'listc'],
+ [l.Name() for l in self._server.GetLists()])
+
+ def testKeys(self):
+ self.assertEqual(('BOGUS_CLIENT_KEY', 'BOGUS_WRAPPED_KEY'),
+ self._server.Keys())
+
+ def testRekey(self):
+ self._fake_sb_server.SetResponse(
+ url_prefix='https://sb-ssl.google.com:443/safebrowsing/newkey?',
+ params=None,
+ request=None,
+ response=('clientkey:28:TkVXX0JPR1VTX0NMSUVOVF9LRVk=\n' +
+ 'wrappedkey:15:NEW_BOGUS_WRKEY'))
+ self.assertEqual(('NEW_BOGUS_CLIENT_KEY', 'NEW_BOGUS_WRKEY'),
+ self._server.Rekey())
+ self.assertEqual(('NEW_BOGUS_CLIENT_KEY', 'NEW_BOGUS_WRKEY'),
+ self._server.Keys())
+
+ def testDownload(self):
+ # First we setup the redirect requests.
+ lista_a_redirect_response = ('a:10:4:27\n' +
+ '1234\x00' +
+ '5678\x01ABCD' +
+ 'EFGH\x02EFGHIJKL' +
+ # Empty add chunk.
+ 'a:7:4:0\n')
+ lista_s_redirect_response = ('s:2:4:22\n' +
+ '5678\x01\x00\x00\x00\x0AABCD' +
+ # Special case where there is no prefix.
+ 'EFGH\x00\x00\x00\x00\x0A' +
+ # Empty sub chunk
+ 's:3:4:0\n')
+ listb_a_redirect_response = (
+ 'a:1:6:1546\n' +
+ # Test an edge case where there are more than 255 entries for
+ # the same host key
+ '1234\xFF%s' % ''.join(map(str, range(100000, 100255))) +
+ '1234\x01100255')
+
+ self._fake_sb_server.SetResponse(
+ url_prefix='http://rd.com/lista-a',
+ params=None,
+ request=None,
+ response=lista_a_redirect_response)
+ self._fake_sb_server.SetResponse(
+ url_prefix='http://rd.com/lista-s',
+ params=None,
+ request=None,
+ response=lista_s_redirect_response)
+ self._fake_sb_server.SetResponse(
+ url_prefix='http://rd.com/listb-a',
+ params=None,
+ request=None,
+ # Make sure we can handle prefixes that are >4B.
+ response=listb_a_redirect_response)
+
+ response = '\n'.join(['n:1800',
+ 'i:lista',
+ 'u:rd.com/lista-s,%s' %
+ self._Mac(lista_s_redirect_response),
+ 'u:rd.com/lista-a,%s' %
+ self._Mac(lista_a_redirect_response),
+ 'ad:1-2,4-5,7',
+ 'i:listb',
+ 'u:rd.com/listb-a,%s' %
+ self._Mac(listb_a_redirect_response),
+ 'sd:2-6'])
+ self._fake_sb_server.SetResponse(
+ url_prefix='%s/downloads?' % self._base_url,
+ params=[('wrkey', 'BOGUS_WRAPPED_KEY'),
+ ('apikey', 'BOGUS_API_KEY')],
+ request='s;%d\nlista;mac\nlistb;mac\n' % (1<<10),
+ response='m:%s\n%s' % (self._Mac(response), response))
+
+ # Perform the actual download request.
+ sblists = [sblist.List('lista'), sblist.List('listb')]
+ response = self._server.Download(sblists, 1<<20)
+
+ #### Test that the download response contains the correct list ops ####
+ self.assertEqual(1800, response.min_delay_sec)
+ self.assertFalse(response.rekey)
+ self.assertFalse(response.reset)
+ self.assertEqual(['lista', 'listb'], response.listops.keys())
+
+ self.assertTrue(isinstance(response.listops['lista'][0], server.SubChunk))
+ self.assertEqual(2, response.listops['lista'][0]._chunknum)
+ self.assertEqual(4, response.listops['lista'][0]._prefixlen)
+ self.assertEqual([('ABCD', 10), ('EFGH', 10)],
+ response.listops['lista'][0]._prefixes)
+
+ self.assertTrue(isinstance(response.listops['lista'][1],
+ server.EmptySubChunks))
+ self.assertEqual([3], response.listops['lista'][1]._chunknums)
+
+ self.assertTrue(isinstance(response.listops['lista'][2], server.AddChunk))
+ self.assertEqual(10, response.listops['lista'][2]._chunknum)
+ self.assertEqual(4, response.listops['lista'][2]._prefixlen)
+ self.assertEqual(['1234', 'ABCD', 'EFGH', 'IJKL'],
+ response.listops['lista'][2]._prefixes)
+
+ self.assertTrue(isinstance(response.listops['lista'][3],
+ server.EmptyAddChunks))
+ self.assertEqual([7], response.listops['lista'][3]._chunknums)
+
+ self.assertTrue(isinstance(response.listops['lista'][4], server.AddDel))
+ self.assertEqual([1, 2, 4, 5, 7],
+ list(response.listops['lista'][4]._chunknums))
+
+ self.assertTrue(isinstance(response.listops['listb'][0], server.AddChunk))
+ self.assertEqual(1, response.listops['listb'][0]._chunknum)
+ self.assertEqual(6, response.listops['listb'][0]._prefixlen)
+ self.assertEqual(map(str, range(100000, 100256)),
+ response.listops['listb'][0]._prefixes)
+
+ self.assertTrue(isinstance(response.listops['listb'][1], server.SubDel))
+ self.assertEqual([2, 3, 4, 5, 6],
+ list(response.listops['listb'][1]._chunknums))
+
+ def testGetFullHashes(self):
+ response = 'lista:123:32\n89AB%s' % (28 * 'A')
+ self._fake_sb_server.SetResponse(
+ url_prefix='%s/gethash?' % self._base_url,
+ params=[('wrkey', 'BOGUS_WRAPPED_KEY'),
+ ('apikey', 'BOGUS_API_KEY')],
+ request='4:12\n0123456789AB',
+ response='%s\n%s' % (self._Mac(response), response))
+
+ self._fake_sb_server.SetResponse(
+ url_prefix='%s/gethash?' % self._base_url,
+ params=[('wrkey', 'BOGUS_WRAPPED_KEY'),
+ ('apikey', 'BOGUS_API_KEY')],
+ request='10:10\n0123456789',
+ response=HTTPError(204))
+
+ resp = self._server.GetFullHashes(['0123', '4567', '89AB'], 4)
+ self.assertTrue(isinstance(resp, server.GetHashResponse))
+ self.assertFalse(resp.rekey)
+ self.assertEqual({'lista': { 123: set(['89AB%s' % (28 * 'A')])}},
+ resp.listmap)
+
+ resp = self._server.GetFullHashes(['0123456789'], 10)
+ self.assertTrue(isinstance(resp, server.GetHashResponse))
+ self.assertFalse(resp.rekey)
+ self.assertEqual({}, resp.listmap)
+
+ def testGetSequence(self):
+ chunkseq = server.Server._GetSequence('1-2,4-5,7-10,11')
+ expected = [1, 2, 4, 5, 7, 8, 9, 10, 11]
+ # Should be able to iterate over chunkseq multiple times.
+ for i in xrange(0, 5):
+ self.assertEqual(expected, list(chunkseq))
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/python/util.py b/python/util.py
new file mode 100755
index 0000000..ccbf5f1
--- /dev/null
+++ b/python/util.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python2.5
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common utilities.
+"""
+
+import hashlib
+import struct
+
+
+def Bin2Hex(hash):
+ hexchars = []
+ for i in struct.unpack('%dB' % (len(hash),), hash):
+ hexchars.append('%02x' % (i,))
+ return ''.join(hexchars)
+
+def GetHash256(expr):
+ return hashlib.sha256(expr).digest()
+
+def IsFullHash(expr):
+ return len(expr) == 32
diff --git a/testing/LICENSE b/testing/LICENSE
deleted file mode 100644
index 229ff03..0000000
--- a/testing/LICENSE
+++ /dev/null
@@ -1,13 +0,0 @@
-Copyright 2009 Google Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
diff --git a/testing/external_test_pb2.py b/testing/external_test_pb2.py
deleted file mode 100755
index 3b86228..0000000
--- a/testing/external_test_pb2.py
+++ /dev/null
@@ -1,202 +0,0 @@
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-
-from google.protobuf import descriptor
-from google.protobuf import message
-from google.protobuf import reflection
-from google.protobuf import service
-from google.protobuf import service_reflection
-from google.protobuf import descriptor_pb2
-
-
-
-_CGIPARAM = descriptor.Descriptor(
- name='CGIParam',
- full_name='CGIParam',
- filename='external_test.proto',
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='Name', full_name='CGIParam.Name', index=0,
- number=1, type=12, cpp_type=9, label=2,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- descriptor.FieldDescriptor(
- name='Value', full_name='CGIParam.Value', index=1,
- number=2, type=12, cpp_type=9, label=2,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- ],
- extensions=[
- ],
- nested_types=[], # TODO(robinson): Implement.
- enum_types=[
- ],
- options=None)
-
-
-_REQUESTDATA = descriptor.Descriptor(
- name='RequestData',
- full_name='RequestData',
- filename='external_test.proto',
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='RequestPath', full_name='RequestData.RequestPath', index=0,
- number=1, type=12, cpp_type=9, label=2,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- descriptor.FieldDescriptor(
- name='Params', full_name='RequestData.Params', index=1,
- number=2, type=11, cpp_type=10, label=3,
- default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- descriptor.FieldDescriptor(
- name='PostData', full_name='RequestData.PostData', index=2,
- number=3, type=12, cpp_type=9, label=1,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- descriptor.FieldDescriptor(
- name='ServerResponse', full_name='RequestData.ServerResponse', index=3,
- number=4, type=12, cpp_type=9, label=2,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- ],
- extensions=[
- ],
- nested_types=[], # TODO(robinson): Implement.
- enum_types=[
- ],
- options=None)
-
-
-_HASHREQUEST = descriptor.Descriptor(
- name='HashRequest',
- full_name='HashRequest',
- filename='external_test.proto',
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='HashPrefix', full_name='HashRequest.HashPrefix', index=0,
- number=1, type=12, cpp_type=9, label=2,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- descriptor.FieldDescriptor(
- name='ServerResponse', full_name='HashRequest.ServerResponse', index=1,
- number=2, type=12, cpp_type=9, label=2,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- descriptor.FieldDescriptor(
- name='Expression', full_name='HashRequest.Expression', index=2,
- number=3, type=12, cpp_type=9, label=1,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- ],
- extensions=[
- ],
- nested_types=[], # TODO(robinson): Implement.
- enum_types=[
- ],
- options=None)
-
-
-_STEPDATA = descriptor.Descriptor(
- name='StepData',
- full_name='StepData',
- filename='external_test.proto',
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='Requests', full_name='StepData.Requests', index=0,
- number=1, type=11, cpp_type=10, label=3,
- default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- descriptor.FieldDescriptor(
- name='Hashes', full_name='StepData.Hashes', index=1,
- number=2, type=11, cpp_type=10, label=3,
- default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- ],
- extensions=[
- ],
- nested_types=[], # TODO(robinson): Implement.
- enum_types=[
- ],
- options=None)
-
-
-_TESTDATA = descriptor.Descriptor(
- name='TestData',
- full_name='TestData',
- filename='external_test.proto',
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='Steps', full_name='TestData.Steps', index=0,
- number=1, type=11, cpp_type=10, label=3,
- default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- descriptor.FieldDescriptor(
- name='ClientKey', full_name='TestData.ClientKey', index=1,
- number=2, type=12, cpp_type=9, label=2,
- default_value="",
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- ],
- extensions=[
- ],
- nested_types=[], # TODO(robinson): Implement.
- enum_types=[
- ],
- options=None)
-
-
-_REQUESTDATA.fields_by_name['Params'].message_type = _CGIPARAM
-_STEPDATA.fields_by_name['Requests'].message_type = _REQUESTDATA
-_STEPDATA.fields_by_name['Hashes'].message_type = _HASHREQUEST
-_TESTDATA.fields_by_name['Steps'].message_type = _STEPDATA
-
-class CGIParam(message.Message):
- __metaclass__ = reflection.GeneratedProtocolMessageType
- DESCRIPTOR = _CGIPARAM
-
-class RequestData(message.Message):
- __metaclass__ = reflection.GeneratedProtocolMessageType
- DESCRIPTOR = _REQUESTDATA
-
-class HashRequest(message.Message):
- __metaclass__ = reflection.GeneratedProtocolMessageType
- DESCRIPTOR = _HASHREQUEST
-
-class StepData(message.Message):
- __metaclass__ = reflection.GeneratedProtocolMessageType
- DESCRIPTOR = _STEPDATA
-
-class TestData(message.Message):
- __metaclass__ = reflection.GeneratedProtocolMessageType
- DESCRIPTOR = _TESTDATA
-
diff --git a/testing/safebrowsing_test_server.py b/testing/safebrowsing_test_server.py
deleted file mode 100755
index d730f2a..0000000
--- a/testing/safebrowsing_test_server.py
+++ /dev/null
@@ -1,392 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2009 Google Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Test server for Safebrowsing protocol v2.
-
-To test an implementation of the safebrowsing protocol, this server should
-be run on the same machine as the client implementation. The client should
-connect to this server at localhost:port where port is specified as a command
-line flag (--port) and perform updates normally, except that each request
-should have an additional CGI param "test_step" that specifies which update
-request this is for the client. That is, it should be incremented after the
-complete parsing of a downloads request so a downloads request and its
-associated redirects should all have the same test_step. The client should
-also make a newkey request and then a getlists requests before making the
-first update request and should use test_step=1 for these requests (test_step
-is 1 indexed). When the client believes that it is done with testing (because
-it recieves a response from an update request with no new data), it should
-make a "/test_complete" request. This will return either "yes" or "no" if the
-test is complete or not.
-"""
-
-__author__ = 'gcasto@google.com (Garrett Casto)'
-
-import BaseHTTPServer
-import binascii
-import base64
-import cgi
-import hmac
-from optparse import OptionParser
-import re
-import sha
-import sys
-from threading import Timer
-import time
-import urlparse
-
-import external_test_pb2
-
-DEFAULT_PORT = 40101
-DEFAULT_DATAFILE_LOCATION = "testing_input.dat"
-POST_DATA_KEY = "post_data"
-GETHASH_PATH = "/safebrowsing/gethash"
-RESET_PATH="/reset"
-DOWNLOADS_PATH = "/safebrowsing/downloads"
-TEST_COMPLETE_PATH = "/test_complete"
-DATABASE_VALIDATION_PATH = "/safebrowsing/verify_database"
-
-# Dict of step -> List of (request_path, param key, response)
-response_data_by_step = {}
-# Dict of step -> Dict of hash_prefix ->
-# (full length hashes responses, num times requested)
-hash_data_by_step = {}
-client_key = ''
-enforce_caching = False
-validate_database = True
-server_port = -1
-datafile_location = ''
-
-def EndServer():
- sys.exit(0)
-
-def CGIParamsToListOfTuples(cgi_params):
- return [(param.Name, param.Value) for param in cgi_params]
-
-def SortedTupleFromParamsAndPostData(params,
- post_data):
- """ Make a sorted tuple from the request such that it can be inserted as
- a key in a map. params is a list of (name, value) tuples and post_data is
- a string (or None).
- """
- if post_data:
- params.append((POST_DATA_KEY, tuple(sorted(post_data.split('\n')))))
- return tuple(sorted(params))
-
-def LoadData(filename):
- """ Load data from filename to be used by the testing server.
- """
- global response_data_by_step
- global client_key
- data_file = open(filename, 'rb')
- str_data = data_file.read()
- test_data = external_test_pb2.TestData()
- test_data.ParseFromString(str_data)
- print "Data Loaded"
- client_key = test_data.ClientKey
- step = 0
- for step_data in test_data.Steps:
- step += 1
- step_list = []
- for request_data in step_data.Requests:
- params_tuple = SortedTupleFromParamsAndPostData(
- CGIParamsToListOfTuples(request_data.Params),
- request_data.PostData)
- step_list.append((request_data.RequestPath,
- params_tuple,
- request_data.ServerResponse))
- response_data_by_step[step] = step_list
-
- hash_step_dict = {}
- for hash_request in step_data.Hashes:
- hash_step_dict[hash_request.HashPrefix] = (hash_request.ServerResponse,
- hash_request.Expression,
- 0)
- hash_data_by_step[step] = hash_step_dict
- print "Data Parsed"
-
-def VerifyTestComplete():
- """ Returns true if all the necessary requests have been made by the client.
- """
- global response_data_by_step
- global hash_data_by_step
- global enforce_caching
-
- complete = True
- for (step, step_list) in response_data_by_step.iteritems():
- if len(step_list):
- print ("Step %s has %d request(s) that were not made %s" %
- (step, len(step_list), step_list))
- complete = False
-
- for (step, hash_step_dict) in hash_data_by_step.iteritems():
- for (prefix,
- (response, expression, num_requests)) in hash_step_dict.iteritems():
- if ((enforce_caching and num_requests != 1) or
- num_requests == 0):
- print ("Hash prefix %s not requested the correct number of times"
- "(%d requests). Requests originated because of expression"
- " %s. Prefix is located in the following locations" %
- (binascii.hexlify(prefix),
- num_requests,
- expression))
- cur_index = 0
- while cur_index < len(response):
- end_header_index = response.find('\n', cur_index + 1)
- header = response[cur_index:end_header_index]
- (listname, chunk_num, hashdatalen) = header.split(":")
- print " List '%s' in add chunk num %d" % (listname, chunk_num)
- cur_index = end_header_index + hashdatalen + 1
-
- complete = False
-
- # TODO(gcasto): Have a check here that verifies that the client doesn't
- # make too many hash requests during the test run.
-
- return complete
-
-class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
- def ParamDictToListOfTuples(self, params):
- # params is a list cgi params to list of specified values. Since we never
- # expect a parameter to be specified multiple times, we just take the first
- # one.
- return [(name, value[0]) for (name, value) in params.iteritems()]
-
- def MakeParamKey(self, params, post_data=None):
- """ Make a lookup key from the request.
- """
- return SortedTupleFromParamsAndPostData(
- self.ParamDictToListOfTuples(params),
- post_data)
-
- def MACResponse(self, response, is_downloads_request):
- """ Returns the response wrapped with a MAC. Formatting will change
- if this is a downloads_request or hashserver_request.
- """
- unescaped_mac = hmac.new(client_key, response, sha).digest()
- return "%s%s\n%s" % (is_downloads_request and "m:" or "",
- base64.urlsafe_b64encode(unescaped_mac),
- response)
-
- def VerifyRequest(self, is_post_request):
- """ Verify that the request matches one loaded from the datafile and
- give the corresponding response. If there is no match, try and give a
- descriptive error message in the response.
- """
- parsed_url = urlparse.urlparse(self.path)
- path = parsed_url[2]
- params = cgi.parse_qs(parsed_url[4])
-
- step = params.get("test_step")
- if step is None or len(step) != 1:
- self.send_response(400)
- self.end_headers()
- print "No test step present."
- return
- step = int(step[0])
-
- if path == TEST_COMPLETE_PATH:
- self.send_response(200)
- self.end_headers()
- if VerifyTestComplete():
- self.wfile.write('yes')
- else:
- self.wfile.write('no')
- elif path == GETHASH_PATH:
- self.SynthesizeGethashResponse(step)
- elif path == RESET_PATH:
- LoadData(datafile_location)
- self.send_response(200)
- self.end_headers()
- self.wfile.write('done')
- else:
- self.GetCannedResponse(path, params, step, is_post_request)
-
- def SynthesizeGethashResponse(self, step):
- """ Create a gethash response. This will possibly combine an arbitrary
- number of hash requests from the protocol buffer.
- """
- global hash_data_by_step
-
- hashes_for_step = hash_data_by_step.get(step, {})
- if not hashes_for_step:
- self.send_response(400)
- self.end_headers()
- print "No response for step %d" % step
- return
-
- post_data = self.rfile.read(int(self.headers['Content-Length']))
- match = re.match(
- r'(?P<prefixsize>\d+):(?P<totalsize>\d+)\n(?P<prefixes>.+)',
- post_data,
- re.MULTILINE | re.IGNORECASE | re.DOTALL)
- if not match:
- self.send_response(400)
- self.end_headers()
- print "Gethash request is malformed %s" % post_data
- return
-
- prefixsize = int(match.group('prefixsize'))
- total_length = int(match.group('totalsize'))
- if total_length % prefixsize != 0:
- self.send_response(400)
- self.end_headers()
- print ("Gethash request is malformed, length should be a multiple of the "
- " prefix size%s" % post_data)
- return
-
- response = ""
- for n in range(total_length/prefixsize):
- prefix = match.group('prefixes')[n*prefixsize:n*prefixsize + prefixsize]
- hash_data = hashes_for_step.get(prefix)
- if hash_data is not None:
- # Reply with the correct response
- response += hash_data[0]
- # Remember that this hash has now been requested.
- hashes_for_step[prefix] = (hash_data[0], hash_data[1], hash_data[2] + 1)
-
- if not response:
- self.send_response(204)
- self.end_headers()
- return
-
- # Need to perform MACing before sending response out.
- self.send_response(200)
- self.end_headers()
- self.wfile.write(self.MACResponse(response, False))
-
- def GetCannedResponse(self, path, params, step, is_post_request):
- """ Given the parameters of a request, see if a matching response is
- found. If one is found, respond with with it, else respond with a 400.
- """
- responses_for_step = response_data_by_step.get(step)
- if not responses_for_step:
- self.send_response(400)
- self.end_headers()
- print "No responses for step %d" % step
- return
-
- # Delete unnecessary params
- del params["test_step"]
- if "client" in params:
- del params["client"]
- if "appver" in params:
- del params["appver"]
-
- param_key = self.MakeParamKey(
- params,
- is_post_request and
- self.rfile.read(int(self.headers['Content-Length'])) or
- None)
-
- (expected_path, expected_params, server_response) = responses_for_step[0]
- if expected_path != path or param_key != expected_params:
- self.send_response(400)
- self.end_headers()
- print "Expected request with path %s and params %s." % (expected_path,
- expected_params)
- print "Actual request path %s and params %s" % (path, param_key)
- return
-
- # Remove request that was just made
- responses_for_step.pop(0)
-
- # If the next request is not needed for this test run, remove it now.
- # We do this after processing instead of before for cases where the
- # data we are removing is the last requests in a step.
- if responses_for_step:
- (expected_path, _, _) = responses_for_step[0]
- if expected_path == DATABASE_VALIDATION_PATH and not validate_database:
- responses_for_step.pop(0)
-
- if path == DOWNLOADS_PATH:
- # Need to have the redirects point to the current port.
- server_response = re.sub(r'localhost:\d+',
- 'localhost:%d' % server_port,
- server_response)
- # Remove the current MAC, because it's going to be wrong now.
- server_response = server_response[server_response.find('\n')+1:]
- # Add a new correct MAC.
- server_response = self.MACResponse(server_response, True)
-
- self.send_response(200)
- self.end_headers()
- self.wfile.write(server_response)
-
- def do_GET(self):
- self.VerifyRequest(False)
-
- def do_POST(self):
- self.VerifyRequest(True)
-
-
-def SetupServer(datafile_location,
- port,
- opt_enforce_caching,
- opt_validate_database):
- """Sets up the safebrowsing test server.
-
- Arguments:
- datafile_location: The file to load testing data from.
- port: port that the server runs on.
- opt_enforce_caching: Whether to require the client to implement caching.
- opt_validate_database: Whether to require the client makes database
- verification requests.
-
- Returns:
- An HTTPServer object which the caller should call serve_forever() on.
- """
- LoadData(datafile_location)
- # TODO(gcasto): Look into extending HTTPServer to remove global variables.
- global enforce_caching
- global validate_database
- global server_port
- enforce_caching = opt_enforce_caching
- validate_database = opt_validate_database
- server_port = port
- return BaseHTTPServer.HTTPServer(('', port), RequestHandler)
-
-if __name__ == '__main__':
- parser = OptionParser()
- parser.add_option("--datafile", dest="datafile_location",
- default=DEFAULT_DATAFILE_LOCATION,
- help="Location to load testing data from.")
- parser.add_option("--port", dest="port", type="int",
- default=DEFAULT_PORT, help="Port to run the server on.")
- parser.add_option("--enforce_caching", dest="enforce_caching",
- action="store_true", default=False,
- help="Whether to require that the client"
- "has implemented caching or not.")
- parser.add_option("--ignore_database_validation", dest="validate_database",
- action="store_false", default=True,
- help="Whether to requires that the client makes verify "
- "database requests or not.")
- parser.add_option("--server_timeout_sec", dest="server_timeout_sec",
- type="int", default=600,
- help="How long to let the server run before shutting it "
- "down. If <=0, the server will never be down")
- (options, _) = parser.parse_args()
-
- datafile_location = options.datafile_location
- server = SetupServer(options.datafile_location,
- options.port,
- options.enforce_caching,
- options.validate_database)
-
- if (options.server_timeout_sec > 0):
- tm = Timer(options.server_timeout_sec, EndServer)
- tm.start()
-
- server.serve_forever()
diff --git a/testing/testing_input.dat b/testing/testing_input.dat
deleted file mode 100644
index 9eb760f..0000000
--- a/testing/testing_input.dat
+++ /dev/null
Binary files differ