blob: 3f292a9231f21f49c3135927a723ad1370ac0cca [file] [log] [blame]
# Copyright 2016 The Chromium Authors. All rights reserved.
# Use of this source code is govered by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
""" Tasks and handlers for maintaining the spam classifier model. These
should be run via cron and task queue rather than manually.
"""
import cgi
import csv
import logging
import webapp2
import cloudstorage
import json
from datetime import date
from datetime import datetime
from datetime import timedelta
from framework import servlet
from framework import urls
from google.appengine.api import taskqueue
from google.appengine.api import app_identity
from framework import gcs_helpers
class TrainingDataExport(webapp2.RequestHandler):
"""Trigger a training data export task"""
def get(self):
logging.info("Training data export requested.")
taskqueue.add(url=urls.SPAM_DATA_EXPORT_TASK + '.do')
BATCH_SIZE = 100
class TrainingDataExportTask(servlet.Servlet):
"""Export any human-labeled ham or spam from the previous day. These
records will be used by a subsequent task to create an updated model.
"""
CHECK_SECURITY_TOKEN = False
def ProcessFormData(self, mr, post_data):
logging.info("Training data export initiated.")
bucket_name = app_identity.get_default_gcs_bucket_name()
date_str = date.today().isoformat()
export_target_path = '/' + bucket_name + '/spam_training_data/' + date_str
total_issues = 0
with cloudstorage.open(export_target_path, mode='w',
content_type=None, options=None, retry_params=None) as gcs_file:
csv_writer = csv.writer(gcs_file, delimiter=',', quotechar='"',
quoting=csv.QUOTE_ALL, lineterminator='\n')
since = datetime.now() - timedelta(days=1)
# TODO: Comments, and further pagination
issues, first_comments, _count = (
self.services.spam.GetTrainingIssues(
mr.cnxn, self.services.issue, since, offset=0, limit=BATCH_SIZE))
total_issues += len(issues)
for issue in issues:
# Cloud Prediction API doesn't allow newlines in the training data.
fixed_summary = issue.summary.replace('\r\n', ' ')
fixed_comment = first_comments[issue.issue_id].replace('\r\n', ' ')
csv_writer.writerow([
'spam' if issue.is_spam else 'ham',
fixed_summary, fixed_comment,
])
self.response.body = json.dumps({
"exported_issue_count": total_issues,
})