blob: 8881f6e1e999fafa05b7893f64634b813b21f161 [file]
# Copyright 2017 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Simple wrapper of Google BigQuery client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import json
import logging
import StringIO
import time
import pytz
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
class BigQueryError(Exception):
"""Represents errors on BigQuery backends."""
def __init__(self, errors):
"""Constructor.
Args:
errors: A list of dictionaries representing errors.
"""
super(BigQueryError, self).__init__('BigQuery query failed: %r' % errors)
self.errors = errors
class BigQueryWrapper(object):
"""Simple wrapper of google.cloud.bigquery.Client object."""
def __init__(self, bigquery_client, dataset, bucket=None):
"""Initializes the wrapper.
Args:
bigquery_client: google.cloud.bigquery.Client object.
dataset: google.cloud.bigquery.Dataset object.
bucket: (Optional) google.cloud.storage.Bucket object.
"""
self._bigquery_client = bigquery_client
self._dataset = dataset
self._bucket = bucket
def TableExists(self, table_name):
"""Checks if the specified table exists or not.
Args:
table_name: Table name.
Returns:
True if the table exists, otherwise False.
Raises:
google.cloud.exceptions.GoogleCloudError: On Google Cloud errors.
"""
table_ref = self._dataset.table(table_name)
try:
self._bigquery_client.get_table(table_ref)
return True
except NotFound:
return False
def TableSchema(self, table_name):
"""Returns the table schema as a list.
Args:
table_name: Table name.
Returns:
Table schema as a list.
Raises:
google.cloud.exceptions.NotFound: If the table does not exist.
"""
table_ref = self._dataset.table(table_name)
table = self._bigquery_client.get_table(table_ref)
return table.schema
def ListTableNames(self):
"""Retrieves a list of table names.
Returns:
A sorted list of table names.
Raises:
google.cloud.exceptions.GoogleCloudError: On Google Cloud errors.
"""
return sorted(table.table_id
for table in self._bigquery_client.list_tables(self._dataset))
def RunQuery(self, query, timeout_seconds, description):
"""Runs a query.
Args:
query: Query string.
timeout_seconds: Timeout in seconds.
description: Description message explaining what this query does.
This is used only for logging.
Returns:
A list of tuples representing result rows.
Raises:
google.cloud.exceptions.GoogleCloudError: On Google Cloud errors, e.g. if
the query job failed.
BigQueryError: On BigQuery errors.
concurrent.futures.TimeoutError: If the job did not complete in
the given timeout.
AssertionError: On internal BigQuery client failures.
"""
logging.info('Running BigQuery query: %s', description)
logging.debug('Query: %s', query)
query_job_config = bigquery.QueryJobConfig(default_dataset=self._dataset)
query_job = self._bigquery_client.query(
query, job_config=query_job_config)
start_time = time.time()
iterator = query_job.result(timeout=timeout_seconds)
assert query_job.state == 'DONE'
rows = list(iterator)
if query_job.error_result:
raise BigQueryError(query_job.errors)
end_time = time.time()
logging.info('BigQuery query successfully finished in %.3fs; '
'processed %d bytes, got %d rows.', end_time - start_time,
query_job.total_bytes_processed, len(rows))
return rows
def CopyTable(self, dst_table_name, *src_table_names):
"""Copy data from multiple tables into a new one.
Args:
dst_table_name: Name of the new table to create.
*src_table_names: Names of tables to copy data from.
Raises:
google.cloud.exceptions.GoogleCloudError: On Google Cloud errors.
BigQueryError: If no table to copy.
concurrent.futures.TimeoutError: If the job did not complete in
the given timeout.
AssertionError: On internal BigQuery client failures.
"""
if not src_table_names:
raise BigQueryError('No table to copy from')
src_tables = [self._dataset.table(t) for t in src_table_names]
dst_table = self._dataset.table(dst_table_name,
self.TableSchema(src_table_names[0]))
job_id = 'copy_into_%s_%s_%s' % (
self._dataset.dataset_id, dst_table_name,
datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S%f'))
job = self._bigquery_client.copy_table(src_tables, dst_table, job_id=job_id)
# Wait for the copy job to finish.
job.result()
assert job.state == 'DONE'
def DeleteTable(self, table_name):
"""Delete a table.
Args:
table_name: Name of the table to delete.
"""
table = self._dataset.table(table_name)
self._bigquery_client.delete_table(table, not_found_ok=True)
def PatchSchema(self, table_name, schema):
"""Patches the table schema for the given table.
Args:
table_name: Table to update.
schema: The new schema as a list of SchemaField.
"""
table_ref = self._dataset.table(table_name)
self._bigquery_client.update_table(table_ref, schema)
def LoadEntries(self, table_name, schema, entries, job_name=None):
"""Loads entries into a table.
Entries are first serialized as newline-delimited JSON and uploaded to
Google Cloud Storage, then a load job is requested to BigQuery.
The target table is created if it does not exist yet. Entries are always
appended to the table and do not overwrite existing entries.
Args:
table_name: Table name.
schema: Schema of the table as a list of SchemaField.
entries: A list of dictionaries representing entries.
job_name: If set, sets the load job name explicitly. Usually you do not
need to set this.
Raises:
google.cloud.exceptions.GoogleCloudError: On Google Cloud errors.
BigQueryError: GS bucket name is not provided or Bigquery errors.
concurrent.futures.TimeoutError: If the job did not complete in
the given timeout.
AssertionError: On internal BigQuery client failures.
"""
if self._bucket is None:
raise BigQueryError(
'GS bucket not provided at object creation to stage entries')
if job_name is None:
job_name = 'load_%s_%s_%s' % (
self._dataset.dataset_id, table_name,
datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S%f'))
json_path = 'loads/%s/%s/data.json' % (self._dataset.dataset_id, job_name)
self._UploadEntriesToGs(entries, json_path)
if self._ShouldPatchSchema(table_name, schema):
self.PatchSchema(table_name, schema)
job = self._CreateLoadJobFromGs(json_path, schema, table_name, job_name)
# Wait for the copy job to finish.
job.result()
assert job.state == 'DONE'
if job.error_result:
raise BigQueryError(job.errors)
def ExportToStorage(self, table_name, destination_url, job_name=None):
"""Exports a table into Google Cloud Storage.
A table is exported as a gzip'ed newline-delimited JSON file.
Args:
table_name: Table name.
destination_url: Destination URL starting with gs://.
job_name: If set, sets the load job name explicitly. Usually you do not
need to set this.
Raises:
google.cloud.exceptions.GoogleCloudError: On Google Cloud errors.
BigQueryError: On BigQuery errors.
concurrent.futures.TimeoutError: If the job did not complete in
the given timeout.
AssertionError: On internal BigQuery client failures.
"""
if job_name is None:
job_name = 'export_%s_%s_%s' % (
self._dataset.dataset_id, table_name,
datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S%f'))
job = self._CreateExportJob(table_name, destination_url, job_name)
# Wait for the copy job to finish.
job.result()
assert job.state == 'DONE'
if job.error_result:
raise BigQueryError(job.errors)
def _UploadEntriesToGs(self, entries, json_path):
"""Uploads entries to Google Cloud Storage as a newline-delimited JSON.
Args:
entries: A list of dictionaries representing entries.
json_path: Path to a new JSON file, relative from Google Cloud Storage
bucket root.
"""
entries_json = _SerializeToNewlineDelimitedJson(entries)
logging.info('Uploading %d entries to GCS', len(entries))
self._bucket.blob(json_path).upload_from_string(entries_json)
def _CreateLoadJobFromGs(self, json_path, schema, table_name, job_id):
"""Creates a new load job that imports entries from Google Cloud Storage.
Args:
json_path: Path to a JSON file, relative from Google Cloud Storage
bucket root.
schema: Schema of the table as a list of SchemaField.
table_name: Name of a BigQuery table to insert new entries to.
job_id: Unique name of a load job.
Returns:
google.cloud.bigquery.job.LoadJob object.
"""
json_url = 'gs://%s/%s' % (self._bucket.name, json_path)
job_config = bigquery.LoadJobConfig(
schema=schema,
create_disposition='CREATE_IF_NEEDED',
write_disposition='WRITE_APPEND',
source_format='NEWLINE_DELIMITED_JSON')
job = self._bigquery_client.load_table_from_uri(
json_url,
self._dataset.table(table_name),
job_id=job_id,
job_config=job_config)
return job
def _ShouldPatchSchema(self, table_name, schema):
"""Decide whether we should patch the schema of the given table.
Args:
table_name: Table to update.
schema: The new schema as a list of SchemaField.
Returns:
True if we should patch schema, False otherwise.
"""
if not self.TableExists(table_name):
return False
old_schema = self.TableSchema(table_name)
# bigquery.schema.SchemaField are not hashable. We take the easy path and
# patch the table as long *any* fields are different.
if len(old_schema) != len(schema):
return True
for i in xrange(len(old_schema)):
if old_schema[i] != schema[i]:
return True
return False
def _CreateExportJob(self, table_name, url, job_id):
"""Creates a new export job.
Args:
table_name: Table name.
url: GS URL of the destination.
job_id: Unique name of a load job.
Returns:
google.cloud.bigquery.job.ExtractJob object.
"""
job_config = bigquery.job.ExtractJobConfig(
compression='GZIP',
destination_format='NEWLINE_DELIMITED_JSON')
job = self._bigquery_client.extract_table(
self._dataset.table(table_name),
url,
job_id=job_id,
job_config=job_config)
return job
def _SerializeToNewlineDelimitedJson(entries):
"""Serializes entries into a newline-delimited JSON.
Args:
entries: A list of dictionaries representing entries.
Returns:
UTF-8 encoded string of a newline-delimited JSON.
"""
buf = StringIO.StringIO()
for entry in entries:
json.dump(entry, buf, separators=(',', ':'), default=_CustomEncoder)
buf.write('\n')
return buf.getvalue().encode('utf-8')
def _CustomEncoder(value):
"""Custom JSON value encoder to format timezone-aware datetime values.
Args:
value: Any object.
Returns:
A serialized string.
Raises:
TypeError: On failure.
"""
if isinstance(value, datetime.datetime):
if value.tzinfo is None:
raise TypeError('Naive datetime can not be serialized')
utc_datetime = value.astimezone(pytz.utc)
return utc_datetime.strftime('%Y-%m-%d %H:%M:%S UTC')
raise TypeError('Could not encode to JSON: %r' % value)