blob: 07ae88993a0afa61511d39660ef8becdff8241bf [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright 2007 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Imports data over HTTP.
Usage:
%(arg0)s [flags]
--debug Show debugging information. (Optional)
--application=<string> Application ID of endpoint (Optional for
*.appspot.com)
--auth_domain=<domain> The auth domain to use for logging in and for
UserProperties. (Default: gmail.com)
--bandwidth_limit=<int> The maximum number of bytes per second for the
aggregate transfer of data to/from the server.
Bursts may exceed this, but overall transfer rate is
restricted to this rate. (Default: 250000)
--batch_size=<int> Number of Entity objects to include in each request
to/from the URL endpoint. The more data per
row/Entity, the smaller the batch size should be.
(Default: downloads 100, uploads 10)
--config_file=<path> File containing Model and Loader definitions or
bulkloader.yaml transforms. (Required unless --dump,
--restore, or --create_config are used.)
--create_config Write a bulkloader.yaml configuration file to
--filename based on the server side datastore state.
--db_filename=<path> Specific progress database to write to, or to
resume from. If not supplied, then a new database
will be started, named:
bulkloader-progress-TIMESTAMP.
The special filename "skip" may be used to simply
skip reading/writing any progress information.
--download Export entities to a file.
--dry_run Do not execute any remote_api calls.
--dump Use zero-configuration dump format.
--email=<string> The username to use. Will prompt if omitted.
--exporter_opts=<string>
A string to pass to the Exporter.initialize method.
--filename=<path> Path to the file to import/export. (Required when
importing or exporting, not mapping.)
--has_header Skip the first row of the input.
--http_limit=<int> The maximum numer of HTTP requests per second to
send to the server. (Default: 8)
--kind=<string> Name of the Entity object kind to put in the
datastore. (Required)
--loader_opts=<string> A string to pass to the Loader.initialize method.
--log_file=<path> File to write bulkloader logs. If not supplied
then a new log file will be created, named:
bulkloader-log-TIMESTAMP.
--map Map an action across datastore entities.
--mapper_opts=<string> A string to pass to the Mapper.Initialize method.
--num_threads=<int> Number of threads to use for uploading/downloading
entities (Default: 10)
--passin Read the login password from stdin.
--restore Restore from zero-configuration dump format.
--result_db_filename=<path>
Result database to write to for downloads.
--rps_limit=<int> The maximum number of records per second to
transfer to/from the server. (Default: 20)
--url=<string> URL endpoint to post to for importing/exporting
data. (Required)
--namespace=<string> Use specified namespace instead of the default one
for all datastore operations.
The exit status will be 0 on success, non-zero on import failure.
Works with the remote_api mix-in library for google.appengine.ext.remote_api.
Please look there for documentation about how to setup the server side.
Example:
%(arg0)s --url=http://app.appspot.com/remote_api --kind=Model \
--filename=data.csv --config_file=loader_config.py
"""
import csv
import errno
import getopt
import getpass
import imp
import logging
import os
import Queue
import re
import shutil
import signal
import StringIO
import sys
import threading
import time
import traceback
import urllib2
import urlparse
from google.appengine.datastore import entity_pb
from google.appengine.api import apiproxy_stub_map
from google.appengine.api import datastore
from google.appengine.api import datastore_errors
from google.appengine.api.namespace_manager import namespace_manager
from google.appengine.datastore import datastore_pb
from google.appengine.ext import db
from google.appengine.ext import key_range as key_range_module
from google.appengine.ext.bulkload import bulkloader_config
from google.appengine.ext.db import polymodel
from google.appengine.ext.db import stats
from google.appengine.ext.remote_api import remote_api_stub
from google.appengine.ext.remote_api import throttle as remote_api_throttle
from google.appengine.runtime import apiproxy_errors
from google.appengine.tools import adaptive_thread_pool
from google.appengine.tools import appengine_rpc
from google.appengine.tools.requeue import ReQueue
try:
import sqlite3
except ImportError:
pass
logger = logging.getLogger('google.appengine.tools.bulkloader')
KeyRange = key_range_module.KeyRange
DEFAULT_THREAD_COUNT = 10
DEFAULT_BATCH_SIZE = 10
DEFAULT_DOWNLOAD_BATCH_SIZE = 100
DEFAULT_QUEUE_SIZE = DEFAULT_THREAD_COUNT * 10
_THREAD_SHOULD_EXIT = '_THREAD_SHOULD_EXIT'
STATE_READ = 0
STATE_SENDING = 1
STATE_SENT = 2
STATE_NOT_SENT = 3
STATE_GETTING = 1
STATE_GOT = 2
STATE_ERROR = 3
DATA_CONSUMED_TO_HERE = 'DATA_CONSUMED_TO_HERE'
INITIAL_BACKOFF = 1.0
BACKOFF_FACTOR = 2.0
DEFAULT_BANDWIDTH_LIMIT = 250000
DEFAULT_RPS_LIMIT = 20
DEFAULT_REQUEST_LIMIT = 8
MAXIMUM_INCREASE_DURATION = 5.0
MAXIMUM_HOLD_DURATION = 12.0
AUTH_FAILED_MESSAGE = ('Authentication Failed: Incorrect credentials or '
'unsupported authentication type (e.g. OpenId).')
def ImportStateMessage(state):
"""Converts a numeric state identifier to a status message."""
return ({
STATE_READ: 'Batch read from file.',
STATE_SENDING: 'Sending batch to server.',
STATE_SENT: 'Batch successfully sent.',
STATE_NOT_SENT: 'Error while sending batch.'
}[state])
def ExportStateMessage(state):
"""Converts a numeric state identifier to a status message."""
return ({
STATE_READ: 'Batch read from file.',
STATE_GETTING: 'Fetching batch from server',
STATE_GOT: 'Batch successfully fetched.',
STATE_ERROR: 'Error while fetching batch'
}[state])
def MapStateMessage(state):
"""Converts a numeric state identifier to a status message."""
return ({
STATE_READ: 'Batch read from file.',
STATE_GETTING: 'Querying for batch from server',
STATE_GOT: 'Batch successfully fetched.',
STATE_ERROR: 'Error while fetching or mapping.'
}[state])
def ExportStateName(state):
"""Converts a numeric state identifier to a string."""
return ({
STATE_READ: 'READ',
STATE_GETTING: 'GETTING',
STATE_GOT: 'GOT',
STATE_ERROR: 'NOT_GOT'
}[state])
def ImportStateName(state):
"""Converts a numeric state identifier to a string."""
return ({
STATE_READ: 'READ',
STATE_GETTING: 'SENDING',
STATE_GOT: 'SENT',
STATE_NOT_SENT: 'NOT_SENT'
}[state])
class Error(Exception):
"""Base-class for exceptions in this module."""
class MissingPropertyError(Error):
"""An expected field is missing from an entity, and no default was given."""
class FatalServerError(Error):
"""An unrecoverable error occurred while posting data to the server."""
class ResumeError(Error):
"""Error while trying to resume a partial upload."""
class ConfigurationError(Error):
"""Error in configuration options."""
class AuthenticationError(Error):
"""Error while trying to authenticate with the server."""
class FileNotFoundError(Error):
"""A filename passed in by the user refers to a non-existent input file."""
class FileNotReadableError(Error):
"""A filename passed in by the user refers to a non-readable input file."""
class FileExistsError(Error):
"""A filename passed in by the user refers to an existing output file."""
class FileNotWritableError(Error):
"""A filename passed in by the user refers to a non-writable output file."""
class BadStateError(Error):
"""A work item in an unexpected state was encountered."""
class KeyRangeError(Error):
"""An error during construction of a KeyRangeItem."""
class KindStatError(Error):
"""Unable to find kind stats for an all-kinds download."""
class FieldSizeLimitError(Error):
"""The csv module tried to read a field larger than the size limit."""
def __init__(self, limit):
self.message = """
A field in your CSV input file has exceeded the current limit of %d.
You can raise this limit by adding the following lines to your config file:
import csv
csv.field_size_limit(new_limit)
where new_limit is number larger than the size in bytes of the largest
field in your CSV.
""" % limit
Error.__init__(self, self.message)
class NameClashError(Error):
"""A name clash occurred while trying to alias old method names."""
def __init__(self, old_name, new_name, klass):
Error.__init__(self, old_name, new_name, klass)
self.old_name = old_name
self.new_name = new_name
self.klass = klass
def GetCSVGeneratorFactory(kind, csv_filename, batch_size, csv_has_header,
openfile=open, create_csv_reader=csv.reader):
"""Return a factory that creates a CSV-based UploadWorkItem generator.
Args:
kind: The kind of the entities being uploaded.
csv_filename: File on disk containing CSV data.
batch_size: Maximum number of CSV rows to stash into an UploadWorkItem.
csv_has_header: Whether to skip the first row of the CSV.
openfile: Used for dependency injection.
create_csv_reader: Used for dependency injection.
Returns:
A callable (accepting the Progress Queue and Progress Generators
as input) which creates the UploadWorkItem generator.
"""
loader = Loader.RegisteredLoader(kind)
loader._Loader__openfile = openfile
loader._Loader__create_csv_reader = create_csv_reader
record_generator = loader.generate_records(csv_filename)
def CreateGenerator(request_manager, progress_queue, progress_generator,
unused_kinds):
"""Initialize a UploadWorkItem generator.
Args:
request_manager: A RequestManager instance.
progress_queue: A ProgressQueue instance to send progress information.
progress_generator: A generator of progress information or None.
unused_kinds: The kinds being generated (ignored in this method).
Returns:
An UploadWorkItemGenerator instance.
"""
return UploadWorkItemGenerator(request_manager,
progress_queue,
progress_generator,
record_generator,
csv_has_header,
batch_size)
return CreateGenerator
class UploadWorkItemGenerator(object):
"""Reads rows from a row generator and generates UploadWorkItems."""
def __init__(self,
request_manager,
progress_queue,
progress_generator,
record_generator,
skip_first,
batch_size):
"""Initialize a WorkItemGenerator.
Args:
request_manager: A RequestManager instance with which to associate
WorkItems.
progress_queue: A progress queue with which to associate WorkItems.
progress_generator: A generator of progress information.
record_generator: A generator of data records.
skip_first: Whether to skip the first data record.
batch_size: The number of data records per WorkItem.
"""
self.request_manager = request_manager
self.progress_queue = progress_queue
self.progress_generator = progress_generator
self.reader = record_generator
self.skip_first = skip_first
self.batch_size = batch_size
self.line_number = 1
self.column_count = None
self.read_rows = []
self.row_count = 0
self.xfer_count = 0
def _AdvanceTo(self, line):
"""Advance the reader to the given line.
Args:
line: A line number to advance to.
"""
while self.line_number < line:
self.reader.next()
self.line_number += 1
self.row_count += 1
self.xfer_count += 1
def _ReadRows(self, key_start, key_end):
"""Attempts to read and encode rows [key_start, key_end].
The encoded rows are stored in self.read_rows.
Args:
key_start: The starting line number.
key_end: The ending line number.
Raises:
StopIteration: if the reader runs out of rows
ResumeError: if there are an inconsistent number of columns.
"""
assert self.line_number == key_start
self.read_rows = []
while self.line_number <= key_end:
row = self.reader.next()
self.row_count += 1
if self.column_count is None:
self.column_count = len(row)
self.read_rows.append((self.line_number, row))
self.line_number += 1
def _MakeItem(self, key_start, key_end, rows, progress_key=None):
"""Makes a UploadWorkItem containing the given rows, with the given keys.
Args:
key_start: The start key for the UploadWorkItem.
key_end: The end key for the UploadWorkItem.
rows: A list of the rows for the UploadWorkItem.
progress_key: The progress key for the UploadWorkItem
Returns:
An UploadWorkItem instance for the given batch.
"""
assert rows
item = UploadWorkItem(self.request_manager, self.progress_queue, rows,
key_start, key_end, progress_key=progress_key)
return item
def Batches(self):
"""Reads from the record_generator and generates UploadWorkItems.
Yields:
Instances of class UploadWorkItem
Raises:
ResumeError: If the progress database and data file indicate a different
number of rows.
"""
if self.skip_first:
logger.info('Skipping header line.')
try:
self.reader.next()
except StopIteration:
return
exhausted = False
self.line_number = 1
self.column_count = None
logger.info('Starting import; maximum %d entities per post',
self.batch_size)
state = None
if self.progress_generator:
for progress_key, state, kind, key_start, key_end in (
self.progress_generator):
if key_start:
try:
self._AdvanceTo(key_start)
self._ReadRows(key_start, key_end)
yield self._MakeItem(key_start,
key_end,
self.read_rows,
progress_key=progress_key)
except StopIteration:
logger.error('Mismatch between data file and progress database')
raise ResumeError(
'Mismatch between data file and progress database')
elif state == DATA_CONSUMED_TO_HERE:
try:
self._AdvanceTo(key_end + 1)
except StopIteration:
state = None
if self.progress_generator is None or state == DATA_CONSUMED_TO_HERE:
while not exhausted:
key_start = self.line_number
key_end = self.line_number + self.batch_size - 1
try:
self._ReadRows(key_start, key_end)
except StopIteration:
exhausted = True
key_end = self.line_number - 1
if key_start <= key_end:
yield self._MakeItem(key_start, key_end, self.read_rows)
class CSVGenerator(object):
"""Reads a CSV file and generates data records."""
def __init__(self,
csv_filename,
openfile=open,
create_csv_reader=csv.reader):
"""Initializes a CSV generator.
Args:
csv_filename: File on disk containing CSV data.
openfile: Used for dependency injection of 'open'.
create_csv_reader: Used for dependency injection of 'csv.reader'.
"""
self.csv_filename = csv_filename
self.openfile = openfile
self.create_csv_reader = create_csv_reader
def Records(self):
"""Reads the CSV data file and generates row records.
Yields:
Lists of strings
Raises:
ResumeError: If the progress database and data file indicate a different
number of rows.
"""
csv_file = self.openfile(self.csv_filename, 'rb')
reader = self.create_csv_reader(csv_file, skipinitialspace=True)
try:
for record in reader:
yield record
except csv.Error, e:
if e.args and e.args[0].startswith('field larger than field limit'):
raise FieldSizeLimitError(csv.field_size_limit())
else:
raise
class KeyRangeItemGenerator(object):
"""Generates ranges of keys to download.
Reads progress information from the progress database and creates
KeyRangeItem objects corresponding to incompletely downloaded parts of an
export.
"""
def __init__(self, request_manager, kinds, progress_queue, progress_generator,
key_range_item_factory):
"""Initialize the KeyRangeItemGenerator.
Args:
request_manager: A RequestManager instance.
kinds: The kind of entities being transferred, or a list of kinds.
progress_queue: A queue used for tracking progress information.
progress_generator: A generator of prior progress information, or None
if there is no prior status.
key_range_item_factory: A factory to produce KeyRangeItems.
"""
self.request_manager = request_manager
if isinstance(kinds, basestring):
self.kinds = [kinds]
else:
self.kinds = kinds
self.row_count = 0
self.xfer_count = 0
self.progress_queue = progress_queue
self.progress_generator = progress_generator
self.key_range_item_factory = key_range_item_factory
def Batches(self):
"""Iterate through saved progress information.
Yields:
KeyRangeItem instances corresponding to undownloaded key ranges.
"""
if self.progress_generator is not None:
for progress_key, state, kind, key_start, key_end in (
self.progress_generator):
if state is not None and state != STATE_GOT and key_start is not None:
key_start = ParseKey(key_start)
key_end = ParseKey(key_end)
key_range = KeyRange(key_start=key_start,
key_end=key_end)
result = self.key_range_item_factory(self.request_manager,
self.progress_queue,
kind,
key_range,
progress_key=progress_key,
state=STATE_READ)
yield result
else:
for kind in self.kinds:
key_range = KeyRange()
yield self.key_range_item_factory(self.request_manager,
self.progress_queue,
kind,
key_range)
class DownloadResult(object):
"""Holds the result of an entity download."""
def __init__(self, continued, direction, keys, entities):
self.continued = continued
self.direction = direction
self.keys = keys
self.entities = entities
self.count = len(keys)
assert self.count == len(entities)
assert direction in (key_range_module.KeyRange.ASC,
key_range_module.KeyRange.DESC)
if self.count > 0:
if direction == key_range_module.KeyRange.ASC:
self.key_start = keys[0]
self.key_end = keys[-1]
else:
self.key_start = keys[-1]
self.key_end = keys[0]
def Entities(self):
"""Returns the list of entities for this result in key order."""
if self.direction == key_range_module.KeyRange.ASC:
return list(self.entities)
else:
result = list(self.entities)
result.reverse()
return result
def __str__(self):
return 'continued = %s\n%s' % (
str(self.continued), '\n'.join(str(self.entities)))
class _WorkItem(adaptive_thread_pool.WorkItem):
"""Holds a description of a unit of upload or download work."""
def __init__(self, progress_queue, key_start, key_end, state_namer,
state=STATE_READ, progress_key=None):
"""Initialize the _WorkItem instance.
Args:
progress_queue: A queue used for tracking progress information.
key_start: The start key of the work item.
key_end: The end key of the work item.
state_namer: Function to describe work item states.
state: The initial state of the work item.
progress_key: If this WorkItem represents state from a prior run,
then this will be the key within the progress database.
"""
adaptive_thread_pool.WorkItem.__init__(self,
'[%s-%s]' % (key_start, key_end))
self.progress_queue = progress_queue
self.state_namer = state_namer
self.state = state
self.progress_key = progress_key
self.progress_event = threading.Event()
self.key_start = key_start
self.key_end = key_end
self.error = None
self.traceback = None
self.kind = None
def _TransferItem(self, thread_pool):
raise NotImplementedError()
def SetError(self):
"""Sets the error and traceback information for this thread.
This must be called from an exception handler.
"""
if not self.error:
exc_info = sys.exc_info()
self.error = exc_info[1]
self.traceback = exc_info[2]
def PerformWork(self, thread_pool):
"""Perform the work of this work item and report the results.
Args:
thread_pool: An AdaptiveThreadPool instance.
Returns:
A tuple (status, instruction) of the work status and an instruction
for the ThreadGate.
"""
status = adaptive_thread_pool.WorkItem.FAILURE
instruction = adaptive_thread_pool.ThreadGate.DECREASE
try:
self.MarkAsTransferring()
try:
transfer_time = self._TransferItem(thread_pool)
if transfer_time is None:
status = adaptive_thread_pool.WorkItem.RETRY
instruction = adaptive_thread_pool.ThreadGate.HOLD
else:
logger.debug('[%s] %s Transferred %d entities in %0.1f seconds',
threading.currentThread().getName(), self, self.count,
transfer_time)
sys.stdout.write('.')
sys.stdout.flush()
status = adaptive_thread_pool.WorkItem.SUCCESS
if transfer_time <= MAXIMUM_INCREASE_DURATION:
instruction = adaptive_thread_pool.ThreadGate.INCREASE
elif transfer_time <= MAXIMUM_HOLD_DURATION:
instruction = adaptive_thread_pool.ThreadGate.HOLD
except (db.InternalError, db.NotSavedError, db.Timeout,
db.TransactionFailedError,
apiproxy_errors.OverQuotaError,
apiproxy_errors.DeadlineExceededError,
apiproxy_errors.ApplicationError), e:
status = adaptive_thread_pool.WorkItem.RETRY
logger.exception('Retrying on non-fatal datastore error: %s', e)
except urllib2.HTTPError, e:
http_status = e.code
if http_status >= 500 and http_status < 600:
status = adaptive_thread_pool.WorkItem.RETRY
logger.exception('Retrying on non-fatal HTTP error: %d %s',
http_status, e.msg)
else:
self.SetError()
status = adaptive_thread_pool.WorkItem.FAILURE
except urllib2.URLError, e:
if IsURLErrorFatal(e):
self.SetError()
status = adaptive_thread_pool.WorkItem.FAILURE
else:
status = adaptive_thread_pool.WorkItem.RETRY
logger.exception('Retrying on non-fatal URL error: %s', e.reason)
finally:
if status == adaptive_thread_pool.WorkItem.SUCCESS:
self.MarkAsTransferred()
else:
self.MarkAsError()
return (status, instruction)
def _AssertInState(self, *states):
"""Raises an Error if the state of this range is not in states."""
if not self.state in states:
raise BadStateError('%s:%s not in %s' %
(str(self),
self.state_namer(self.state),
map(self.state_namer, states)))
def _AssertProgressKey(self):
"""Raises an Error if the progress key is None."""
if self.progress_key is None:
raise BadStateError('%s: Progress key is missing' % str(self))
def MarkAsRead(self):
"""Mark this _WorkItem as read, updating the progress database."""
self._AssertInState(STATE_READ)
self._StateTransition(STATE_READ, blocking=True)
def MarkAsTransferring(self):
"""Mark this _WorkItem as transferring, updating the progress database."""
self._AssertInState(STATE_READ, STATE_ERROR)
self._AssertProgressKey()
self._StateTransition(STATE_GETTING, blocking=True)
def MarkAsTransferred(self):
"""Mark this _WorkItem as transferred, updating the progress database."""
raise NotImplementedError()
def MarkAsError(self):
"""Mark this _WorkItem as failed, updating the progress database."""
self._AssertInState(STATE_GETTING)
self._AssertProgressKey()
self._StateTransition(STATE_ERROR, blocking=True)
def _StateTransition(self, new_state, blocking=False):
"""Transition the work item to a new state, storing progress information.
Args:
new_state: The state to transition to.
blocking: Whether to block for the progress thread to acknowledge the
transition.
"""
assert not self.progress_event.isSet()
self.state = new_state
self.progress_queue.put(self)
if blocking:
self.progress_event.wait()
self.progress_event.clear()
class UploadWorkItem(_WorkItem):
"""Holds a unit of uploading work.
A UploadWorkItem represents a number of entities that need to be uploaded to
Google App Engine. These entities are encoded in the "content" field of
the UploadWorkItem, and will be POST'd as-is to the server.
The entities are identified by a range of numeric keys, inclusively. In
the case of a resumption of an upload, or a replay to correct errors,
these keys must be able to identify the same set of entities.
Note that keys specify a range. The entities do not have to sequentially
fill the entire range, they must simply bound a range of valid keys.
"""
def __init__(self, request_manager, progress_queue, rows, key_start, key_end,
progress_key=None):
"""Initialize the UploadWorkItem instance.
Args:
request_manager: A RequestManager instance.
progress_queue: A queue used for tracking progress information.
rows: A list of pairs of a line number and a list of column values.
key_start: The (numeric) starting key, inclusive.
key_end: The (numeric) ending key, inclusive.
progress_key: If this UploadWorkItem represents state from a prior run,
then this will be the key within the progress database.
"""
_WorkItem.__init__(self, progress_queue, key_start, key_end,
ImportStateName, state=STATE_READ,
progress_key=progress_key)
assert isinstance(key_start, (int, long))
assert isinstance(key_end, (int, long))
assert key_start <= key_end
self.request_manager = request_manager
self.rows = rows
self.content = None
self.count = len(rows)
def __str__(self):
return '[%s-%s]' % (self.key_start, self.key_end)
def _TransferItem(self, thread_pool, get_time=time.time):
"""Transfers the entities associated with an item.
Args:
thread_pool: An AdaptiveThreadPool instance.
get_time: Used for dependency injection.
"""
t = get_time()
if not self.content:
self.content = self.request_manager.EncodeContent(self.rows)
try:
self.request_manager.PostEntities(self.content)
except:
raise
return get_time() - t
def MarkAsTransferred(self):
"""Mark this UploadWorkItem as sucessfully-sent to the server."""
self._AssertInState(STATE_SENDING)
self._AssertProgressKey()
self._StateTransition(STATE_SENT, blocking=False)
def GetImplementationClass(kind_or_class_key):
"""Returns the implementation class for a given kind or class key.
Args:
kind_or_class_key: A kind string or a tuple of kind strings.
Return:
A db.Model subclass for the given kind or class key.
"""
if isinstance(kind_or_class_key, tuple):
try:
implementation_class = polymodel._class_map[kind_or_class_key]
except KeyError:
raise db.KindError('No implementation for class \'%s\'' %
kind_or_class_key)
else:
implementation_class = db.class_for_kind(kind_or_class_key)
return implementation_class
def KeyLEQ(key1, key2):
"""Compare two keys for less-than-or-equal-to.
All keys with numeric ids come before all keys with names. None represents
an unbounded end-point so it is both greater and less than any other key.
Args:
key1: An int or datastore.Key instance.
key2: An int or datastore.Key instance.
Returns:
True if key1 <= key2
"""
if key1 is None or key2 is None:
return True
return key1 <= key2
class KeyRangeItem(_WorkItem):
"""Represents an item of work that scans over a key range.
A KeyRangeItem object represents holds a KeyRange
and has an associated state: STATE_READ, STATE_GETTING, STATE_GOT,
and STATE_ERROR.
- STATE_READ indicates the range ready to be downloaded by a worker thread.
- STATE_GETTING indicates the range is currently being downloaded.
- STATE_GOT indicates that the range was successfully downloaded
- STATE_ERROR indicates that an error occurred during the last download
attempt
KeyRangeItems not in the STATE_GOT state are stored in the progress database.
When a piece of KeyRangeItem work is downloaded, the download may cover only
a portion of the range. In this case, the old KeyRangeItem is removed from
the progress database and ranges covering the undownloaded range are
generated and stored as STATE_READ in the export progress database.
"""
def __init__(self,
request_manager,
progress_queue,
kind,
key_range,
progress_key=None,
state=STATE_READ,
first=False):
"""Initialize a KeyRangeItem object.
Args:
request_manager: A RequestManager instance.
progress_queue: A queue used for tracking progress information.
kind: The kind of entities for this range.
key_range: A KeyRange instance for this work item.
progress_key: The key for this range within the progress database.
state: The initial state of this range.
first: boolean, default False, whether this is the first WorkItem
of its kind.
"""
_WorkItem.__init__(self, progress_queue, key_range.key_start,
key_range.key_end, ExportStateName, state=state,
progress_key=progress_key)
assert KeyLEQ(key_range.key_start, key_range.key_end), (
'%s not less than %s' %
(repr(key_range.key_start), repr(key_range.key_end)))
self.request_manager = request_manager
self.kind = kind
self.key_range = key_range
self.download_result = None
self.count = 0
self.key_start = key_range.key_start
self.key_end = key_range.key_end
self.first = first
def __str__(self):
return '%s-%s' % (self.kind, self.key_range)
def __repr__(self):
return self.__str__()
def MarkAsTransferred(self):
"""Mark this KeyRangeItem as transferred, updating the progress database."""
pass
def Process(self, download_result, thread_pool, batch_size,
new_state=STATE_GOT):
"""Mark this KeyRangeItem as success, updating the progress database.
Process will split this KeyRangeItem based on the content of
download_result and adds the unfinished ranges to the work queue.
Args:
download_result: A DownloadResult instance.
thread_pool: An AdaptiveThreadPool instance.
batch_size: The number of entities to transfer per request.
new_state: The state to transition the completed range to.
"""
self._AssertInState(STATE_GETTING)
self._AssertProgressKey()
self.download_result = download_result
self.count = len(download_result.keys)
if download_result.continued:
self._FinishedRange()._StateTransition(new_state, blocking=True)
self._AddUnfinishedRanges(thread_pool, batch_size)
else:
self._StateTransition(new_state, blocking=True)
def _FinishedRange(self):
"""Returns the range completed by the download_result.
Returns:
A KeyRangeItem representing a completed range.
"""
assert self.download_result is not None
if self.key_range.direction == key_range_module.KeyRange.ASC:
key_start = self.key_range.key_start
if self.download_result.continued:
key_end = self.download_result.key_end
else:
key_end = self.key_range.key_end
else:
key_end = self.key_range.key_end
if self.download_result.continued:
key_start = self.download_result.key_start
else:
key_start = self.key_range.key_start
key_range = KeyRange(key_start=key_start,
key_end=key_end,
direction=self.key_range.direction)
result = self.__class__(self.request_manager,
self.progress_queue,
self.kind,
key_range,
progress_key=self.progress_key,
state=self.state)
result.download_result = self.download_result
result.count = self.count
return result
def _SplitAndAddRanges(self, thread_pool, batch_size):
"""Split the key range [key_start, key_end] into a list of ranges."""
if self.download_result.direction == key_range_module.KeyRange.ASC:
key_range = KeyRange(
key_start=self.download_result.key_end,
key_end=self.key_range.key_end,
include_start=False)
else:
key_range = KeyRange(
key_start=self.key_range.key_start,
key_end=self.download_result.key_start,
include_end=False)
if thread_pool.QueuedItemCount() > 2 * thread_pool.num_threads():
ranges = [key_range]
else:
ranges = key_range.split_range(batch_size=batch_size)
for key_range in ranges:
key_range_item = self.__class__(self.request_manager,
self.progress_queue,
self.kind,
key_range)
key_range_item.MarkAsRead()
thread_pool.SubmitItem(key_range_item, block=True)
def _AddUnfinishedRanges(self, thread_pool, batch_size):
"""Adds incomplete KeyRanges to the thread_pool.
Args:
thread_pool: An AdaptiveThreadPool instance.
batch_size: The number of entities to transfer per request.
Returns:
A list of KeyRanges representing incomplete datastore key ranges.
Raises:
KeyRangeError: if this key range has already been completely transferred.
"""
assert self.download_result is not None
if self.download_result.continued:
self._SplitAndAddRanges(thread_pool, batch_size)
else:
raise KeyRangeError('No unfinished part of key range.')
class DownloadItem(KeyRangeItem):
"""A KeyRangeItem for downloading key ranges."""
def _TransferItem(self, thread_pool, get_time=time.time):
"""Transfers the entities associated with an item."""
t = get_time()
download_result = self.request_manager.GetEntities(
self, retry_parallel=self.first)
transfer_time = get_time() - t
self.Process(download_result, thread_pool,
self.request_manager.batch_size)
return transfer_time
class MapperItem(KeyRangeItem):
"""A KeyRangeItem for mapping over key ranges."""
def _TransferItem(self, thread_pool, get_time=time.time):
t = get_time()
mapper = self.request_manager.GetMapper(self.kind)
download_result = self.request_manager.GetEntities(
self, keys_only=mapper.map_over_keys_only(), retry_parallel=self.first)
transfer_time = get_time() - t
try:
mapper.batch_apply(download_result.Entities())
except MapperRetry:
return None
self.Process(download_result, thread_pool,
self.request_manager.batch_size)
return transfer_time
def ConvertKeys(keys):
"""Convert a list of keys to a list of keys with the app_id of the caller.
Args:
keys: A list of datastore Entity Keys.
Returns:
A new list of keys in the same order as the input with app_id set to the
default app_id in the calling context. Whichever input keys were already
of this app_id are copied by reference.
"""
def ChangeApp(key, app_id):
if key.app() == app_id:
return key
return datastore.Key.from_path(namespace=key.namespace(),
_app=app_id, *key.to_path())
app_id = datastore.Key.from_path('kind', 'name').app()
return [ChangeApp(key, app_id) for key in keys]
def ReserveKeys(keys):
"""Reserve all ids in the paths of the given keys.
Args:
keys: A list of keys with ids in their paths, for which the corresponding
id sequences should be advanced to prevent id collisions.
"""
datastore._GetConnection()._reserve_keys(ConvertKeys(keys))
def _AuthFunction(host, email, passin, raw_input_fn, password_input_fn):
"""Internal method shared between RequestManager and _GetRemoteAppId.
Args:
host: Hostname to present to the user.
email: Existing email address to use; if none, will prompt the user.
passin: Value of the --passin command line flag. If true, will get the
password using raw_input_fn insetad of password_input_fn.
raw_input_fn: Method to get a string, typically raw_input.
password_input_fn: Method to get a string, typically getpass.
Returns:
Pair, (email, password).
"""
if not email:
print 'Please enter login credentials for %s' % host
email = raw_input_fn('Email: ')
if email:
password_prompt = 'Password for %s: ' % email
if passin:
password = raw_input_fn(password_prompt)
else:
password = password_input_fn(password_prompt)
else:
password = None
return email, password
class RequestManager(object):
"""A class which wraps a connection to the server."""
def __init__(self,
app_id,
host_port,
url_path,
kind,
throttle,
batch_size,
secure,
email,
passin,
dry_run=False,
server=None,
throttle_class=None):
"""Initialize a RequestManager object.
Args:
app_id: String containing the application id for requests.
host_port: String containing the "host:port" pair; the port is optional.
url_path: partial URL (path) to post entity data to.
kind: Kind of the Entity records being posted.
throttle: A Throttle instance.
batch_size: The number of entities to transfer per request.
secure: Use SSL when communicating with server.
email: If not none, the username to log in with.
passin: If True, the password will be read from standard in.
server: An existing AbstractRpcServer to reuse.
throttle_class: A class to use instead of the default
ThrottledHttpRpcServer.
"""
self.app_id = app_id
self.host_port = host_port
self.host = host_port.split(':')[0]
if url_path and url_path[0] != '/':
url_path = '/' + url_path
self.url_path = url_path
self.kind = kind
self.throttle = throttle
self.batch_size = batch_size
self.secure = secure
self.authenticated = False
self.auth_called = False
self.parallel_download = True
self.email = email
self.passin = passin
self.mapper = None
self.dry_run = dry_run
if self.dry_run:
logger.info('Running in dry run mode, skipping remote_api setup')
return
logger.debug('Configuring remote_api. url_path = %s, '
'servername = %s' % (url_path, host_port))
throttled_rpc_server_factory = (
remote_api_throttle.ThrottledHttpRpcServerFactory(
self.throttle, throttle_class=throttle_class))
if server:
remote_api_stub.ConfigureRemoteApiFromServer(server, url_path, app_id)
else:
remote_api_stub.ConfigureRemoteApi(
app_id,
url_path,
self.AuthFunction,
servername=host_port,
rpc_server_factory=throttled_rpc_server_factory,
secure=self.secure)
remote_api_throttle.ThrottleRemoteDatastore(self.throttle)
logger.debug('Bulkloader using app_id: %s', os.environ['APPLICATION_ID'])
def Authenticate(self):
"""Invoke authentication if necessary."""
logger.info('Connecting to %s%s', self.host_port, self.url_path)
if self.dry_run:
self.authenticated = True
return
remote_api_stub.MaybeInvokeAuthentication()
self.authenticated = True
def AuthFunction(self,
raw_input_fn=raw_input,
password_input_fn=getpass.getpass):
"""Prompts the user for a username and password.
Caches the results the first time it is called and returns the
same result every subsequent time.
Args:
raw_input_fn: Used for dependency injection.
password_input_fn: Used for dependency injection.
Returns:
A pair of the username and password.
"""
self.auth_called = True
return _AuthFunction(self.host, self.email, self.passin,
raw_input_fn, password_input_fn)
def ReserveKeys(self, keys):
"""Reserve all ids in the paths of the given keys.
Args:
keys: A list of keys with ids in their paths, for which the corresponding
id sequences should be advanced to prevent id collisions.
"""
if self.dry_run:
return
ReserveKeys(keys)
def GetSchemaKinds(self):
"""Returns the list of kinds for this app.
There can be 3 possible cases using namespaces:
a.) No namespace specified and Datastore has only default namespace ->
Query GlobalStat and KindStat.
b.) No namespace specified but Datastore has multiple namespace ->
Query NamespaceGlobalStat and NamespaceKindStat.
c.) Namespace specified and Datastore has multiple namespaces ->
Query NamespaceGlobalStat and NamespaceKindStat.
Returns:
A list of kinds.
"""
namespaces = False
if (namespace_manager.get_namespace() or
stats.NamespaceStat.all().count() > 1):
namespaces = True
if namespaces:
global_kind = stats.NamespaceGlobalStat
else:
global_kind = stats.GlobalStat
kinds_kind = stats.NamespaceKindStat if namespaces else stats.KindStat
global_stat = global_kind.all().get()
if not global_stat:
raise KindStatError()
timestamp = global_stat.timestamp
kind_stat = kinds_kind.all().filter(
"timestamp =", timestamp).fetch(1000)
kind_list = [stat.kind_name for stat in kind_stat
if stat.kind_name and not stat.kind_name.startswith('__')]
return list(set(kind_list))
def EncodeContent(self, rows, loader=None):
"""Encodes row data to the wire format.
Args:
rows: A list of pairs of a line number and a list of column values.
loader: Used for dependency injection.
Returns:
A list of datastore.Entity instances.
Raises:
ConfigurationError: if no loader is defined for self.kind
"""
if not loader:
try:
loader = Loader.RegisteredLoader(self.kind)
except KeyError:
logger.error('No Loader defined for kind %s.' % self.kind)
raise ConfigurationError('No Loader defined for kind %s.' % self.kind)
entities = []
for line_number, values in rows:
key = loader.generate_key(line_number, values)
if isinstance(key, datastore.Key):
parent = key.parent()
key = key.name()
else:
parent = None
entity = loader.create_entity(values, key_name=key, parent=parent)
def ToEntity(entity):
if isinstance(entity, db.Model):
return entity._populate_entity()
else:
return entity
if not entity:
continue
if isinstance(entity, list):
entities.extend(map(ToEntity, entity))
elif entity:
entities.append(ToEntity(entity))
return entities
def PostEntities(self, entities):
"""Posts Entity records to a remote endpoint over HTTP.
Args:
entities: A list of datastore entities.
"""
if self.dry_run:
return
datastore.Put(entities)
def _QueryForPbs(self, query):
"""Perform the given query and return a list of entity_pb's."""
try:
query_pb = query._ToPb(limit=self.batch_size, count=self.batch_size)
result_pb = datastore_pb.QueryResult()
apiproxy_stub_map.MakeSyncCall('datastore_v3', 'RunQuery', query_pb,
result_pb)
results = result_pb.result_list()
while result_pb.more_results():
next_pb = datastore_pb.NextRequest()
next_pb.set_count(self.batch_size - len(results))
next_pb.mutable_cursor().CopyFrom(result_pb.cursor())
result_pb = datastore_pb.QueryResult()
apiproxy_stub_map.MakeSyncCall('datastore_v3', 'Next', next_pb,
result_pb)
results += result_pb.result_list()
return results
except apiproxy_errors.ApplicationError, e:
raise datastore._ToDatastoreError(e)
def GetEntities(
self, key_range_item, key_factory=datastore.Key, keys_only=False,
retry_parallel=False):
"""Gets Entity records from a remote endpoint over HTTP.
Args:
key_range_item: Range of keys to get.
key_factory: Used for dependency injection.
keys_only: bool, default False, only get keys values
retry_parallel: bool, default False, to try a parallel download despite
past parallel download failures.
Returns:
A DownloadResult instance.
Raises:
ConfigurationError: if no Exporter is defined for key_range_item.kind
"""
keys = []
entities = []
kind = key_range_item.kind
if retry_parallel:
self.parallel_download = True
if self.parallel_download:
query = key_range_item.key_range.make_directed_datastore_query(
kind, keys_only=keys_only)
try:
results = self._QueryForPbs(query)
except datastore_errors.NeedIndexError:
logger.info('%s: No descending index on __key__, '
'performing serial download', kind)
self.parallel_download = False
if not self.parallel_download:
key_range_item.key_range.direction = key_range_module.KeyRange.ASC
query = key_range_item.key_range.make_ascending_datastore_query(
kind, keys_only=keys_only)
results = self._QueryForPbs(query)
size = len(results)
for entity in results:
key = key_factory()
key._Key__reference = entity.key()
entities.append(entity)
keys.append(key)
continued = (size == self.batch_size)
key_range_item.count = size
return DownloadResult(continued, key_range_item.key_range.direction,
keys, entities)
def GetMapper(self, kind):
"""Returns a mapper for the registered kind.
Returns:
A Mapper instance.
Raises:
ConfigurationError: if no Mapper is defined for kind
"""
if not self.mapper:
try:
self.mapper = Mapper.RegisteredMapper(kind)
except KeyError:
logger.error('No Mapper defined for kind %s.' % kind)
raise ConfigurationError('No Mapper defined for kind %s.' % kind)
return self.mapper
def InterruptibleSleep(sleep_time):
"""Puts thread to sleep, checking this threads exit_flag twice a second.
Args:
sleep_time: Time to sleep.
"""
slept = 0.0
epsilon = .0001
thread = threading.currentThread()
while slept < sleep_time - epsilon:
remaining = sleep_time - slept
this_sleep_time = min(remaining, 0.5)
time.sleep(this_sleep_time)
slept += this_sleep_time
if thread.exit_flag:
return
class _ThreadBase(threading.Thread):
"""Provide some basic features for the threads used in the uploader.
This abstract base class is used to provide some common features:
* Flag to ask thread to exit as soon as possible.
* Record exit/error status for the primary thread to pick up.
* Capture exceptions and record them for pickup.
* Some basic logging of thread start/stop.
* All threads are "daemon" threads.
* Friendly names for presenting to users.
Concrete sub-classes must implement PerformWork().
Either self.NAME should be set or GetFriendlyName() be overridden to
return a human-friendly name for this thread.
The run() method starts the thread and prints start/exit messages.
self.exit_flag is intended to signal that this thread should exit
when it gets the chance. PerformWork() should check self.exit_flag
whenever it has the opportunity to exit gracefully.
"""
def __init__(self):
threading.Thread.__init__(self)
self.setDaemon(True)
self.exit_flag = False
self.error = None
self.traceback = None
def run(self):
"""Perform the work of the thread."""
logger.debug('[%s] %s: started', self.getName(), self.__class__.__name__)
try:
self.PerformWork()
except:
self.SetError()
logger.exception('[%s] %s:', self.getName(), self.__class__.__name__)
logger.debug('[%s] %s: exiting', self.getName(), self.__class__.__name__)
def SetError(self):
"""Sets the error and traceback information for this thread.
This must be called from an exception handler.
"""
if not self.error:
exc_info = sys.exc_info()
self.error = exc_info[1]
self.traceback = exc_info[2]
def PerformWork(self):
"""Perform the thread-specific work."""
raise NotImplementedError()
def CheckError(self):
"""If an error is present, then log it."""
if self.error:
logger.error('Error in %s: %s', self.GetFriendlyName(), self.error)
if self.traceback:
logger.debug(''.join(traceback.format_exception(self.error.__class__,
self.error,
self.traceback)))
def GetFriendlyName(self):
"""Returns a human-friendly description of the thread."""
if hasattr(self, 'NAME'):
return self.NAME
return 'unknown thread'
non_fatal_error_codes = set([errno.EAGAIN,
errno.ENETUNREACH,
errno.ENETRESET,
errno.ECONNRESET,
errno.ETIMEDOUT,
errno.EHOSTUNREACH])
def IsURLErrorFatal(error):
"""Returns False if the given URLError may be from a transient failure.
Args:
error: A urllib2.URLError instance.
"""
assert isinstance(error, urllib2.URLError)
if not hasattr(error, 'reason'):
return True
if not isinstance(error.reason[0], int):
return True
return error.reason[0] not in non_fatal_error_codes
class DataSourceThread(_ThreadBase):
"""A thread which reads WorkItems and pushes them into queue.
This thread will read/consume WorkItems from a generator (produced by
the generator factory). These WorkItems will then be pushed into the
thread_pool. Note that reading will block if/when the thread_pool becomes
full. Information on content consumed from the generator will be pushed
into the progress_queue.
"""
NAME = 'data source thread'
def __init__(self,
request_manager,
kinds,
thread_pool,
progress_queue,
workitem_generator_factory,
progress_generator_factory):
"""Initialize the DataSourceThread instance.
Args:
request_manager: A RequestManager instance.
kinds: The kinds of entities being transferred.
thread_pool: An AdaptiveThreadPool instance.
progress_queue: A queue used for tracking progress information.
workitem_generator_factory: A factory that creates a WorkItem generator
progress_generator_factory: A factory that creates a generator which
produces prior progress status, or None if there is no prior status
to use.
"""
_ThreadBase.__init__(self)
self.request_manager = request_manager
self.kinds = kinds
self.thread_pool = thread_pool
self.progress_queue = progress_queue
self.workitem_generator_factory = workitem_generator_factory
self.progress_generator_factory = progress_generator_factory
self.entity_count = 0
def PerformWork(self):
"""Performs the work of a DataSourceThread."""
if self.progress_generator_factory:
progress_gen = self.progress_generator_factory()
else:
progress_gen = None
content_gen = self.workitem_generator_factory(self.request_manager,
self.progress_queue,
progress_gen,
self.kinds)
self.xfer_count = 0
self.read_count = 0
self.read_all = False
for item in content_gen.Batches():
item.MarkAsRead()
while not self.exit_flag:
try:
self.thread_pool.SubmitItem(item, block=True, timeout=1.0)
self.entity_count += item.count
break
except Queue.Full:
pass
if self.exit_flag:
break
if not self.exit_flag:
self.read_all = True
self.read_count = content_gen.row_count
self.xfer_count = content_gen.xfer_count
def _RunningInThread(thread):
"""Return True if we are running within the specified thread."""
return threading.currentThread().getName() == thread.getName()
class _Database(object):
"""Base class for database connections in this module.
The table is created by a primary thread (the python main thread)
but all future lookups and updates are performed by a secondary
thread.
"""
SIGNATURE_TABLE_NAME = 'bulkloader_database_signature'
def __init__(self,
db_filename,
create_table,
signature,
index=None,
commit_periodicity=100):
"""Initialize the _Database instance.
Args:
db_filename: The sqlite3 file to use for the database.
create_table: A string containing the SQL table creation command.
signature: A string identifying the important invocation options,
used to make sure we are not using an old database.
index: An optional string to create an index for the database.
commit_periodicity: Number of operations between database commits.
"""
self.db_filename = db_filename
logger.info('Opening database: %s', db_filename)
self.primary_conn = sqlite3.connect(db_filename, isolation_level=None)
self.primary_thread = threading.currentThread()
self.secondary_conn = None
self.secondary_thread = None
self.operation_count = 0
self.commit_periodicity = commit_periodicity
try:
self.primary_conn.execute(create_table)
except sqlite3.OperationalError, e:
if 'already exists' not in e.message:
raise
if index:
try:
self.primary_conn.execute(index)
except sqlite3.OperationalError, e:
if 'already exists' not in e.message:
raise
self.existing_table = False
signature_cursor = self.primary_conn.cursor()
create_signature = """
create table %s (
value TEXT not null)
""" % _Database.SIGNATURE_TABLE_NAME
try:
self.primary_conn.execute(create_signature)
self.primary_conn.cursor().execute(
'insert into %s (value) values (?)' % _Database.SIGNATURE_TABLE_NAME,
(signature,))
except sqlite3.OperationalError, e:
if 'already exists' not in e.message:
logger.exception('Exception creating table:')
raise
else:
self.existing_table = True
signature_cursor.execute(
'select * from %s' % _Database.SIGNATURE_TABLE_NAME)
(result,) = signature_cursor.fetchone()
if result and result != signature:
logger.error('Database signature mismatch:\n\n'
'Found:\n'
'%s\n\n'
'Expecting:\n'
'%s\n',
result, signature)
raise ResumeError('Database signature mismatch: %s != %s' % (
signature, result))
def ThreadComplete(self):
"""Finalize any operations the secondary thread has performed.
The database aggregates lots of operations into a single commit, and
this method is used to commit any pending operations as the thread
is about to shut down.
"""
if self.secondary_conn:
self._MaybeCommit(force_commit=True)
def _MaybeCommit(self, force_commit=False):
"""Periodically commit changes into the SQLite database.
Committing every operation is quite expensive, and slows down the
operation of the script. Thus, we only commit after every N operations,
as determined by the self.commit_periodicity value. Optionally, the
caller can force a commit.
Args:
force_commit: Pass True in order for a commit to occur regardless
of the current operation count.
"""
self.operation_count += 1
if force_commit or (self.operation_count % self.commit_periodicity) == 0:
self.secondary_conn.commit()
def _OpenSecondaryConnection(self):
"""Possibly open a database connection for the secondary thread.
If the connection is not open (for the calling thread, which is assumed
to be the unique secondary thread), then open it. We also open a couple
cursors for later use (and reuse).
"""
if self.secondary_conn:
return
assert not _RunningInThread(self.primary_thread)
self.secondary_thread = threading.currentThread()
self.secondary_conn = sqlite3.connect(self.db_filename)
self.insert_cursor = self.secondary_conn.cursor()
self.update_cursor = self.secondary_conn.cursor()
zero_matcher = re.compile(r'\x00')
zero_one_matcher = re.compile(r'\x00\x01')
def KeyStr(key):
"""Returns a string to represent a key, preserving ordering.
Unlike datastore.Key.__str__(), we have the property:
key1 < key2 ==> KeyStr(key1) < KeyStr(key2)
The key string is constructed from the key path as follows:
(1) Strings are prepended with ':' and numeric id's are padded to
20 digits.
(2) Any null characters (u'\0') present are replaced with u'\0\1'
(3) The sequence u'\0\0' is used to separate each component of the path.
(1) assures that names and ids compare properly, while (2) and (3) enforce
the part-by-part comparison of pieces of the path.
Args:
key: A datastore.Key instance.
Returns:
A string representation of the key, which preserves ordering.
"""
assert isinstance(key, datastore.Key)
path = key.to_path()
out_path = []
for part in path:
if isinstance(part, (int, long)):
part = '%020d' % part
else:
part = ':%s' % part
out_path.append(zero_matcher.sub(u'\0\1', part))
out_str = u'\0\0'.join(out_path)
return out_str
def StrKey(key_str):
"""The inverse of the KeyStr function.
Args:
key_str: A string in the range of KeyStr.
Returns:
A datastore.Key instance k, such that KeyStr(k) == key_str.
"""
parts = key_str.split(u'\0\0')
for i in xrange(len(parts)):
if parts[i][0] == ':':
part = parts[i][1:]
part = zero_one_matcher.sub(u'\0', part)
parts[i] = part
else:
parts[i] = int(parts[i])
return datastore.Key.from_path(*parts)
class ResultDatabase(_Database):
"""Persistently record all the entities downloaded during an export.
The entities are held in the database by their unique datastore key
in order to avoid duplication if an export is restarted.
"""
def __init__(self, db_filename, signature, commit_periodicity=1,
exporter=None):
"""Initialize a ResultDatabase object.
Args:
db_filename: The name of the SQLite database to use.
signature: A string identifying the important invocation options,
used to make sure we are not using an old database.
commit_periodicity: How many operations to perform between commits.
exporter: Exporter instance; if exporter.calculate_sort_key_from_entity
is true then exporter.sort_key_from_entity(entity) will be called.
"""
self.complete = False
create_table = ('create table result (\n'
'id BLOB primary key,\n'
'value BLOB not null,\n'
'sort_key BLOB)')
_Database.__init__(self,
db_filename,
create_table,
signature,
commit_periodicity=commit_periodicity)
if self.existing_table:
cursor = self.primary_conn.cursor()
cursor.execute('select count(*) from result')
self.existing_count = int(cursor.fetchone()[0])
else:
self.existing_count = 0
self.count = self.existing_count
if exporter and getattr(exporter, 'calculate_sort_key_from_entity', False):
self.sort_key_from_entity = exporter.sort_key_from_entity
else:
self.sort_key_from_entity = None
def _StoreEntity(self, entity_id, entity):
"""Store an entity in the result database.
Args:
entity_id: A datastore.Key for the entity.
entity: The entity to store.
Returns:
True if this entities is not already present in the result database.
"""
assert _RunningInThread(self.secondary_thread)
assert isinstance(entity_id, datastore.Key), (
'expected a datastore.Key, got a %s' % entity_id.__class__.__name__)
key_str = buffer(KeyStr(entity_id).encode('utf-8'))
self.insert_cursor.execute(
'select count(*) from result where id = ?', (key_str,))
already_present = self.insert_cursor.fetchone()[0]
result = True
if already_present:
result = False
self.insert_cursor.execute('delete from result where id = ?',
(key_str,))
else:
self.count += 1
if self.sort_key_from_entity:
sort_key = self.sort_key_from_entity(datastore.Entity._FromPb(entity))
else:
sort_key = ''
value = entity.Encode()
self.insert_cursor.execute(
'insert into result (id, value, sort_key) values (?, ?, ?)',
(key_str, buffer(value), sort_key))
return result
def StoreEntities(self, keys, entities):
"""Store a group of entities in the result database.
Args:
keys: A list of entity keys.
entities: A list of entities.
Returns:
The number of new entities stored in the result database.
"""
self._OpenSecondaryConnection()
t = time.time()
count = 0
for entity_id, entity in zip(keys,
entities):
if self._StoreEntity(entity_id, entity):
count += 1
logger.debug('%s insert: delta=%.3f',
self.db_filename,
time.time() - t)
logger.debug('Entities transferred total: %s', self.count)
self._MaybeCommit()
return count
def ResultsComplete(self):
"""Marks the result database as containing complete results."""
self.complete = True
def AllEntities(self):
"""Yields all pairs of (id, value) from the result table."""
conn = sqlite3.connect(self.db_filename, isolation_level=None)
cursor = conn.cursor()
cursor.execute(
'select id, value from result order by sort_key, id')
for unused_entity_id, entity in cursor:
entity_proto = entity_pb.EntityProto(contents=entity)
yield datastore.Entity._FromPb(entity_proto)
class _ProgressDatabase(_Database):
"""Persistently record all progress information during an upload.
This class wraps a very simple SQLite database which records each of
the relevant details from a chunk of work. If the loader is
resumed, then data is replayed out of the database.
"""
def __init__(self,
db_filename,
sql_type,
py_type,
signature,
commit_periodicity=100):
"""Initialize the ProgressDatabase instance.
Args:
db_filename: The name of the SQLite database to use.
sql_type: A string of the SQL type to use for entity keys.
py_type: The python type of entity keys.
signature: A string identifying the important invocation options,
used to make sure we are not using an old database.
commit_periodicity: How many operations to perform between commits.
"""
self.prior_key_end = None
create_table = ('create table progress (\n'
'id integer primary key autoincrement,\n'
'state integer not null,\n'
'kind text not null,\n'
'key_start %s,\n'
'key_end %s)'
% (sql_type, sql_type))
self.py_type = py_type
index = 'create index i_state on progress (state)'
_Database.__init__(self,
db_filename,
create_table,
signature,
index=index,
commit_periodicity=commit_periodicity)
def UseProgressData(self):
"""Returns True if the database has progress information.
Note there are two basic cases for progress information:
1) All saved records indicate a successful upload. In this case, we
need to skip everything transmitted so far and then send the rest.
2) Some records for incomplete transfer are present. These need to be
sent again, and then we resume sending after all the successful
data.
Returns:
True: if the database has progress information.
Raises:
ResumeError: if there is an error retrieving rows from the database.
"""
assert _RunningInThread(self.primary_thread)
cursor = self.primary_conn.cursor()
cursor.execute('select count(*) from progress')
row = cursor.fetchone()
if row is None:
raise ResumeError('Cannot retrieve progress information from database.')
return row[0] != 0
def StoreKeys(self, kind, key_start, key_end):
"""Record a new progress record, returning a key for later updates.
The specified progress information will be persisted into the database.
A unique key will be returned that identifies this progress state. The
key is later used to (quickly) update this record.
For the progress resumption to proceed properly, calls to StoreKeys
MUST specify monotonically increasing key ranges. This will result in
a database whereby the ID, KEY_START, and KEY_END rows are all
increasing (rather than having ranges out of order).
NOTE: the above precondition is NOT tested by this method (since it
would imply an additional table read or two on each invocation).
Args:
kind: The kind for the WorkItem
key_start: The starting key of the WorkItem (inclusive)
key_end: The end key of the WorkItem (inclusive)
Returns:
A string to later be used as a unique key to update this state.
"""
self._OpenSecondaryConnection()
assert _RunningInThread(self.secondary_thread)
assert (not key_start) or isinstance(key_start, self.py_type), (
'%s is a %s, %s expected %s' % (key_start,
key_start.__class__,
self.__class__.__name__,
self.py_type))
assert (not key_end) or isinstance(key_end, self.py_type), (
'%s is a %s, %s expected %s' % (key_end,
key_end.__class__,
self.__class__.__name__,
self.py_type))
assert KeyLEQ(key_start, key_end), '%s not less than %s' % (
repr(key_start), repr(key_end))
self.insert_cursor.execute(
'insert into progress (state, kind, key_start, key_end)'
' values (?, ?, ?, ?)',
(STATE_READ, unicode(kind), unicode(key_start), unicode(key_end)))
progress_key = self.insert_cursor.lastrowid
self._MaybeCommit()
return progress_key
def UpdateState(self, key, new_state):
"""Update a specified progress record with new information.
Args:
key: The key for this progress record, returned from StoreKeys
new_state: The new state to associate with this progress record.
"""
self._OpenSecondaryConnection()
assert _RunningInThread(self.secondary_thread)
assert isinstance(new_state, int)
self.update_cursor.execute('update progress set state=? where id=?',
(new_state, key))
self._MaybeCommit()
def DeleteKey(self, progress_key):
"""Delete the entities with the given key from the result database."""
self._OpenSecondaryConnection()
assert _RunningInThread(self.secondary_thread)
t = time.time()
self.insert_cursor.execute(
'delete from progress where rowid = ?', (progress_key,))
logger.debug('delete: delta=%.3f', time.time() - t)
self._MaybeCommit()
def GetProgressStatusGenerator(self):
"""Get a generator which yields progress information.
The returned generator will yield a series of 5-tuples that specify
progress information about a prior run of the uploader. The 5-tuples
have the following values:
progress_key: The unique key to later update this record with new
progress information.
state: The last state saved for this progress record.
kind: The datastore kind of the items for uploading.
key_start: The starting key of the items for uploading (inclusive).
key_end: The ending key of the items for uploading (inclusive).
After all incompletely-transferred records are provided, then one
more 5-tuple will be generated:
None
DATA_CONSUMED_TO_HERE: A unique string value indicating this record
is being provided.
None
None
key_end: An integer value specifying the last data source key that
was handled by the previous run of the uploader.
The caller should begin uploading records which occur after key_end.
Yields:
Five-tuples of (progress_key, state, kind, key_start, key_end)
"""
conn = sqlite3.connect(self.db_filename, isolation_level=None)
cursor = conn.cursor()
cursor.execute('select max(key_end) from progress')
result = cursor.fetchone()
if result is not None:
key_end = result[0]
else:
logger.debug('No rows in progress database.')
return
self.prior_key_end = key_end
cursor.execute(
'select id, state, kind, key_start, key_end from progress'
' where state != ?'
' order by id',
(STATE_SENT,))
rows = cursor.fetchall()
for row in rows:
if row is None:
break
progress_key, state, kind, key_start, key_end = row
yield progress_key, state, kind, key_start, key_end
yield None, DATA_CONSUMED_TO_HERE, None, None, key_end
def ProgressDatabase(db_filename, signature):
"""Returns a database to store upload progress information."""
return _ProgressDatabase(db_filename, 'INTEGER', int, signature)
class ExportProgressDatabase(_ProgressDatabase):
"""A database to store download progress information."""
def __init__(self, db_filename, signature):
"""Initialize an ExportProgressDatabase."""
_ProgressDatabase.__init__(self,
db_filename,
'TEXT',
datastore.Key,
signature,
commit_periodicity=1)
def UseProgressData(self):
"""Check if the progress database contains progress data.
Returns:
True: if the database contains progress data.
"""
return self.existing_table
class StubProgressDatabase(object):
"""A stub implementation of ProgressDatabase which does nothing."""
def UseProgressData(self):
"""Whether the stub database has progress information (it doesn't)."""
return False
def StoreKeys(self, unused_kind, unused_key_start, unused_key_end):
"""Pretend to store a key in the stub database."""
return 'fake-key'
def UpdateState(self, unused_key, unused_new_state):
"""Pretend to update the state of a progress item."""
pass
def ThreadComplete(self):
"""Finalize operations on the stub database (i.e. do nothing)."""
pass
def DeleteKey(self, unused_key):
"""Delete the operations with a given key (but, do nothing.)"""
pass
class _ProgressThreadBase(_ThreadBase):
"""A thread which records progress information for the upload process.
The progress information is stored into the provided progress database.
This class is not responsible for replaying a prior run's progress
information out of the database. Separate mechanisms must be used to
resume a prior upload attempt.
"""
NAME = 'progress tracking thread'
def __init__(self, progress_queue, progress_db):
"""Initialize the ProgressTrackerThread instance.
Args:
progress_queue: A Queue used for tracking progress information.
progress_db: The database for tracking progress information; should
be an instance of ProgressDatabase.
"""
_ThreadBase.__init__(self)
self.progress_queue = progress_queue
self.db = progress_db
self.entities_transferred = 0
def EntitiesTransferred(self):
"""Return the total number of unique entities transferred."""
return self.entities_transferred
def UpdateProgress(self, item):
"""Updates the progress information for the given item.
Args:
item: A work item whose new state will be recorded
"""
raise NotImplementedError()
def WorkFinished(self):
"""Performs final actions after the entity transfer is complete."""
raise NotImplementedError()
def PerformWork(self):
"""Performs the work of a ProgressTrackerThread."""
while not self.exit_flag:
try:
item = self.progress_queue.get(block=True, timeout=1.0)
except Queue.Empty:
continue
if item == _THREAD_SHOULD_EXIT:
break
if item.state == STATE_READ and item.progress_key is None:
item.progress_key = self.db.StoreKeys(item.kind,
item.key_start,
item.key_end)
else:
assert item.progress_key is not None
self.UpdateProgress(item)
item.progress_event.set()
self.progress_queue.task_done()
self.db.ThreadComplete()
class ProgressTrackerThread(_ProgressThreadBase):
"""A thread which records progress information for the upload process.
The progress information is stored into the provided progress database.
This class is not responsible for replaying a prior run's progress
information out of the database. Separate mechanisms must be used to
resume a prior upload attempt.
"""
NAME = 'progress tracking thread'
def __init__(self, progress_queue, progress_db):
"""Initialize the ProgressTrackerThread instance.
Args:
progress_queue: A Queue used for tracking progress information.
progress_db: The database for tracking progress information; should
be an instance of ProgressDatabase.
"""
_ProgressThreadBase.__init__(self, progress_queue, progress_db)
def UpdateProgress(self, item):
"""Update the state of the given WorkItem.
Args:
item: A WorkItem instance.
"""
self.db.UpdateState(item.progress_key, item.state)
if item.state == STATE_SENT:
self.entities_transferred += item.count
def WorkFinished(self):
"""Performs final actions after the entity transfer is complete."""
pass
class ExportProgressThread(_ProgressThreadBase):
"""A thread to record progress information and write record data for exports.
The progress information is stored into a provided progress database.
Exported results are stored in the result database and dumped to an output
file at the end of the download.
"""
def __init__(self, exporter, progress_queue, progress_db, result_db):
"""Initialize the ExportProgressThread instance.
Args:
exporter: An Exporter instance for the download.
progress_queue: A Queue used for tracking progress information.
progress_db: The database for tracking progress information; should
be an instance of ProgressDatabase.
result_db: The database for holding exported entities; should be an
instance of ResultDatabase.
"""
_ProgressThreadBase.__init__(self, progress_queue, progress_db)
self.exporter = exporter
self.existing_count = result_db.existing_count
self.result_db = result_db
def EntitiesTransferred(self):
"""Return the total number of unique entities transferred."""
return self.result_db.count
def WorkFinished(self):
"""Write the contents of the result database."""
self.exporter.output_entities(self.result_db.AllEntities())
def UpdateProgress(self, item):
"""Update the state of the given KeyRangeItem.
Args:
item: A KeyRange instance.
"""
if item.state == STATE_GOT:
count = self.result_db.StoreEntities(item.download_result.keys,
item.download_result.entities)
self.db.DeleteKey(item.progress_key)
self.entities_transferred += count
else:
self.db.UpdateState(item.progress_key, item.state)
class MapperProgressThread(_ProgressThreadBase):
"""A thread to record progress information for maps over the datastore."""
def __init__(self, mapper, progress_queue, progress_db):
"""Initialize the MapperProgressThread instance.
Args:
mapper: A Mapper object for this map run.
progress_queue: A Queue used for tracking progress information.
progress_db: The database for tracking progress information; should
be an instance of ProgressDatabase.
"""
_ProgressThreadBase.__init__(self, progress_queue, progress_db)
self.mapper = mapper
def EntitiesTransferred(self):
"""Return the total number of unique entities transferred."""
return self.entities_transferred
def WorkFinished(self):
"""Perform actions after map is complete."""
pass
def UpdateProgress(self, item):
"""Update the state of the given KeyRangeItem.
Args:
item: A KeyRange instance.
"""
if item.state == STATE_GOT:
self.entities_transferred += item.count
self.db.DeleteKey(item.progress_key)
else:
self.db.UpdateState(item.progress_key, item.state)
def ParseKey(key_string):
"""Turn a key stored in the database into a Key or None.
Args:
key_string: The string representation of a Key.
Returns:
A datastore.Key instance or None
"""
if not key_string:
return None
if key_string == 'None':
return None
return datastore.Key(encoded=key_string)
def Validate(value, typ):
"""Checks that value is non-empty and of the right type.
Args:
value: any value
typ: a type or tuple of types
Raises:
ValueError: if value is None or empty.
TypeError: if it's not the given type.
"""
if not value:
raise ValueError('Value should not be empty; received %s.' % value)
elif not isinstance(value, typ):
raise TypeError('Expected a %s, but received %s (a %s).' %
(typ, value, value.__class__))
def CheckFile(filename):
"""Check that the given file exists and can be opened for reading.
Args:
filename: The name of the file.
Raises:
FileNotFoundError: if the given filename is not found
FileNotReadableError: if the given filename is not readable.
"""
if not os.path.exists(filename):
raise FileNotFoundError('%s: file not found' % filename)
elif not os.access(filename, os.R_OK):
raise FileNotReadableError('%s: file not readable' % filename)
class Loader(object):
"""A base class for creating datastore entities from input data.
To add a handler for bulk loading a new entity kind into your datastore,
write a subclass of this class that calls Loader.__init__ from your
class's __init__.
If you need to run extra code to convert entities from the input
data, create new properties, or otherwise modify the entities before
they're inserted, override handle_entity.
See the create_entity method for the creation of entities from the
(parsed) input data.
"""
__loaders = {}
kind = None
__properties = None
def __init__(self, kind, properties):
"""Constructor.
Populates this Loader's kind and properties map.
Args:
kind: a string containing the entity kind that this loader handles
properties: list of (name, converter) tuples.
This is used to automatically convert the input columns into
properties. The converter should be a function that takes one
argument, a string value from the input file, and returns a
correctly typed property value that should be inserted. The
tuples in this list should match the columns in your input file,
in order.
For example:
[('name', str),
('id_number', int),
('email', datastore_types.Email),
('user', users.User),
('birthdate', lambda x: datetime.datetime.fromtimestamp(float(x))),
('description', datastore_types.Text),
]
"""
Validate(kind, (basestring, tuple))
self.kind = kind
self.__openfile = open
self.__create_csv_reader = csv.reader
GetImplementationClass(kind)
Validate(properties, list)
for name, fn in properties:
Validate(name, basestring)
assert callable(fn), (
'Conversion function %s for property %s is not callable.' % (fn, name))
self.__properties = properties
@staticmethod
def RegisterLoader(loader):
"""Register loader and the Loader instance for its kind.
Args:
loader: A Loader instance.
"""
Loader.__loaders[loader.kind] = loader
def get_keys_to_reserve(self):
"""Returns keys with ids in their paths to be reserved.
Returns:
A list of keys used to advance the id sequences associated with
each id to prevent collisions with future ids.
"""
return []
def alias_old_names(self):
"""Aliases method names so that Loaders defined with old names work."""
aliases = (
('CreateEntity', 'create_entity'),
('HandleEntity', 'handle_entity'),
('GenerateKey', 'generate_key'),
)
for old_name, new_name in aliases:
setattr(Loader, old_name, getattr(Loader, new_name))
if hasattr(self.__class__, old_name) and not (
getattr(self.__class__, old_name).im_func ==
getattr(Loader, new_name).im_func):
if hasattr(self.__class__, new_name) and not (
getattr(self.__class__, new_name).im_func ==
getattr(Loader, new_name).im_func):
raise NameClashError(old_name, new_name, self.__class__)
setattr(self, new_name, getattr(self, old_name))
def create_entity(self, values, key_name=None, parent=None):
"""Creates a entity from a list of property values.
Args:
values: list/tuple of str
key_name: if provided, the name for the (single) resulting entity
parent: A datastore.Key instance for the parent, or None
Returns:
list of db.Model
The returned entities are populated with the property values from the
argument, converted to native types using the properties map given in
the constructor, and passed through handle_entity. They're ready to be
inserted.
Raises:
AssertionError: if the number of values doesn't match the number
of properties in the properties map.
ValueError: if any element of values is None or empty.
TypeError: if values is not a list or tuple.
"""
Validate(values, (list, tuple))
assert len(values) == len(self.__properties), (
'Expected %d columns, found %d.' %
(len(self.__properties), len(values)))
model_class = GetImplementationClass(self.kind)
properties = {
'key_name': key_name,
'parent': parent,
}
for (name, converter), val in zip(self.__properties, values):
if converter is bool and val.lower() in ('0', 'false', 'no'):
val = False
properties[name] = converter(val)
entity = model_class(**properties)
entities = self.handle_entity(entity)
if entities:
if not isinstance(entities, (list, tuple)):
entities = [entities]
for entity in entities:
if not isinstance(entity, db.Model):
raise TypeError('Expected a db.Model, received %s (a %s).' %
(entity, entity.__class__))
return entities
def generate_key(self, i, values):
"""Generates a key_name to be used in creating the underlying object.
The default implementation returns None.
This method can be overridden to control the key generation for
uploaded entities. The value returned should be None (to use a
server generated numeric key), or a string which neither starts
with a digit nor has the form __*__ (see
https://developers.google.com/appengine/docs/python/datastore/entities),
or a datastore.Key instance.
If you generate your own string keys, keep in mind:
1. The key name for each entity must be unique.
2. If an entity of the same kind and key already exists in the
datastore, it will be overwritten.
Args:
i: Number corresponding to this object (assume it's run in a loop,
this is your current count.
values: list/tuple of str.
Returns:
A string to be used as the key_name for an entity.
"""
return None
def handle_entity(self, entity):
"""Subclasses can override this to add custom entity conversion code.
This is called for each entity, after its properties are populated
from the input but before it is stored. Subclasses can override
this to add custom entity handling code.
The entity to be inserted should be returned. If multiple entities
should be inserted, return a list of entities. If no entities
should be inserted, return None or [].
Args:
entity: db.Model
Returns:
db.Model or list of db.Model
"""
return entity
def initialize(self, filename, loader_opts):
"""Performs initialization and validation of the input file.
This implementation checks that the input file exists and can be
opened for reading.
Args:
filename: The string given as the --filename flag argument.
loader_opts: The string given as the --loader_opts flag argument.
"""
CheckFile(filename)
def finalize(self):
"""Performs finalization actions after the upload completes."""
pass
def generate_records(self, filename):
"""Subclasses can override this to add custom data input code.
This method must yield fixed-length lists of strings.
The default implementation uses csv.reader to read CSV rows
from filename.
Args:
filename: The string input for the --filename option.
Yields:
Lists of strings.
"""
csv_generator = CSVGenerator(filename, openfile=self.__openfile,
create_csv_reader=self.__create_csv_reader
).Records()
return csv_generator
@staticmethod
def RegisteredLoaders():
"""Returns a dict of the Loader instances that have been created."""
return dict(Loader.__loaders)
@staticmethod
def RegisteredLoader(kind):
"""Returns the loader instance for the given kind if it exists."""
return Loader.__loaders[kind]
class RestoreThread(_ThreadBase):
"""A thread to read saved entity_pbs from sqlite3."""
NAME = 'RestoreThread'
_ENTITIES_DONE = 'Entities Done'
def __init__(self, queue, filename):
_ThreadBase.__init__(self)
self.queue = queue
self.filename = filename
def PerformWork(self):
db_conn = sqlite3.connect(self.filename)
cursor = db_conn.cursor()
cursor.execute('select id, value from result')
for entity_id, value in cursor:
self.queue.put(value, block=True)
self.queue.put(RestoreThread._ENTITIES_DONE, block=True)
class RestoreLoader(Loader):
"""A Loader which imports protobuffers from a file."""
def __init__(self, kind, app_id):
self.kind = kind
self.app_id = app_id
self.namespace = namespace_manager.get_namespace()
def initialize(self, filename, loader_opts):
CheckFile(filename)
self.queue = Queue.Queue(1000)
restore_thread = RestoreThread(self.queue, filename)
restore_thread.start()
self.keys_to_reserve = self._find_keys_to_reserve(
self.generate_records(filename))
restore_thread = RestoreThread(self.queue, filename)
restore_thread.start()
def get_keys_to_reserve(self):
"""Returns keys with ids in their paths to be reserved.
Returns:
A list of keys used to advance the id sequences associated with
each id to prevent collisions with future ids.
"""
return self.keys_to_reserve
def _find_keys_to_reserve(self, record_generator):
"""Find all entity keys with ids in their paths.
Args:
record_generator: A generator of entity_encoding strings.
Returns:
A list of keys to reserve.
"""
keys_to_reserve = []
for values in record_generator:
entity = self.create_entity(values)
key = entity.key()
for id_or_name in key.to_path()[1::2]:
if isinstance(id_or_name, (int, long)):
keys_to_reserve.append(key)
break
return keys_to_reserve
def generate_records(self, filename):
while True:
record = self.queue.get(block=True)
if id(record) == id(RestoreThread._ENTITIES_DONE):
break
entity_proto = entity_pb.EntityProto(contents=str(record))
fixed_entity_proto = self._translate_entity_proto(entity_proto)
yield datastore.Entity._FromPb(fixed_entity_proto)
def create_entity(self, values, key_name=None, parent=None):
return values
def rewrite_reference_proto(self, entity_namespace, reference_proto):
"""Transform the Reference protobuffer which underlies keys and references.
Args:
entity_namespace: The 'before' namespace of the entity that has this
reference property. If this value does not match the reference
properties current namespace, then the reference property namespace will
not be modified.
reference_proto: A Onestore Reference proto
"""
reference_proto.set_app(self.app_id)
if entity_namespace != reference_proto.name_space():
return
if self.namespace:
reference_proto.set_name_space(self.namespace)
else:
reference_proto.clear_name_space()
def _translate_entity_proto(self, entity_proto):
"""Transform the ReferenceProperties of the given entity to fix app_id."""
entity_key = entity_proto.mutable_key()
entity_key.set_app(self.app_id)
original_entity_namespace = entity_key.name_space()
if self.namespace:
entity_key.set_name_space(self.namespace)
else:
entity_key.clear_name_space()
for prop in entity_proto.property_list():
prop_value = prop.mutable_value()
if prop_value.has_referencevalue():
self.rewrite_reference_proto(original_entity_namespace,
prop_value.mutable_referencevalue())
for prop in entity_proto.raw_property_list():
prop_value = prop.mutable_value()
if prop_value.has_referencevalue():
self.rewrite_reference_proto(original_entity_namespace,
prop_value.mutable_referencevalue())
return entity_proto
class Exporter(object):
"""A base class for serializing datastore entities.
To add a handler for exporting an entity kind from your datastore,
write a subclass of this class that calls Exporter.__init__ from your
class's __init__.
If you need to run extra code to convert entities from the input
data, create new properties, or otherwise modify the entities before
they're inserted, override handle_entity.
See the output_entities method for the writing of data from entities.
"""
__exporters = {}
kind = None
__properties = None
calculate_sort_key_from_entity = False
def __init__(self, kind, properties):
"""Constructor.
Populates this Exporters's kind and properties map.
Args:
kind: a string containing the entity kind that this exporter handles
properties: list of (name, converter, default) tuples.
This is used to automatically convert the entities to strings.
The converter should be a function that takes one argument, a property
value of the appropriate type, and returns a str or unicode. The default
is a string to be used if the property is not present, or None to fail
with an error if the property is missing.
For example:
[('name', str, None),
('id_number', str, None),
('email', str, ''),
('user', str, None),
('birthdate',
lambda x: str(datetime.datetime.fromtimestamp(float(x))),
None),
('description', str, ''),
]
"""
Validate(kind, basestring)
self.kind = kind
GetImplementationClass(kind)
Validate(properties, list)
for name, fn, default in properties:
Validate(name, basestring)
assert callable(fn), (
'Conversion function %s for property %s is not callable.' % (
fn, name))
if default:
Validate(default, basestring)
self.__properties = properties
@staticmethod
def RegisterExporter(exporter):
"""Register exporter and the Exporter instance for its kind.
Args:
exporter: A Exporter instance.
"""
Exporter.__exporters[exporter.kind] = exporter
def __ExtractProperties(self, entity):
"""Converts an entity into a list of string values.
Args:
entity: An entity to extract the properties from.
Returns:
A list of the properties of the entity.
Raises:
MissingPropertyError: if an expected field on the entity is missing.
"""
encoding = []
for name, fn, default in self.__properties:
try:
encoding.append(fn(entity[name]))
except KeyError:
if name == '__key__':
encoding.append(fn(entity.key()))
elif default is None:
raise MissingPropertyError(name)
else:
encoding.append(default)
return encoding
def __EncodeEntity(self, entity):
"""Convert the given entity into CSV string.
Args:
entity: The entity to encode.
Returns:
A CSV string.
"""
output = StringIO.StringIO()
writer = csv.writer(output)
writer.writerow(self.__ExtractProperties(entity))
return output.getvalue()
def __SerializeEntity(self, entity):
"""Creates a string representation of an entity.
Args:
entity: The entity to serialize.
Returns:
A serialized representation of an entity.
"""
encoding = self.__EncodeEntity(entity)
if not isinstance(encoding, unicode):
encoding = unicode(encoding, 'utf-8')
encoding = encoding.encode('utf-8')
return encoding
def output_entities(self, entity_generator):
"""Outputs the downloaded entities.
This implementation writes CSV.
Args:
entity_generator: A generator that yields the downloaded entities
in key order.
"""
CheckOutputFile(self.output_filename)
output_file = open(self.output_filename, 'w')
logger.debug('Export complete, writing to file')
output_file.writelines(self.__SerializeEntity(entity)
for entity in entity_generator)
def initialize(self, filename, exporter_opts):
"""Performs initialization and validation of the output file.
This implementation checks that the input file exists and can be
opened for writing.
Args:
filename: The string given as the --filename flag argument.
exporter_opts: The string given as the --exporter_opts flag argument.
"""
CheckOutputFile(filename)
self.output_filename = filename
def finalize(self):
"""Performs finalization actions after the download completes."""
pass
def sort_key_from_entity(self, entity):
"""A value to alter sorting of entities in output_entities entity_generator.
Will only be called if calculate_sort_key_from_entity is true.
Args:
entity: A datastore.Entity.
Returns:
A value to store in the intermediate sqlite table. The table will later
be sorted by this value then by the datastore key, so the sort_key need
not be unique.
"""
return ''
@staticmethod
def RegisteredExporters():
"""Returns a dictionary of the exporter instances that have been created."""
return dict(Exporter.__exporters)
@staticmethod
def RegisteredExporter(kind):
"""Returns an exporter instance for the given kind if it exists."""
return Exporter.__exporters[kind]
class DumpExporter(Exporter):
"""An exporter which dumps protobuffers to a file."""
def __init__(self, kind, result_db_filename):
self.kind = kind
self.result_db_filename = result_db_filename
def output_entities(self, entity_generator):
shutil.copyfile(self.result_db_filename, self.output_filename)
class MapperRetry(Error):
"""An exception that indicates a non-fatal error during mapping."""
class Mapper(object):
"""A base class for serializing datastore entities.
To add a handler for exporting an entity kind from your datastore,
write a subclass of this class that calls Mapper.__init__ from your
class's __init__.
You need to implement to batch_apply or apply method on your subclass
for the map to do anything.
"""
__mappers = {}
kind = None
def __init__(self, kind):
"""Constructor.
Populates this Mappers's kind.
Args:
kind: a string containing the entity kind that this mapper handles
"""
Validate(kind, basestring)
self.kind = kind
GetImplementationClass(kind)
@staticmethod
def RegisterMapper(mapper):
"""Register mapper and the Mapper instance for its kind.
Args:
mapper: A Mapper instance.
"""
Mapper.__mappers[mapper.kind] = mapper
def initialize(self, mapper_opts):
"""Performs initialization.
Args:
mapper_opts: The string given as the --mapper_opts flag argument.
"""
pass
def finalize(self):
"""Performs finalization actions after the download completes."""
pass
def apply(self, entity):
print 'Default map function doing nothing to %s' % entity
def batch_apply(self, entities):
for entity in entities:
self.apply(entity)
def map_over_keys_only(self):
"""Return whether this mapper should iterate over only keys or not.
Override this method in subclasses to return True values.
Returns:
True or False
"""
return False
@staticmethod
def RegisteredMappers():
"""Returns a dictionary of the mapper instances that have been created."""
return dict(Mapper.__mappers)
@staticmethod
def RegisteredMapper(kind):
"""Returns an mapper instance for the given kind if it exists."""
return Mapper.__mappers[kind]
class QueueJoinThread(threading.Thread):
"""A thread that joins a queue and exits.
Queue joins do not have a timeout. To simulate a queue join with
timeout, run this thread and join it with a timeout.
"""
def __init__(self, queue):
"""Initialize a QueueJoinThread.
Args:
queue: The queue for this thread to join.
"""
threading.Thread.__init__(self)
self.setDaemon(True)
assert isinstance(queue, (Queue.Queue, ReQueue))
self.queue = queue
def run(self):
"""Perform the queue join in this thread."""
self.queue.join()
def InterruptibleQueueJoin(queue,
thread_local,
thread_pool,
queue_join_thread_factory=QueueJoinThread,
check_workers=True):
"""Repeatedly joins the given ReQueue or Queue.Queue with short timeout.
Between each timeout on the join, worker threads are checked.
Args:
queue: A Queue.Queue or ReQueue instance.
thread_local: A threading.local instance which indicates interrupts.
thread_pool: An AdaptiveThreadPool instance.
queue_join_thread_factory: Used for dependency injection.
check_workers: Whether to interrupt the join on worker death.
Returns:
True unless the queue join is interrupted by SIGINT or worker death.
"""
thread = queue_join_thread_factory(queue)
thread.start()
while True:
thread.join(timeout=.5)
if not thread.isAlive():
return True
if thread_local.shut_down:
logger.debug('Queue join interrupted')
return False
if check_workers:
for worker_thread in thread_pool.Threads():
if not worker_thread.isAlive():
return False
def ShutdownThreads(data_source_thread, thread_pool):
"""Shuts down the worker and data source threads.
Args:
data_source_thread: A running DataSourceThread instance.
thread_pool: An AdaptiveThreadPool instance with workers registered.
"""
logger.info('An error occurred. Shutting down...')
data_source_thread.exit_flag = True
thread_pool.Shutdown()
data_source_thread.join(timeout=3.0)
if data_source_thread.isAlive():
logger.warn('%s hung while trying to exit',
data_source_thread.GetFriendlyName())
class BulkTransporterApp(object):
"""Class to wrap bulk transport application functionality."""
def __init__(self,
arg_dict,
input_generator_factory,
throttle,
progress_db,
progresstrackerthread_factory,
max_queue_size=DEFAULT_QUEUE_SIZE,
request_manager_factory=RequestManager,
datasourcethread_factory=DataSourceThread,
progress_queue_factory=Queue.Queue,
thread_pool_factory=adaptive_thread_pool.AdaptiveThreadPool,
server=None):
"""Instantiate a BulkTransporterApp.
Uploads or downloads data to or from application using HTTP requests.
When run, the class will spin up a number of threads to read entities
from the data source, pass those to a number of worker threads
for sending to the application, and track all of the progress in a
small database in case an error or pause/termination requires a
restart/resumption of the upload process.
Args:
arg_dict: Dictionary of command line options.
input_generator_factory: A factory that creates a WorkItem generator.
throttle: A Throttle instance.
progress_db: The database to use for replaying/recording progress.
progresstrackerthread_factory: Used for dependency injection.
max_queue_size: Maximum size of the queues before they should block.
request_manager_factory: Used for dependency injection.
datasourcethread_factory: Used for dependency injection.
progress_queue_factory: Used for dependency injection.
thread_pool_factory: Used for dependency injection.
server: An existing AbstractRpcServer to reuse.
"""
self.app_id = arg_dict['application']
self.post_url = arg_dict['url']
self.kind = arg_dict['kind']
self.batch_size = arg_dict['batch_size']
self.input_generator_factory = input_generator_factory
self.num_threads = arg_dict['num_threads']
self.email = arg_dict['email']
self.passin = arg_dict['passin']
self.dry_run = arg_dict['dry_run']
self.throttle_class = arg_dict['throttle_class']
self.throttle = throttle
self.progress_db = progress_db
self.progresstrackerthread_factory = progresstrackerthread_factory
self.max_queue_size = max_queue_size
self.request_manager_factory = request_manager_factory
self.datasourcethread_factory = datasourcethread_factory
self.progress_queue_factory = progress_queue_factory
self.thread_pool_factory = thread_pool_factory
self.server = server
(scheme,
self.host_port, self.url_path,
unused_query, unused_fragment) = urlparse.urlsplit(self.post_url)
self.secure = (scheme == 'https')
def RunPostAuthentication(self):
"""Method that gets called after authentication."""
if isinstance(self.kind, basestring):
return [self.kind]
return self.kind
def Run(self):
"""Perform the work of the BulkTransporterApp.
Raises:
AuthenticationError: If authentication is required and fails.
Returns:
Error code suitable for sys.exit, e.g. 0 on success, 1 on failure.
"""
self.error = False
thread_pool = self.thread_pool_factory(
self.num_threads, queue_size=self.max_queue_size)
progress_queue = self.progress_queue_factory(self.max_queue_size)
self.request_manager = self.request_manager_factory(self.app_id,
self.host_port,
self.url_path,
self.kind,
self.throttle,
self.batch_size,
self.secure,
self.email,
self.passin,
self.dry_run,
self.server,
self.throttle_class)
try:
self.request_manager.Authenticate()
except Exception, e:
self.error = True
if not isinstance(e, urllib2.HTTPError) or (
e.code != 302 and e.code != 401):
logger.exception('Exception during authentication')
raise AuthenticationError()
if (self.request_manager.auth_called and
not self.request_manager.authenticated):
self.error = True
raise AuthenticationError('Authentication failed')
kinds = self.RunPostAuthentication()
for thread in thread_pool.Threads():
self.throttle.Register(thread)
self.progress_thread = self.progresstrackerthread_factory(
progress_queue, self.progress_db)
if self.progress_db.UseProgressData():
logger.debug('Restarting upload using progress database')
progress_generator_factory = self.progress_db.GetProgressStatusGenerator
else:
progress_generator_factory = None
self.data_source_thread = (
self.datasourcethread_factory(self.request_manager,
kinds,
thread_pool,
progress_queue,
self.input_generator_factory,
progress_generator_factory))
self.throttle.Register(self.data_source_thread)
thread_local = threading.local()
thread_local.shut_down = False
def Interrupt(unused_signum, unused_frame):
"""Shutdown gracefully in response to a signal."""
thread_local.shut_down = True
self.error = True
signal.signal(signal.SIGINT, Interrupt)
self.progress_thread.start()
self.data_source_thread.start()
while not thread_local.shut_down:
self.data_source_thread.join(timeout=0.25)
if self.data_source_thread.isAlive():
for thread in list(thread_pool.Threads()) + [self.progress_thread]:
if not thread.isAlive():
logger.info('Unexpected thread death: %s', thread.getName())
thread_local.shut_down = True
self.error = True
break
else:
break
def _Join(ob, msg):
logger.debug('Waiting for %s...', msg)
if isinstance(ob, threading.Thread):
ob.join(timeout=3.0)
if ob.isAlive():
logger.debug('Joining %s failed', ob)
else:
logger.debug('... done.')
elif isinstance(ob, (Queue.Queue, ReQueue)):
if not InterruptibleQueueJoin(ob, thread_local, thread_pool):
ShutdownThreads(self.data_source_thread, thread_pool)
else:
ob.join()
logger.debug('... done.')
if self.data_source_thread.error or thread_local.shut_down:
ShutdownThreads(self.data_source_thread, thread_pool)
else:
_Join(thread_pool.requeue, 'worker threads to finish')
thread_pool.Shutdown()
thread_pool.JoinThreads()
thread_pool.CheckErrors()
print ''
if self.progress_thread.isAlive():
InterruptibleQueueJoin(progress_queue, thread_local, thread_pool,
check_workers=False)
else:
logger.warn('Progress thread exited prematurely')
progress_queue.put(_THREAD_SHOULD_EXIT)
_Join(self.progress_thread, 'progress_thread to terminate')
self.progress_thread.CheckError()
if not thread_local.shut_down:
self.progress_thread.WorkFinished()
self.data_source_thread.CheckError()
return self.ReportStatus()
def ReportStatus(self):
"""Display a message reporting the final status of the transfer."""
raise NotImplementedError()
class BulkUploaderApp(BulkTransporterApp):
"""Class to encapsulate bulk uploader functionality."""
def __init__(self, *args, **kwargs):
BulkTransporterApp.__init__(self, *args, **kwargs)
def RunPostAuthentication(self):
loader = Loader.RegisteredLoader(self.kind)
self.request_manager.ReserveKeys(loader.get_keys_to_reserve())
return [self.kind]
def ReportStatus(self):
"""Display a message reporting the final status of the transfer."""
total_up, duration = self.throttle.TotalTransferred(
remote_api_throttle.BANDWIDTH_UP)
s_total_up, unused_duration = self.throttle.TotalTransferred(
remote_api_throttle.HTTPS_BANDWIDTH_UP)
total_up += s_total_up
total = total_up
logger.info('%d entities total, %d previously transferred',
self.data_source_thread.read_count,
self.data_source_thread.xfer_count)
transfer_count = self.progress_thread.EntitiesTransferred()
logger.info('%d entities (%d bytes) transferred in %.1f seconds',
transfer_count, total, duration)
if (self.data_source_thread.read_all and
transfer_count +
self.data_source_thread.xfer_count >=
self.data_source_thread.read_count):
logger.info('All entities successfully transferred')
return 0
else:
logger.info('Some entities not successfully transferred')
return 1
class BulkDownloaderApp(BulkTransporterApp):
"""Class to encapsulate bulk downloader functionality."""
def __init__(self, *args, **kwargs):
BulkTransporterApp.__init__(self, *args, **kwargs)
def RunPostAuthentication(self):
if not self.kind:
return self.request_manager.GetSchemaKinds()
elif isinstance(self.kind, basestring):
return [self.kind]
else:
return self.kind
def ReportStatus(self):
"""Display a message reporting the final status of the transfer."""
total_down, duration = self.throttle.TotalTransferred(
remote_api_throttle.BANDWIDTH_DOWN)
s_total_down, unused_duration = self.throttle.TotalTransferred(
remote_api_throttle.HTTPS_BANDWIDTH_DOWN)
total_down += s_total_down
total = total_down
existing_count = self.progress_thread.existing_count
xfer_count = self.progress_thread.EntitiesTransferred()
logger.info('Have %d entities, %d previously transferred',
xfer_count, existing_count)
logger.info('%d entities (%d bytes) transferred in %.1f seconds',
xfer_count, total, duration)
if self.error:
return 1
else:
return 0
class BulkMapperApp(BulkTransporterApp):
"""Class to encapsulate bulk map functionality."""
def __init__(self, *args, **kwargs):
BulkTransporterApp.__init__(self, *args, **kwargs)
def ReportStatus(self):
"""Display a message reporting the final status of the transfer."""
total_down, duration = self.throttle.TotalTransferred(
remote_api_throttle.BANDWIDTH_DOWN)
s_total_down, unused_duration = self.throttle.TotalTransferred(
remote_api_throttle.HTTPS_BANDWIDTH_DOWN)
total_down += s_total_down
total = total_down
xfer_count = self.progress_thread.EntitiesTransferred()
logger.info('The following may be inaccurate if any mapper tasks '
'encountered errors and had to be retried.')
logger.info('Applied mapper to %s entities.',
xfer_count)
logger.info('%s entities (%s bytes) transferred in %.1f seconds',
xfer_count, total, duration)
if self.error:
return 1
else:
return 0
def PrintUsageExit(code):
"""Prints usage information and exits with a status code.
Args:
code: Status code to pass to sys.exit() after displaying usage information.
"""
print __doc__ % {'arg0': sys.argv[0]}
sys.stdout.flush()
sys.stderr.flush()
sys.exit(code)
REQUIRED_OPTION = object()
BOOL_ARGS = ('create_config', 'debug', 'download', 'dry_run', 'dump',
'has_header', 'map', 'passin', 'restore')
INT_ARGS = ('bandwidth_limit', 'batch_size', 'http_limit', 'num_threads',
'rps_limit')
FILENAME_ARGS = ('config_file', 'db_filename', 'filename', 'log_file',
'result_db_filename')
STRING_ARGS = ('application', 'auth_domain', 'email', 'exporter_opts',
'kind', 'loader_opts', 'mapper_opts', 'namespace', 'url')
DEPRECATED_OPTIONS = {'csv_has_header': 'has_header', 'app_id': 'application'}
FLAG_SPEC = (['csv_has_header', 'help', 'app_id='] +
list(BOOL_ARGS) +
[arg + '=' for arg in INT_ARGS + FILENAME_ARGS + STRING_ARGS])
def ParseArguments(argv, die_fn=lambda: PrintUsageExit(1)):
"""Parses command-line arguments.
Prints out a help message if -h or --help is supplied.
Args:
argv: List of command-line arguments.
die_fn: Function to invoke to end the program.
Returns:
A dictionary containing the value of command-line options.
"""
opts, unused_args = getopt.getopt(
argv[1:],
'h',
FLAG_SPEC)
arg_dict = {}
arg_dict['url'] = REQUIRED_OPTION
arg_dict['filename'] = None
arg_dict['config_file'] = None
arg_dict['kind'] = None
arg_dict['batch_size'] = None
arg_dict['num_threads'] = DEFAULT_THREAD_COUNT
arg_dict['bandwidth_limit'] = DEFAULT_BANDWIDTH_LIMIT