blob: d144d860e2087dee6462b728a159c56f427aec60 [file] [log] [blame]
# Copyright 2017 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Functions for training and evaluating a GBM model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
from chromite.lib import cros_build_lib
from chromite.lib import cidb
import lightgbm
from lib import cl_feature_extractor
from lib import feature_preprocessor
from lib import feature_feeder
PARAMS = dict(
num_leaves=200,
learning_rate=0.003,
min_child_samples=20,
min_sum_hessian_in_leaf=1e-3,
n_estimators=2000,
metric='auc',
subsample=0.9,
colsample_bytree=0.6,
)
CRED_DIR = 'creds/cidb'
SAVED_MODEL_DIR = os.path.join('model', 'saved', 'trees')
LATEST_MODEL = 'latest.pkl'
def Model(**kwargs):
"""Returns an instance of a LightGBM classifier.
Args:
**kwargs: Overrides default PARAMS.
"""
params = dict(PARAMS, **kwargs)
return lightgbm.LGBMClassifier(**params)
@cros_build_lib.Memoize
def LoadModel():
"""Loads the trained model from the saved models directory."""
with open(os.path.join(SAVED_MODEL_DIR, LATEST_MODEL)) as fp:
return pickle.load(fp)
@cros_build_lib.Memoize
def _GetCIDBConnection():
"""Returns a cached CIDB connection."""
return cidb.CIDBConnection(CRED_DIR)
def BuildRisks(build_id):
"""Returns the Bad CL probabilities for a build.
Args:
build_id: The master build id.
Returns:
A dictionary mapping CL numbers to probabilities
"""
# TODO(phobbs) this is a bit slow for new CQs (~10s). Can we do this faster
# with some query batching?
feature_feeder.FetchNecessaryRecords(
_GetCIDBConnection(),
build_id
)
cl_feature_extractor.ExtractCLFeatures(for_training=False)
pre = feature_preprocessor.CLFeaturePreprocessor(
build_to_identify=build_id)
features = pre.features
model = LoadModel()
# probs is a [N x 2] array - we only care about the Bad CL probability, not
# the good CL probability.
probs = model.predict_proba(features)[:, 1]
# TODO(phobbs) for some reason we store the CL numbers in the labels arary.
# That's a weird decision, since they're not labels we want to predict.
cl_numbers = pre.labels[:, 1]
return {cl: p for cl, p in zip(cl_numbers, probs)}