blob: e1caeee40d3837bf0c89250599800687c1ef332e [file] [log] [blame]
# Copyright 2018 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""A Fake Datastore Client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import pprint
import mock
class FakeKey(object):
"""A fake cloud datastore Key."""
def __init__(self, kind, *path):
self.kind = kind
self.path = tuple(path)
self._flat_path = (kind,) + self.path
last = self.path[-1]
if isinstance(last, basestring):
self.name = last
self.id = None
elif isinstance(last, int):
self.name = None
self.id = last
else:
raise TypeError(
'Last element of path must be a name (string) or id (int).')
def __repr__(self):
return pprint.pformat(self.__dict__)
def __str__(self):
return pprint.pformat(self.__dict__)
def __hash__(self):
return hash(self._fields())
def __eq__(self, other):
return self._fields() == other._fields()
def _fields(self):
return tuple(sorted(self.__dict__.iteritems()))
class FakeDatastoreClient(object):
"""A fake in-memory datastore client."""
_NAMESPACE_ROWS = {}
key = FakeKey
def __init__(self, namespace=None):
self.namespace = namespace
self._rows = self._NAMESPACE_ROWS.setdefault(namespace, {})
def put(self, entity):
"""Stores the entity.
Args:
entity: The entity to store.
"""
self._rows[entity.key] = entity
def put_multi(self, entities):
"""Stores the entities.
Args:
entities: The entities to store.
"""
for entity in entities:
self.put(entity)
def get(self, key):
"""Retreives the entity indexed by |key|
Args:
key: The key of the entity to retrieve.
"""
return copy.deepcopy(self._rows.get(key))
def delete(self, key):
"""Deletes the entity with |key|."""
del self._rows[key]
def delete_multi(self, keys):
"""Deletes the entities with |keys|."""
for k in keys:
self.delete(k)
def query(self, kind=None):
"""Returns rows matching some constraint.
Args:
kind: Return rows with this kind.
"""
return mock.Mock(fetch=lambda: [
v for k, v in self._rows.iteritems()
if k.kind == kind
])
@classmethod
def reset_store(cls):
cls._NAMESPACE_ROWS = {}