blob: 3291ecf307ed9af3ff0c58ce5e9b44af0155ad16 [file] [log] [blame]
// Copyright 2018 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rpc
import (
"context"
"strings"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/Masterminds/squirrel"
"github.com/VividCortex/mysqlerr"
"github.com/go-sql-driver/mysql"
"go.chromium.org/luci/common/errors"
states "go.chromium.org/luci/machine-db/api/common/v1"
"go.chromium.org/luci/machine-db/api/crimson/v1"
"go.chromium.org/luci/machine-db/appengine/database"
"go.chromium.org/luci/machine-db/appengine/model"
"go.chromium.org/luci/machine-db/common"
)
// CreateVM handles a request to create a new VM.
func (*Service) CreateVM(c context.Context, req *crimson.CreateVMRequest) (*crimson.VM, error) {
return createVM(c, req.Vm)
}
// ListVMs handles a request to list VMs.
func (*Service) ListVMs(c context.Context, req *crimson.ListVMsRequest) (*crimson.ListVMsResponse, error) {
vms, err := listVMs(c, database.Get(c), req)
if err != nil {
return nil, err
}
return &crimson.ListVMsResponse{
Vms: vms,
}, nil
}
// UpdateVM handles a request to update an existing VM.
func (*Service) UpdateVM(c context.Context, req *crimson.UpdateVMRequest) (*crimson.VM, error) {
return updateVM(c, req.Vm, req.UpdateMask)
}
// createVM creates a new VM in the database.
func createVM(c context.Context, v *crimson.VM) (*crimson.VM, error) {
if err := validateVMForCreation(v); err != nil {
return nil, err
}
ip, _ := common.ParseIPv4(v.Ipv4)
tx, err := database.Begin(c)
if err != nil {
return nil, errors.Annotate(err, "failed to begin transaction").Err()
}
defer tx.MaybeRollback(c)
hostnameID, err := model.AssignHostnameAndIP(c, tx, v.Name, ip)
if err != nil {
return nil, err
}
// vms.hostname_id, vms.physical_host_id, and vms.os_id are NOT NULL as above.
_, err = tx.ExecContext(c, `
INSERT INTO vms (hostname_id, physical_host_id, os_id, description, deployment_ticket, state)
VALUES (
?,
(SELECT p.id FROM physical_hosts p, hostnames h WHERE p.hostname_id = h.id AND h.name = ?),
(SELECT id FROM oses WHERE name = ?),
?,
?,
?
)
`, hostnameID, v.Host, v.Os, v.Description, v.DeploymentTicket, v.State)
if err != nil {
switch e, ok := err.(*mysql.MySQLError); {
case !ok:
// Type assertion failed.
case e.Number == mysqlerr.ER_BAD_NULL_ERROR && strings.Contains(e.Message, "'physical_host_id'"):
// e.g. "Error 1048: Column 'physical_host_id' cannot be null".
return nil, status.Errorf(codes.NotFound, "physical host %q does not exist", v.Host)
case e.Number == mysqlerr.ER_BAD_NULL_ERROR && strings.Contains(e.Message, "'os_id'"):
// e.g. "Error 1048: Column 'os_id' cannot be null".
return nil, status.Errorf(codes.NotFound, "operating system %q does not exist", v.Os)
}
return nil, errors.Annotate(err, "failed to create VM").Err()
}
vms, err := listVMs(c, tx, &crimson.ListVMsRequest{
Names: []string{v.Name},
})
if err != nil {
return nil, errors.Annotate(err, "failed to fetch created VM").Err()
}
if err := tx.Commit(); err != nil {
return nil, errors.Annotate(err, "failed to commit transaction").Err()
}
return vms[0], nil
}
// listVMs returns a slice of VMs in the database.
func listVMs(c context.Context, q database.QueryerContext, req *crimson.ListVMsRequest) ([]*crimson.VM, error) {
ipv4s, err := parseIPv4s(req.Ipv4S)
if err != nil {
return nil, err
}
stmt := squirrel.Select(
"hv.name",
"hv.vlan_id",
"hp.name",
"hp.vlan_id",
"o.name",
"v.description",
"v.deployment_ticket",
"i.ipv4",
"v.state",
)
stmt = stmt.From("vms v, hostnames hv, physical_hosts p, hostnames hp, oses o, ips i").
Where("v.hostname_id = hv.id").
Where("v.physical_host_id = p.id").
Where("p.hostname_id = hp.id").
Where("v.os_id = o.id").
Where("i.hostname_id = hv.id")
stmt = selectInString(stmt, "hv.name", req.Names)
stmt = selectInInt64(stmt, "hv.vlan_id", req.Vlans)
stmt = selectInInt64(stmt, "i.ipv4", ipv4s)
stmt = selectInString(stmt, "hp.name", req.Hosts)
stmt = selectInInt64(stmt, "hp.vlan_id", req.HostVlans)
stmt = selectInString(stmt, "o.name", req.Oses)
stmt = selectInState(stmt, "v.state", req.States)
query, args, err := stmt.ToSql()
if err != nil {
return nil, errors.Annotate(err, "failed to generate statement").Err()
}
rows, err := q.QueryContext(c, query, args...)
if err != nil {
return nil, errors.Annotate(err, "failed to fetch VMs").Err()
}
defer rows.Close()
var vms []*crimson.VM
for rows.Next() {
v := &crimson.VM{}
var ipv4 common.IPv4
if err = rows.Scan(
&v.Name,
&v.Vlan,
&v.Host,
&v.HostVlan,
&v.Os,
&v.Description,
&v.DeploymentTicket,
&ipv4,
&v.State,
); err != nil {
return nil, errors.Annotate(err, "failed to fetch VM").Err()
}
v.Ipv4 = ipv4.String()
vms = append(vms, v)
}
return vms, nil
}
// updateVM updates an existing VM in the database.
func updateVM(c context.Context, v *crimson.VM, mask *field_mask.FieldMask) (*crimson.VM, error) {
if err := validateVMForUpdate(v, mask); err != nil {
return nil, err
}
stmt := squirrel.Update("vms")
updatedHost := false
for _, path := range mask.Paths {
switch path {
case "host":
if !updatedHost {
stmt = stmt.Set("physical_host_id", squirrel.Expr("(SELECT id FROM physical_hosts WHERE hostname_id = (SELECT id FROM hostnames WHERE name = ?))", v.Host))
}
updatedHost = true
case "os":
stmt = stmt.Set("os_id", squirrel.Expr("(SELECT id FROM oses WHERE name = ?)", v.Os))
case "state":
stmt = stmt.Set("state", v.State)
case "description":
stmt = stmt.Set("description", v.Description)
case "deployment_ticket":
stmt = stmt.Set("deployment_ticket", v.DeploymentTicket)
}
}
stmt = stmt.Where("hostname_id = (SELECT id FROM hostnames WHERE name = ?)", v.Name)
query, args, err := stmt.ToSql()
if err != nil {
return nil, errors.Annotate(err, "failed to generate statement").Err()
}
tx, err := database.Begin(c)
if err != nil {
return nil, errors.Annotate(err, "failed to begin transaction").Err()
}
defer tx.MaybeRollback(c)
_, err = tx.ExecContext(c, query, args...)
if err != nil {
switch e, ok := err.(*mysql.MySQLError); {
case !ok:
// Type assertion failed.
case e.Number == mysqlerr.ER_BAD_NULL_ERROR && strings.Contains(e.Message, "'physical_host_id'"):
// e.g. "Error 1048: Column 'physical_host_id' cannot be null".
return nil, status.Errorf(codes.NotFound, "physical host %q does not exist", v.Host)
case e.Number == mysqlerr.ER_BAD_NULL_ERROR && strings.Contains(e.Message, "'os_id'"):
// e.g. "Error 1048: Column 'os_id' cannot be null".
return nil, status.Errorf(codes.NotFound, "operating system %q does not exist", v.Os)
}
return nil, errors.Annotate(err, "failed to update VM").Err()
}
// The number of rows affected cannot distinguish between zero because the VM didn't exist
// and zero because the row already matched, so skip looking at the number of rows affected.
vms, err := listVMs(c, tx, &crimson.ListVMsRequest{
Names: []string{v.Name},
})
switch {
case err != nil:
return nil, errors.Annotate(err, "failed to fetch updated VM").Err()
case len(vms) == 0:
return nil, status.Errorf(codes.NotFound, "VM %q does not exist", v.Name)
}
if err := tx.Commit(); err != nil {
return nil, errors.Annotate(err, "failed to commit transaction").Err()
}
return vms[0], nil
}
// validateVMForCreation validates a VM for creation.
func validateVMForCreation(v *crimson.VM) error {
switch {
case v == nil:
return status.Error(codes.InvalidArgument, "VM specification is required")
case v.Name == "":
return status.Error(codes.InvalidArgument, "hostname is required and must be non-empty")
case v.Vlan != 0:
return status.Error(codes.InvalidArgument, "VLAN must not be specified, use IP address instead")
case v.Host == "":
return status.Error(codes.InvalidArgument, "physical hostname is required and must be non-empty")
case v.HostVlan != 0:
return status.Error(codes.InvalidArgument, "host VLAN must not be specified, use physical hostname instead")
case v.Os == "":
return status.Error(codes.InvalidArgument, "operating system is required and must be non-empty")
case v.Ipv4 == "":
return status.Error(codes.InvalidArgument, "IPv4 address is required and must be non-empty")
case v.State == states.State_STATE_UNSPECIFIED:
return status.Error(codes.InvalidArgument, "state is required")
default:
_, err := common.ParseIPv4(v.Ipv4)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid IPv4 address %q", v.Ipv4)
}
return nil
}
}
// validateVMForUpdate validates a VM for update.
func validateVMForUpdate(v *crimson.VM, mask *field_mask.FieldMask) error {
switch err := validateUpdateMask(mask); {
case v == nil:
return status.Error(codes.InvalidArgument, "VM specification is required")
case v.Name == "":
return status.Error(codes.InvalidArgument, "hostname is required and must be non-empty")
case err != nil:
return err
}
for _, path := range mask.Paths {
// TODO(smut): Allow IPv4 address to be updated.
switch path {
case "name":
return status.Error(codes.InvalidArgument, "hostname cannot be updated, delete and create a new VM instead")
case "vlan":
return status.Error(codes.InvalidArgument, "VLAN cannot be updated, delete and create a new VM instead")
case "host":
if v.Host == "" {
return status.Error(codes.InvalidArgument, "physical hostname is required and must be non-empty")
}
case "host_vlan":
return status.Error(codes.InvalidArgument, "host VLAN cannot be updated, update the host instead")
case "os":
if v.Os == "" {
return status.Error(codes.InvalidArgument, "operating system is required and must be non-empty")
}
case "state":
if v.State == states.State_STATE_UNSPECIFIED {
return status.Error(codes.InvalidArgument, "state is required")
}
case "description":
// Empty description is allowed, nothing to validate.
case "deployment_ticket":
// Empty deployment ticket is allowed, nothing to validate.
default:
return status.Errorf(codes.InvalidArgument, "unsupported update mask path %q", path)
}
}
return nil
}