// Copyright 2015 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

package overlord

import (
	"encoding/json"
	"errors"
	"fmt"
	"github.com/flynn/go-shlex"
	"github.com/kr/pty"
	"github.com/satori/go.uuid"
	"io"
	"io/ioutil"
	"log"
	"net"
	"os"
	"os/exec"
	"runtime"
	"strings"
	"time"
)

const (
	DEFAULT_SHELL  = "/bin/bash"
	DIAL_TIMEOUT   = 3
	OVERLORD_IP    = "localhost"
	PING_INTERVAL  = 10
	PING_TIMEOUT   = 10
	RETRY_INTERVAL = 2
	READ_TIMEOUT   = 3
	RANDOM_MID     = "##random_mid##"
)

type Ghost struct {
	*RPCCore
	addrs         []string               // List of possible Overlord addresses
	connectedAddr string                 // Current connected Overlord address
	mid           string                 // Machine ID
	cid           string                 // Client ID
	mode          int                    // mode, see constants.go
	properties    map[string]interface{} // Client properties
	reset         bool                   // Whether to reset the connection
	quit          bool                   // Whether to quit the connection
	readChan      chan string            // The incoming data channel
	readErrChan   chan error             // The incoming data error channel
	pauseLanDisc  bool                   // Stop LAN discovery
	shellCommand  string                 // filename to cat in logcat mode
}

func NewGhost(addrs []string, mode int, mid string) *Ghost {
	var finalMid string
	var err error

	if mid == RANDOM_MID {
		finalMid = uuid.NewV4().String()
	} else if mid != "" {
		finalMid = mid
	} else {
		finalMid, err = GetMachineID()
		if err != nil {
			panic(err)
		}
	}
	return &Ghost{
		RPCCore:      NewRPCCore(nil),
		addrs:        addrs,
		mid:          finalMid,
		cid:          uuid.NewV4().String(),
		mode:         mode,
		properties:   make(map[string]interface{}),
		reset:        false,
		quit:         false,
		pauseLanDisc: false,
	}
}

func (self *Ghost) SetCid(cid string) *Ghost {
	self.cid = cid
	return self
}

func (self *Ghost) SetCommand(command string) *Ghost {
	self.shellCommand = command
	return self
}

func (self *Ghost) ExistsInAddr(target string) bool {
	for _, x := range self.addrs {
		if target == x {
			return true
		}
	}
	return false
}

func (self *Ghost) LoadPropertiesFromFile(filename string) {
	bytes, err := ioutil.ReadFile(filename)
	if err != nil {
		log.Printf("LoadPropertiesFromFile: %s\n", err)
		return
	}

	if err := json.Unmarshal(bytes, &self.properties); err != nil {
		log.Printf("LoadPropertiesFromFile: %s\n", err)
		return
	}
}

func (self *Ghost) handleTerminalRequest(req *Request) error {
	type RequestParams struct {
		Sid string `json:"sid"`
	}

	var params RequestParams
	if err := json.Unmarshal(req.Params, &params); err != nil {
		return err
	}

	go func() {
		log.Printf("Received terminal command, Terminal %s spawned\n", params.Sid)
		addrs := []string{self.connectedAddr}
		// Terminal sessions are identified with session ID, thus we don't care
		// machine ID and can make them random.
		g := NewGhost(addrs, TERMINAL, RANDOM_MID).SetCid(params.Sid)
		g.Start(true)
	}()

	res := NewResponse(req.Rid, SUCCESS, nil)
	return self.SendResponse(res)
}

func (self *Ghost) handleShellRequest(req *Request) error {
	type RequestParams struct {
		Sid string `json:"sid"`
		Cmd string `json:"command"`
	}

	var params RequestParams
	if err := json.Unmarshal(req.Params, &params); err != nil {
		return err
	}

	go func() {
		log.Printf("Received shell command: %s, shell %s spawned\n", params.Cmd, params.Sid)
		addrs := []string{self.connectedAddr}
		// Shell sessions are identified with session ID, thus we don't care
		// machine ID and can make them random.
		g := NewGhost(addrs, SHELL, RANDOM_MID).SetCid(params.Sid).SetCommand(params.Cmd)
		g.Start(true)
	}()

	res := NewResponse(req.Rid, SUCCESS, nil)
	return self.SendResponse(res)
}

func (self *Ghost) handleRequest(req *Request) error {
	var err error
	switch req.Name {
	case "shell":
		err = self.handleShellRequest(req)
	case "terminal":
		err = self.handleTerminalRequest(req)
	default:
		err = errors.New(`Received unregistered command "` + req.Name + `", ignoring`)
	}
	return err
}

func (self *Ghost) ProcessRequests(reqs []*Request) error {
	for _, req := range reqs {
		if err := self.handleRequest(req); err != nil {
			return err
		}
	}
	return nil
}

func (self *Ghost) Ping() error {
	pingHandler := func(res *Response) error {
		if res == nil {
			self.reset = true
			return errors.New("Ping timeout")
		}
		return nil
	}
	req := NewRequest("ping", nil)
	req.SetTimeout(PING_TIMEOUT)
	return self.SendRequest(req, pingHandler)
}

// Spawn a PTY server and forward I/O to the TCP socket.
func (self *Ghost) SpawnPTYServer(res *Response) error {
	log.Println("SpawnPTYServer: started")

	shell := os.Getenv("SHELL")
	if shell == "" {
		shell = DEFAULT_SHELL
	}

	home := os.Getenv("HOME")
	if home == "" {
		home = "/"
	}

	os.Chdir(home)
	cmd := exec.Command(shell)
	tty, err := pty.Start(cmd)
	if err != nil {
		return errors.New(`SpawnPTYServer: Cannot start "` + shell + `", abort`)
	}

	defer func() {
		self.quit = true
		cmd.Process.Kill()
		tty.Close()
		log.Println("SpawnPTYServer: terminated")
	}()

	stopConn := make(chan bool, 1)

	go func() {
		io.Copy(self.Conn, tty)
		cmd.Wait()
		stopConn <- true
	}()

	for {
		select {
		case buffer := <-self.readChan:
			tty.Write([]byte(buffer))
		case err := <-self.readErrChan:
			if err == io.EOF {
				log.Println("SpawnPTYServer: connection dropped")
				return nil
			} else {
				return err
			}
		case s := <-stopConn:
			if s {
				return nil
			}
		}
	}

	return nil
}

// Spawn a Shell server and forward input/output from/to the TCP socket.
func (self *Ghost) SpawnShellServer(res *Response) error {
	log.Println("SpawnShellServer: started")

	var err error

	defer func() {
		if err != nil {
			self.Conn.Write([]byte(err.Error() + "\n"))
		}
		self.quit = true
		self.Conn.Close()
		log.Println("SpawnShellServer: terminated")
	}()

	parts, err := shlex.Split(self.shellCommand)
	cmd_name, err := exec.LookPath(parts[0])
	if err != nil {
		return err
	}

	cmd := exec.Command(cmd_name, parts[1:]...)
	stdout, err := cmd.StdoutPipe()
	if err != nil {
		return err
	}
	stderr, err := cmd.StderrPipe()
	if err != nil {
		return err
	}
	stdin, err := cmd.StdinPipe()
	if err != nil {
		return err
	}

	stopConn := make(chan bool, 1)

	go io.Copy(self.Conn, stdout)
	go func() {
		io.Copy(self.Conn, stderr)
		cmd.Wait()
		stopConn <- true
	}()

	if err = cmd.Start(); err != nil {
		return err
	}

	for {
		select {
		case buf := <-self.readChan:
			stdin.Write([]byte(buf))
		case err := <-self.readErrChan:
			if err == io.EOF {
				cmd.Process.Kill()
				return errors.New("SpawnShellServer: connection dropped")
			} else {
				return err
			}
		case s := <-stopConn:
			if s {
				return nil
			}
		}
	}

	return nil
}

// Register existent to Overlord.
func (self *Ghost) Register() error {
	for _, addr := range self.addrs {
		log.Printf("Trying %s ...\n", addr)
		conn, err := net.DialTimeout("tcp", addr, DIAL_TIMEOUT*time.Second)
		if err == nil {
			log.Println("Connection established, registering...")
			self.Conn = conn
			req := NewRequest("register", map[string]interface{}{
				"mid":        self.mid,
				"cid":        self.cid,
				"mode":       self.mode,
				"properties": self.properties,
			})

			registered := func(res *Response) error {
				if res == nil {
					self.reset = true
					return errors.New("Register request timeout")
				} else {
					log.Printf("Registered with Overlord at %s", addr)
					self.connectedAddr = addr
					self.pauseLanDisc = true
				}
				return nil
			}

			var handler ResponseHandler
			switch self.mode {
			case AGENT:
				handler = registered
			case TERMINAL:
				handler = self.SpawnPTYServer
			case SHELL:
				handler = self.SpawnShellServer
			}
			err = self.SendRequest(req, handler)
			return nil
		}
	}

	return errors.New("Cannot connect to any server")
}

// Reset all states for a new connection.
func (self *Ghost) Reset() {
	self.ClearRequests()
	self.reset = false
}

// Main routine for listen to socket messages.
func (self *Ghost) Listen() error {
	readChan, readErrChan := self.SpawnReaderRoutine()
	pingTicker := time.NewTicker(time.Duration(PING_INTERVAL * time.Second))
	reqTicker := time.NewTicker(time.Duration(TIMEOUT_CHECK_SECS * time.Second))

	self.readChan = readChan
	self.readErrChan = readErrChan

	defer func() {
		self.Conn.Close()
		self.pauseLanDisc = false
	}()

	for {
		select {
		case buffer := <-readChan:
			reqs := self.ParseRequests(buffer, false)
			if self.quit {
				return nil
			}
			if err := self.ProcessRequests(reqs); err != nil {
				log.Println(err)
				continue
			}
		case err := <-readErrChan:
			if err == io.EOF {
				return errors.New("Connection dropped")
			} else {
				return err
			}
		case <-pingTicker.C:
			self.Ping()
		case <-reqTicker.C:
			err := self.ScanForTimeoutRequests()
			if self.reset {
				return err
			}
		}
	}
}

// Start listening to LAN discovery message.
func (self *Ghost) StartLanDiscovery() {
	log.Println("LAN discovery: started")
	buf := make([]byte, BUFSIZ)
	conn, err := net.ListenPacket("udp", fmt.Sprintf(":%d", OVERLORD_LD_PORT))
	if err != nil {
		log.Printf("LAN discovery: %s, abort\n", err.Error())
		return
	}

	defer func() {
		conn.Close()
		log.Println("LAN discovery: stopped")
	}()

	for {
		conn.SetReadDeadline(time.Now().Add(READ_TIMEOUT * time.Second))
		n, remote, err := conn.ReadFrom(buf)

		if self.pauseLanDisc {
			log.Println("LAN discovery: paused")
			ticker := time.NewTicker(READ_TIMEOUT * time.Second)
		waitLoop:
			for {
				select {
				case <-ticker.C:
					if !self.pauseLanDisc {
						break waitLoop
					}
				}
			}
			log.Println("LAN discovery: resumed")
			continue
		}

		if err != nil {
			continue
		}

		// LAN discovery packet format: "OVERLOARD [host]:port"
		data := strings.Split(string(buf[:n]), " ")
		if data[0] != "OVERLORD" {
			continue
		}

		overlordAddrParts := strings.Split(data[1], ":")
		remoteAddrParts := strings.Split(remote.String(), ":")

		var remoteAddr string
		if strings.Trim(overlordAddrParts[0], " ") == "" {
			remoteAddr = remoteAddrParts[0] + ":" + overlordAddrParts[1]
		} else {
			remoteAddr = data[1]
		}

		if !self.ExistsInAddr(remoteAddr) {
			log.Printf("LAN discovery: got overlord address %s", remoteAddr)
			self.addrs = append(self.addrs, remoteAddr)
		}
	}
}

// ScanGateWay scans currenty netowrk gateway and add it into addrs if not
// already exist.
func (self *Ghost) ScanGateway() {
	if gateways, err := GetGateWayIP(); err == nil {
		for _, gw := range gateways {
			addr := fmt.Sprintf("%s:%d", gw, OVERLORD_PORT)
			if !self.ExistsInAddr(addr) {
				self.addrs = append(self.addrs, addr)
			}
		}
	}
}

// Bootstrap and start the client.
func (self *Ghost) Start(noLanDisc bool) {
	log.Printf("%s started\n", ModeStr(self.mode))
	log.Printf("MID: %s\n", self.mid)
	log.Printf("CID: %s\n", self.cid)

	if !noLanDisc {
		go self.StartLanDiscovery()
	}

	for {
		self.ScanGateway()
		err := self.Register()
		if err == nil {
			err = self.Listen()
		}
		if self.quit {
			break
		}
		self.Reset()
		log.Printf("%s, retrying in %ds\n", err, RETRY_INTERVAL)
		time.Sleep(RETRY_INTERVAL * time.Second)
	}
}

func StartGhost(args []string, mid string, noLanDisc bool, propFile string) {
	var addrs []string

	if len(args) >= 1 {
		addrs = append(addrs, fmt.Sprintf("%s:%d", args[0], OVERLORD_PORT))
	}
	addrs = append(addrs, fmt.Sprintf("%s:%d", OVERLORD_IP, OVERLORD_PORT))

	g := NewGhost(addrs, AGENT, mid)
	if propFile != "" {
		g.LoadPropertiesFromFile(propFile)
	}
	go g.Start(noLanDisc)

	ticker := time.NewTicker(time.Duration(60 * time.Second))

	for {
		select {
		case <-ticker.C:
			log.Printf("Num of Goroutines: %d\n", runtime.NumGoroutine())
		}
	}
}
