// Copyright 2022 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// GRPC Server impl
package cli

import (
	"go.chromium.org/chromiumos/lro"
	common_utils "go.chromium.org/chromiumos/test/provision/v2/common-utils"
	"go.chromium.org/chromiumos/test/util/portdiscovery"
	"context"
	"fmt"
	"log"
	"net"
	"net/url"

	"github.com/pkg/errors"
	"go.chromium.org/chromiumos/config/go/test/api"
	api1 "go.chromium.org/chromiumos/config/go/test/lab/api"

	firmwareservice "go.chromium.org/chromiumos/test/provision/v2/cros-fw-provision/service"
	state_machine "go.chromium.org/chromiumos/test/provision/v2/cros-fw-provision/state-machine"

	"go.chromium.org/chromiumos/config/go/longrunning"
	"google.golang.org/grpc"
	"google.golang.org/protobuf/types/known/anypb"
)

type FWProvisionServer struct {
	// dutAdapter provides an interface to manipulate DUT via cros-dut
	// service. Its address may be specified either when server is created,
	// or later in user's ProvisionFirmwareRequest
	dutAdapter common_utils.ServiceAdapterInterface

	// servoClient provides an interface to manipulate DUT via cros-servod
	// service. Its address may be specified either when server is created,
	// or later in user's ProvisionFirmwareRequest
	servoClient api.ServodServiceClient

	log        *log.Logger
	listenPort int

	manager *lro.Manager

	cacheServer url.URL

	board string
	model string
}

func ipEndpointToHostPort(i *api1.IpEndpoint) (string, error) {
	if len(i.GetAddress()) == 0 {
		return "", errors.New("IpEndpoint missing address")
	}
	if i.GetPort() == 0 {
		return "", errors.New("IpEndpoint missing port")
	}
	return fmt.Sprintf("%v:%v", i.GetAddress(), i.GetPort()), nil
}

// NewFWProvisionServer returns a new FWProvisionServer, a closer function, and an error.
func NewFWProvisionServer(listenPort int, log *log.Logger) (*FWProvisionServer, func(), error) {
	manager := lro.New()
	return &FWProvisionServer{
		listenPort: listenPort,
		log:        log,
		manager:    manager,
	}, manager.Close, nil
}

// Start starts the grpc server.
func (ps *FWProvisionServer) Start() error {
	l, err := net.Listen("tcp", fmt.Sprintf(":%d", ps.listenPort))
	if err != nil {
		return fmt.Errorf("failed to create listener at %d", ps.listenPort)
	}
	server := grpc.NewServer()
	api.RegisterGenericProvisionServiceServer(server, ps)
	longrunning.RegisterOperationsServer(server, ps.manager)
	ps.log.Println("provisionservice listen to request at ", l.Addr().String())

	// Write port number to ~/.cftmeta for go/cft-port-discovery
	err = portdiscovery.WriteServiceMetadata("provision", l.Addr().String(), ps.log)
	if err != nil {
		ps.log.Println("Warning: error when writing to metadata file: ", err)
	}

	return server.Serve(l)
}

// StartUp handles the initialization of the GenericProvisionService by passing in parameters through the ProvisionStartupRequest.
func (ps *FWProvisionServer) StartUp(ctx context.Context, req *api.ProvisionStartupRequest) (*api.ProvisionStartupResponse, error) {
	ps.log.Println("Received api.ProvisionStartupRequest: ", req)
	response := api.ProvisionStartupResponse{}

	if err := ps.validateStartupRequest(req); err != nil {
		response.Status = api.ProvisionStartupResponse_STATUS_INVALID_REQUEST
		return &response, err
	}

	ps.board = req.Dut.GetChromeos().DutModel.BuildTarget
	ps.model = req.Dut.GetChromeos().DutModel.ModelName

	dutServAddr, err := ipEndpointToHostPort(req.DutServer)
	if err != nil {
		response.Status = api.ProvisionStartupResponse_STATUS_INVALID_REQUEST
		return &response, errors.Wrap(err, "failed to parse IpEndpoint of Dut Server")
	}
	dutAdapter, err := connectToDutServer(dutServAddr)
	if err != nil {
		response.Status = api.ProvisionStartupResponse_STATUS_STARTUP_FAILED
		return &response, errors.Wrap(err, "connect to dut server")
	}
	ps.dutAdapter = dutAdapter

	cacheServerAddr, err := ipEndpointToHostPort(req.Dut.GetCacheServer().GetAddress())
	if err != nil {
		response.Status = api.ProvisionStartupResponse_STATUS_INVALID_REQUEST
		return &response, errors.Wrap(err, "failed to parse IpEndpoint of cache server")
	}
	ps.cacheServer.Scheme = "http"
	ps.cacheServer.Host = cacheServerAddr
	if req.Dut.GetCacheServer().GetAddress().Address == "localhost" {
		response.Status = api.ProvisionStartupResponse_STATUS_INVALID_REQUEST
		return &response, errors.New("ProvisionStartupRequest: cache_server_address must be visible from DUT, i.e. no localhost")
	}

	response.Status = api.ProvisionStartupResponse_STATUS_SUCCESS
	return &response, nil
}

func (ps *FWProvisionServer) validateStartupRequest(req *api.ProvisionStartupRequest) error {
	if req == nil {
		return errors.New("ProvisionStartupRequest is required")
	}
	if req.Dut == nil {
		return errors.New("ProvisionStartupRequest: dut is required")
	}
	if req.Dut.GetChromeos() == nil {
		return errors.New("ProvisionStartupRequest: dut.chromeos is required")
	}
	if req.Dut.GetChromeos().DutModel == nil {
		return errors.New("ProvisionStartupRequest: dut.chromeos.dut_model is required")
	}
	if req.Dut.GetChromeos().DutModel.BuildTarget == "" {
		return errors.New("ProvisionStartupRequest: dut.chromeos.dut_model.build_target is required")
	}
	if req.Dut.GetChromeos().DutModel.ModelName == "" {
		return errors.New("ProvisionStartupRequest: dut.chromeos.dut_model.model_name is required")
	}
	if req.DutServer == nil {
		return errors.New("ProvisionStartupRequest: dut_server is required")
	}
	if req.Dut.GetCacheServer() == nil {
		return errors.New("ProvisionStartupRequest: dut.cache_server is required")
	}
	if req.Dut.GetCacheServer().GetAddress() == nil {
		return errors.New("ProvisionStartupRequest: dut.cache_server.address is required")
	}
	if req.Dut.GetCacheServer().GetAddress().Address == "" {
		return errors.New("ProvisionStartupRequest: dut.cache_server.address.address is required")
	}
	if req.Dut.GetCacheServer().GetAddress().Port == 0 {
		return errors.New("ProvisionStartupRequest: dut.cache_server.address.port is required")
	}
	return nil
}

// Install starts the firmware provisioning in the background, and returns a long running operation or an error.
func (ps *FWProvisionServer) Install(ctx context.Context, req *api.InstallRequest) (*longrunning.Operation, error) {
	ps.log.Println("Received api.InstallCrosRequest: ", req)
	op := ps.manager.NewOperation()

	go ps.doProvision(context.Background(), req, op.Name)

	return op, nil
}

func (ps *FWProvisionServer) doProvision(ctx context.Context, req *api.InstallRequest, lroName string) {
	response := api.InstallResponse{}
	defer func() {
		ps.manager.SetResult(lroName, &response)
		ps.log.Printf("Provision set OP Response to:%s ", response.String())
	}()

	fwService, err := firmwareservice.NewFirmwareService(ctx, ps.dutAdapter, ps.servoClient, ps.cacheServer,
		ps.board, ps.model, false, req)
	if err != nil {
		response.Status = api.InstallResponse_STATUS_INVALID_REQUEST
		ps.log.Printf("Failed to initialize Firmware Service: %v", err)
		return
	}
	// Clean up the temporary directories on the DUT.
	defer fwService.DeleteArchiveDirectories()

	response.Status = api.InstallResponse_STATUS_SUCCESS
	var firmwareResponse *api.FirmwareProvisionResponse
	// Execute state machine
	cs := state_machine.NewFirmwarePrepareState(fwService)
	for cs != nil {
		var metadata *api.FirmwareProvisionResponse
		metadata, response.Status, err = cs.Execute(ctx, ps.log)
		if metadata != nil {
			firmwareResponse = metadata
		}
		if err != nil {
			ps.log.Printf("State machine failed: %v", err)
			break
		}
		cs = cs.Next()
	}
	// If the state machine didn't return metadata, get one from the fwService.
	if firmwareResponse == nil {
		firmwareResponse = fwService.GetVersions()
	}
	if err != nil {
		firmwareResponse.ErrorMessage = err.Error()
	}
	response.Metadata, err = anypb.New(firmwareResponse)
	ps.log.Printf("Failed to create AnyPb: %v", err)
}
