blob: 467037557f967b4e07239e33586bb7a27aa222b9 [file] [log] [blame]
# (c) 2005 Ian Bicking and contributors; written for Paste
# (http://pythonpaste.org) Licensed under the MIT license:
# http://www.opensource.org/licenses/mit-license.php
"""
Gives a multi-value dictionary object (MultiDict) plus several wrappers
"""
from collections import MutableMapping
import binascii
import warnings
from webob.compat import (
PY3,
iteritems_,
itervalues_,
url_encode,
)
__all__ = ['MultiDict', 'NestedMultiDict', 'NoVars', 'GetDict']
class MultiDict(MutableMapping):
"""
An ordered dictionary that can have multiple values for each key.
Adds the methods getall, getone, mixed and extend and add to the normal
dictionary interface.
"""
def __init__(self, *args, **kw):
if len(args) > 1:
raise TypeError("MultiDict can only be called with one positional "
"argument")
if args:
if hasattr(args[0], 'iteritems'):
items = list(args[0].iteritems())
elif hasattr(args[0], 'items'):
items = list(args[0].items())
else:
items = list(args[0])
self._items = items
else:
self._items = []
if kw:
self._items.extend(kw.items())
@classmethod
def view_list(cls, lst):
"""
Create a dict that is a view on the given list
"""
if not isinstance(lst, list):
raise TypeError(
"%s.view_list(obj) takes only actual list objects, not %r"
% (cls.__name__, lst))
obj = cls()
obj._items = lst
return obj
@classmethod
def from_fieldstorage(cls, fs):
"""
Create a dict from a cgi.FieldStorage instance
"""
obj = cls()
# fs.list can be None when there's nothing to parse
for field in fs.list or ():
charset = field.type_options.get('charset', 'utf8')
transfer_encoding = field.headers.get('Content-Transfer-Encoding', None)
supported_transfer_encoding = {
'base64' : binascii.a2b_base64,
'quoted-printable' : binascii.a2b_qp
}
if PY3: # pragma: no cover
if charset == 'utf8':
decode = lambda b: b
else:
decode = lambda b: b.encode('utf8').decode(charset)
else:
decode = lambda b: b.decode(charset)
if field.filename:
field.filename = decode(field.filename)
obj.add(field.name, field)
else:
value = field.value
if transfer_encoding in supported_transfer_encoding:
if PY3: # pragma: no cover
# binascii accepts bytes
value = value.encode('utf8')
value = supported_transfer_encoding[transfer_encoding](value)
if PY3: # pragma: no cover
# binascii returns bytes
value = value.decode('utf8')
obj.add(field.name, decode(value))
return obj
def __getitem__(self, key):
for k, v in reversed(self._items):
if k == key:
return v
raise KeyError(key)
def __setitem__(self, key, value):
try:
del self[key]
except KeyError:
pass
self._items.append((key, value))
def add(self, key, value):
"""
Add the key and value, not overwriting any previous value.
"""
self._items.append((key, value))
def getall(self, key):
"""
Return a list of all values matching the key (may be an empty list)
"""
return [v for k, v in self._items if k == key]
def getone(self, key):
"""
Get one value matching the key, raising a KeyError if multiple
values were found.
"""
v = self.getall(key)
if not v:
raise KeyError('Key not found: %r' % key)
if len(v) > 1:
raise KeyError('Multiple values match %r: %r' % (key, v))
return v[0]
def mixed(self):
"""
Returns a dictionary where the values are either single
values, or a list of values when a key/value appears more than
once in this dictionary. This is similar to the kind of
dictionary often used to represent the variables in a web
request.
"""
result = {}
multi = {}
for key, value in self.items():
if key in result:
# We do this to not clobber any lists that are
# *actual* values in this dictionary:
if key in multi:
result[key].append(value)
else:
result[key] = [result[key], value]
multi[key] = None
else:
result[key] = value
return result
def dict_of_lists(self):
"""
Returns a dictionary where each key is associated with a list of values.
"""
r = {}
for key, val in self.items():
r.setdefault(key, []).append(val)
return r
def __delitem__(self, key):
items = self._items
found = False
for i in range(len(items)-1, -1, -1):
if items[i][0] == key:
del items[i]
found = True
if not found:
raise KeyError(key)
def __contains__(self, key):
for k, v in self._items:
if k == key:
return True
return False
has_key = __contains__
def clear(self):
del self._items[:]
def copy(self):
return self.__class__(self)
def setdefault(self, key, default=None):
for k, v in self._items:
if key == k:
return v
self._items.append((key, default))
return default
def pop(self, key, *args):
if len(args) > 1:
raise TypeError("pop expected at most 2 arguments, got %s"
% repr(1 + len(args)))
for i in range(len(self._items)):
if self._items[i][0] == key:
v = self._items[i][1]
del self._items[i]
return v
if args:
return args[0]
else:
raise KeyError(key)
def popitem(self):
return self._items.pop()
def update(self, *args, **kw):
if args:
lst = args[0]
if len(lst) != len(dict(lst)):
# this does not catch the cases where we overwrite existing
# keys, but those would produce too many warning
msg = ("Behavior of MultiDict.update() has changed "
"and overwrites duplicate keys. Consider using .extend()"
)
warnings.warn(msg, UserWarning, stacklevel=2)
MutableMapping.update(self, *args, **kw)
def extend(self, other=None, **kwargs):
if other is None:
pass
elif hasattr(other, 'items'):
self._items.extend(other.items())
elif hasattr(other, 'keys'):
for k in other.keys():
self._items.append((k, other[k]))
else:
for k, v in other:
self._items.append((k, v))
if kwargs:
self.update(kwargs)
def __repr__(self):
items = map('(%r, %r)'.__mod__, _hide_passwd(self.items()))
return '%s([%s])' % (self.__class__.__name__, ', '.join(items))
def __len__(self):
return len(self._items)
##
## All the iteration:
##
def iterkeys(self):
for k, v in self._items:
yield k
if PY3: # pragma: no cover
keys = iterkeys
else:
def keys(self):
return [k for k, v in self._items]
__iter__ = iterkeys
def iteritems(self):
return iter(self._items)
if PY3: # pragma: no cover
items = iteritems
else:
def items(self):
return self._items[:]
def itervalues(self):
for k, v in self._items:
yield v
if PY3: # pragma: no cover
values = itervalues
else:
def values(self):
return [v for k, v in self._items]
_dummy = object()
class GetDict(MultiDict):
# def __init__(self, data, tracker, encoding, errors):
# d = lambda b: b.decode(encoding, errors)
# data = [(d(k), d(v)) for k,v in data]
def __init__(self, data, env):
self.env = env
MultiDict.__init__(self, data)
def on_change(self):
e = lambda t: t.encode('utf8')
data = [(e(k), e(v)) for k,v in self.items()]
qs = url_encode(data)
self.env['QUERY_STRING'] = qs
self.env['webob._parsed_query_vars'] = (self, qs)
def __setitem__(self, key, value):
MultiDict.__setitem__(self, key, value)
self.on_change()
def add(self, key, value):
MultiDict.add(self, key, value)
self.on_change()
def __delitem__(self, key):
MultiDict.__delitem__(self, key)
self.on_change()
def clear(self):
MultiDict.clear(self)
self.on_change()
def setdefault(self, key, default=None):
result = MultiDict.setdefault(self, key, default)
self.on_change()
return result
def pop(self, key, *args):
result = MultiDict.pop(self, key, *args)
self.on_change()
return result
def popitem(self):
result = MultiDict.popitem(self)
self.on_change()
return result
def update(self, *args, **kwargs):
MultiDict.update(self, *args, **kwargs)
self.on_change()
def extend(self, *args, **kwargs):
MultiDict.extend(self, *args, **kwargs)
self.on_change()
def __repr__(self):
items = map('(%r, %r)'.__mod__, _hide_passwd(self.items()))
# TODO: GET -> GetDict
return 'GET([%s])' % (', '.join(items))
def copy(self):
# Copies shouldn't be tracked
return MultiDict(self)
class NestedMultiDict(MultiDict):
"""
Wraps several MultiDict objects, treating it as one large MultiDict
"""
def __init__(self, *dicts):
self.dicts = dicts
def __getitem__(self, key):
for d in self.dicts:
value = d.get(key, _dummy)
if value is not _dummy:
return value
raise KeyError(key)
def _readonly(self, *args, **kw):
raise KeyError("NestedMultiDict objects are read-only")
__setitem__ = _readonly
add = _readonly
__delitem__ = _readonly
clear = _readonly
setdefault = _readonly
pop = _readonly
popitem = _readonly
update = _readonly
def getall(self, key):
result = []
for d in self.dicts:
result.extend(d.getall(key))
return result
# Inherited:
# getone
# mixed
# dict_of_lists
def copy(self):
return MultiDict(self)
def __contains__(self, key):
for d in self.dicts:
if key in d:
return True
return False
has_key = __contains__
def __len__(self):
v = 0
for d in self.dicts:
v += len(d)
return v
def __nonzero__(self):
for d in self.dicts:
if d:
return True
return False
def iteritems(self):
for d in self.dicts:
for item in iteritems_(d):
yield item
if PY3: # pragma: no cover
items = iteritems
else:
def items(self):
return list(self.iteritems())
def itervalues(self):
for d in self.dicts:
for value in itervalues_(d):
yield value
if PY3: # pragma: no cover
values = itervalues
else:
def values(self):
return list(self.itervalues())
def __iter__(self):
for d in self.dicts:
for key in d:
yield key
iterkeys = __iter__
if PY3: # pragma: no cover
keys = iterkeys
else:
def keys(self):
return list(self.iterkeys())
class NoVars(object):
"""
Represents no variables; used when no variables
are applicable.
This is read-only
"""
def __init__(self, reason=None):
self.reason = reason or 'N/A'
def __getitem__(self, key):
raise KeyError("No key %r: %s" % (key, self.reason))
def __setitem__(self, *args, **kw):
raise KeyError("Cannot add variables: %s" % self.reason)
add = __setitem__
setdefault = __setitem__
update = __setitem__
def __delitem__(self, *args, **kw):
raise KeyError("No keys to delete: %s" % self.reason)
clear = __delitem__
pop = __delitem__
popitem = __delitem__
def get(self, key, default=None):
return default
def getall(self, key):
return []
def getone(self, key):
return self[key]
def mixed(self):
return {}
dict_of_lists = mixed
def __contains__(self, key):
return False
has_key = __contains__
def copy(self):
return self
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__,
self.reason)
def __len__(self):
return 0
def __cmp__(self, other):
return cmp({}, other)
def iterkeys(self):
return iter([])
if PY3: # pragma: no cover
keys = iterkeys
items = iterkeys
values = iterkeys
else:
def keys(self):
return []
items = keys
values = keys
itervalues = iterkeys
iteritems = iterkeys
__iter__ = iterkeys
def _hide_passwd(items):
for k, v in items:
if ('password' in k
or 'passwd' in k
or 'pwd' in k
):
yield k, '******'
else:
yield k, v