blob: 0c968dfa722700dbd8c5bd8527254c2f78384aa8 [file] [log] [blame]
# Copyright 2016 The Chromium Authors. All rights reserved.
# Use of this source code is govered by a BSD-style
# license that can be found in the LICENSE file or at
"""Request rate limiting implementation.
This is intented to be used for automatic DDoS protection.
import datetime
import logging
import settings
import time
from infra_libs import ts_mon
from google.appengine.api import memcache
from google.appengine.api.modules import modules
from google.appengine.api import users
DEFAULT_LIMIT = 300 # 300 requests in 5 minutes is 1 QPS.
ANON_USER = 'anon'
COUNTRY_HEADER = 'X-AppEngine-Country'
# Two-letter country code: max requests per N_MINUTES
# This limit will apply to all requests coming
# from this country.
# To add a country code, see GAE logs and use the
# appropriate code from
# E.g., 'cn': 300, # Limit to 1 QPS.
# Modules not in this list will not have rate limiting applied by this
# class.
MODULE_WHITELIST = ['default']
def _CacheKeys(request, now_sec):
""" Returns an array of arrays. Each array contains strings with
the same prefix and a timestamp suffix, starting with the most
recent and decrementing by 1 minute each time.
now = datetime.datetime.fromtimestamp(now_sec)
country = request.headers.get(COUNTRY_HEADER, 'ZZ')
ip = request.remote_addr
minute_buckets = [now - datetime.timedelta(minutes=m) for m in
user = users.get_current_user()
user_email = if user else ANON_USER
# <IP, country, user_email> to be rendered into each key prefix.
prefixes = []
# All logged-in users get a per-user rate limit, regardless of IP and country.
if user:
prefixes.append(['ALL', 'ALL',])
# All anon requests get a per-IP ratelimit.
prefixes.append([ip, 'ALL', 'ALL'])
# All requests from a problematic country get a per-country rate limit,
# regardless of the user (even a non-logged-in one) or IP.
if country in COUNTRY_LIMITS:
prefixes.append(['ALL', country, 'ALL'])
keysets = []
for prefix in prefixes:
keysets.append(['ratelimit-%s-%s' % ('-'.join(prefix),
str(minute_bucket.replace(second=0, microsecond=0)))
for minute_bucket in minute_buckets])
return keysets, country, ip, user_email
class RateLimiter:
blocked_requests = ts_mon.CounterMetric(
limit_exceeded = ts_mon.CounterMetric(
cost_thresh_exceeded = ts_mon.CounterMetric(
checks = ts_mon.CounterMetric(
def __init__(self, _cache=memcache, fail_open=True, **_kwargs):
self.fail_open = fail_open
def CheckStart(self, request, now=None):
if (modules.get_current_module_name() not in MODULE_WHITELIST or
return'X-AppEngine-Country: %s' %
request.headers.get(COUNTRY_HEADER, 'ZZ'))
if now is None:
now = time.time()
keysets, country, ip, user_email = _CacheKeys(request, now)
# There are either two or three sets of keys in keysets.
# Three if the user's country is in COUNTRY_LIMITS, otherwise two.
for keys in keysets:
count = 0
counters = memcache.get_multi(keys)
count = sum(counters.values())
self.checks.increment({'type': 'success'})
except Exception as e:
if not self.fail_open:
self.checks.increment({'type': 'fail_open'})
raise RateLimitExceeded(country=country, ip=ip, user_email=user_email)
self.checks.increment({'type': 'fail_closed'})
limit = COUNTRY_LIMITS.get(country, DEFAULT_LIMIT)
if count > limit:
# Since webapp2 won't let us return a 429 error code
# <>, we can't
# monitor rate limit exceeded events with our standard tools.
# We return a 400 with a custom error message to the client,
# and this logging is so we can monitor it internally.'Rate Limit Exceeded: %s, %s, %s, %d' % (
country, ip, user_email, count))
if settings.ratelimiting_enabled:
raise RateLimitExceeded(country=country, ip=ip, user_email=user_email)
k = keys[0]
# Only update the latest *time* bucket for each prefix (reverse chron).
memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
memcache.incr(k, initial_value=0)
def CheckEnd(self, request, now, start_time):
"""If a request was expensive to process, charge some extra points
against this set of buckets.
We pass in both now and start_time so we can update the buckets
based on keys created from start_time instead of now.
now and start_time are float seconds.
if (modules.get_current_module_name() not in MODULE_WHITELIST or
not settings.ratelimiting_cost_enabled):
elapsed_ms = (now - start_time) * 1000
# Would it kill the python lib maintainers to have timedelta.total_ms()?
if elapsed_ms < settings.ratelimiting_cost_thresh_ms:
# TODO: Look into caching the keys instead of generating them twice
# for every request. Say, return them from CheckStart so they can
# be bassed back in here later.
keysets, country, ip, user_email = _CacheKeys(request, start_time)
for keys in keysets:'Rate Limit Cost Threshold Exceeded: %s, %s, %s' % (
country, ip, user_email))
# Only update the latest *time* bucket for each prefix (reverse chron).
k = keys[0]
memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
memcache.incr(k, initial_value=0)
class RateLimitExceeded(Exception):
def __init__(self, country=None, ip=None, user_email=None, **_kwargs): = country
self.ip = ip
self.user_email = user_email
def __str__(self):
return 'RateLimitExceeded: %s, %s, %s' % (, self.ip, self.user_email)