| // Copyright 2015 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package overlord |
| |
| import ( |
| "encoding/json" |
| "errors" |
| "fmt" |
| "github.com/flynn/go-shlex" |
| "github.com/kr/pty" |
| "github.com/satori/go.uuid" |
| "io" |
| "io/ioutil" |
| "log" |
| "net" |
| "os" |
| "os/exec" |
| "runtime" |
| "strings" |
| "time" |
| ) |
| |
| const ( |
| DEFAULT_SHELL = "/bin/bash" |
| DIAL_TIMEOUT = 3 |
| OVERLORD_IP = "localhost" |
| PING_INTERVAL = 10 |
| PING_TIMEOUT = 10 |
| RETRY_INTERVAL = 2 |
| READ_TIMEOUT = 3 |
| RANDOM_MID = "##random_mid##" |
| ) |
| |
| type Ghost struct { |
| *RPCCore |
| addrs []string // List of possible Overlord addresses |
| connectedAddr string // Current connected Overlord address |
| mid string // Machine ID |
| cid string // Client ID |
| mode int // mode, see constants.go |
| properties map[string]interface{} // Client properties |
| reset bool // Whether to reset the connection |
| quit bool // Whether to quit the connection |
| readChan chan string // The incoming data channel |
| readErrChan chan error // The incoming data error channel |
| pauseLanDisc bool // Stop LAN discovery |
| shellCommand string // filename to cat in logcat mode |
| } |
| |
| func NewGhost(addrs []string, mode int, mid string) *Ghost { |
| var finalMid string |
| var err error |
| |
| if mid == RANDOM_MID { |
| finalMid = uuid.NewV4().String() |
| } else if mid != "" { |
| finalMid = mid |
| } else { |
| finalMid, err = GetMachineID() |
| if err != nil { |
| panic(err) |
| } |
| } |
| return &Ghost{ |
| RPCCore: NewRPCCore(nil), |
| addrs: addrs, |
| mid: finalMid, |
| cid: uuid.NewV4().String(), |
| mode: mode, |
| properties: make(map[string]interface{}), |
| reset: false, |
| quit: false, |
| pauseLanDisc: false, |
| } |
| } |
| |
| func (self *Ghost) SetCid(cid string) *Ghost { |
| self.cid = cid |
| return self |
| } |
| |
| func (self *Ghost) SetCommand(command string) *Ghost { |
| self.shellCommand = command |
| return self |
| } |
| |
| func (self *Ghost) ExistsInAddr(target string) bool { |
| for _, x := range self.addrs { |
| if target == x { |
| return true |
| } |
| } |
| return false |
| } |
| |
| func (self *Ghost) LoadPropertiesFromFile(filename string) { |
| bytes, err := ioutil.ReadFile(filename) |
| if err != nil { |
| log.Printf("LoadPropertiesFromFile: %s\n", err) |
| return |
| } |
| |
| if err := json.Unmarshal(bytes, &self.properties); err != nil { |
| log.Printf("LoadPropertiesFromFile: %s\n", err) |
| return |
| } |
| } |
| |
| func (self *Ghost) handleTerminalRequest(req *Request) error { |
| type RequestParams struct { |
| Sid string `json:"sid"` |
| } |
| |
| var params RequestParams |
| if err := json.Unmarshal(req.Params, ¶ms); err != nil { |
| return err |
| } |
| |
| go func() { |
| log.Printf("Received terminal command, Terminal %s spawned\n", params.Sid) |
| addrs := []string{self.connectedAddr} |
| // Terminal sessions are identified with session ID, thus we don't care |
| // machine ID and can make them random. |
| g := NewGhost(addrs, TERMINAL, RANDOM_MID).SetCid(params.Sid) |
| g.Start(true) |
| }() |
| |
| res := NewResponse(req.Rid, SUCCESS, nil) |
| return self.SendResponse(res) |
| } |
| |
| func (self *Ghost) handleShellRequest(req *Request) error { |
| type RequestParams struct { |
| Sid string `json:"sid"` |
| Cmd string `json:"command"` |
| } |
| |
| var params RequestParams |
| if err := json.Unmarshal(req.Params, ¶ms); err != nil { |
| return err |
| } |
| |
| go func() { |
| log.Printf("Received shell command: %s, shell %s spawned\n", params.Cmd, params.Sid) |
| addrs := []string{self.connectedAddr} |
| // Shell sessions are identified with session ID, thus we don't care |
| // machine ID and can make them random. |
| g := NewGhost(addrs, SHELL, RANDOM_MID).SetCid(params.Sid).SetCommand(params.Cmd) |
| g.Start(true) |
| }() |
| |
| res := NewResponse(req.Rid, SUCCESS, nil) |
| return self.SendResponse(res) |
| } |
| |
| func (self *Ghost) handleRequest(req *Request) error { |
| var err error |
| switch req.Name { |
| case "shell": |
| err = self.handleShellRequest(req) |
| case "terminal": |
| err = self.handleTerminalRequest(req) |
| default: |
| err = errors.New(`Received unregistered command "` + req.Name + `", ignoring`) |
| } |
| return err |
| } |
| |
| func (self *Ghost) ProcessRequests(reqs []*Request) error { |
| for _, req := range reqs { |
| if err := self.handleRequest(req); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| func (self *Ghost) Ping() error { |
| pingHandler := func(res *Response) error { |
| if res == nil { |
| self.reset = true |
| return errors.New("Ping timeout") |
| } |
| return nil |
| } |
| req := NewRequest("ping", nil) |
| req.SetTimeout(PING_TIMEOUT) |
| return self.SendRequest(req, pingHandler) |
| } |
| |
| // Spawn a PTY server and forward I/O to the TCP socket. |
| func (self *Ghost) SpawnPTYServer(res *Response) error { |
| log.Println("SpawnPTYServer: started") |
| |
| shell := os.Getenv("SHELL") |
| if shell == "" { |
| shell = DEFAULT_SHELL |
| } |
| |
| home := os.Getenv("HOME") |
| if home == "" { |
| home = "/" |
| } |
| |
| os.Chdir(home) |
| cmd := exec.Command(shell) |
| tty, err := pty.Start(cmd) |
| if err != nil { |
| return errors.New(`SpawnPTYServer: Cannot start "` + shell + `", abort`) |
| } |
| |
| defer func() { |
| self.quit = true |
| cmd.Process.Kill() |
| tty.Close() |
| log.Println("SpawnPTYServer: terminated") |
| }() |
| |
| stopConn := make(chan bool, 1) |
| |
| go func() { |
| io.Copy(self.Conn, tty) |
| cmd.Wait() |
| stopConn <- true |
| }() |
| |
| for { |
| select { |
| case buffer := <-self.readChan: |
| tty.Write([]byte(buffer)) |
| case err := <-self.readErrChan: |
| if err == io.EOF { |
| log.Println("SpawnPTYServer: connection dropped") |
| return nil |
| } else { |
| return err |
| } |
| case s := <-stopConn: |
| if s { |
| return nil |
| } |
| } |
| } |
| |
| return nil |
| } |
| |
| // Spawn a Shell server and forward input/output from/to the TCP socket. |
| func (self *Ghost) SpawnShellServer(res *Response) error { |
| log.Println("SpawnShellServer: started") |
| |
| var err error |
| |
| defer func() { |
| if err != nil { |
| self.Conn.Write([]byte(err.Error() + "\n")) |
| } |
| self.quit = true |
| self.Conn.Close() |
| log.Println("SpawnShellServer: terminated") |
| }() |
| |
| parts, err := shlex.Split(self.shellCommand) |
| cmd_name, err := exec.LookPath(parts[0]) |
| if err != nil { |
| return err |
| } |
| |
| cmd := exec.Command(cmd_name, parts[1:]...) |
| stdout, err := cmd.StdoutPipe() |
| if err != nil { |
| return err |
| } |
| stderr, err := cmd.StderrPipe() |
| if err != nil { |
| return err |
| } |
| stdin, err := cmd.StdinPipe() |
| if err != nil { |
| return err |
| } |
| |
| stopConn := make(chan bool, 1) |
| |
| go io.Copy(self.Conn, stdout) |
| go func() { |
| io.Copy(self.Conn, stderr) |
| cmd.Wait() |
| stopConn <- true |
| }() |
| |
| if err = cmd.Start(); err != nil { |
| return err |
| } |
| |
| for { |
| select { |
| case buf := <-self.readChan: |
| stdin.Write([]byte(buf)) |
| case err := <-self.readErrChan: |
| if err == io.EOF { |
| cmd.Process.Kill() |
| return errors.New("SpawnShellServer: connection dropped") |
| } else { |
| return err |
| } |
| case s := <-stopConn: |
| if s { |
| return nil |
| } |
| } |
| } |
| |
| return nil |
| } |
| |
| // Register existent to Overlord. |
| func (self *Ghost) Register() error { |
| for _, addr := range self.addrs { |
| log.Printf("Trying %s ...\n", addr) |
| conn, err := net.DialTimeout("tcp", addr, DIAL_TIMEOUT*time.Second) |
| if err == nil { |
| log.Println("Connection established, registering...") |
| self.Conn = conn |
| req := NewRequest("register", map[string]interface{}{ |
| "mid": self.mid, |
| "cid": self.cid, |
| "mode": self.mode, |
| "properties": self.properties, |
| }) |
| |
| registered := func(res *Response) error { |
| if res == nil { |
| self.reset = true |
| return errors.New("Register request timeout") |
| } else { |
| log.Printf("Registered with Overlord at %s", addr) |
| self.connectedAddr = addr |
| self.pauseLanDisc = true |
| } |
| return nil |
| } |
| |
| var handler ResponseHandler |
| switch self.mode { |
| case AGENT: |
| handler = registered |
| case TERMINAL: |
| handler = self.SpawnPTYServer |
| case SHELL: |
| handler = self.SpawnShellServer |
| } |
| err = self.SendRequest(req, handler) |
| return nil |
| } |
| } |
| |
| return errors.New("Cannot connect to any server") |
| } |
| |
| // Reset all states for a new connection. |
| func (self *Ghost) Reset() { |
| self.ClearRequests() |
| self.reset = false |
| } |
| |
| // Main routine for listen to socket messages. |
| func (self *Ghost) Listen() error { |
| readChan, readErrChan := self.SpawnReaderRoutine() |
| pingTicker := time.NewTicker(time.Duration(PING_INTERVAL * time.Second)) |
| reqTicker := time.NewTicker(time.Duration(TIMEOUT_CHECK_SECS * time.Second)) |
| |
| self.readChan = readChan |
| self.readErrChan = readErrChan |
| |
| defer func() { |
| self.Conn.Close() |
| self.pauseLanDisc = false |
| }() |
| |
| for { |
| select { |
| case buffer := <-readChan: |
| reqs := self.ParseRequests(buffer, false) |
| if self.quit { |
| return nil |
| } |
| if err := self.ProcessRequests(reqs); err != nil { |
| log.Println(err) |
| continue |
| } |
| case err := <-readErrChan: |
| if err == io.EOF { |
| return errors.New("Connection dropped") |
| } else { |
| return err |
| } |
| case <-pingTicker.C: |
| self.Ping() |
| case <-reqTicker.C: |
| err := self.ScanForTimeoutRequests() |
| if self.reset { |
| return err |
| } |
| } |
| } |
| } |
| |
| // Start listening to LAN discovery message. |
| func (self *Ghost) StartLanDiscovery() { |
| log.Println("LAN discovery: started") |
| buf := make([]byte, BUFSIZ) |
| conn, err := net.ListenPacket("udp", fmt.Sprintf(":%d", OVERLORD_LD_PORT)) |
| if err != nil { |
| log.Printf("LAN discovery: %s, abort\n", err.Error()) |
| return |
| } |
| |
| defer func() { |
| conn.Close() |
| log.Println("LAN discovery: stopped") |
| }() |
| |
| for { |
| conn.SetReadDeadline(time.Now().Add(READ_TIMEOUT * time.Second)) |
| n, remote, err := conn.ReadFrom(buf) |
| |
| if self.pauseLanDisc { |
| log.Println("LAN discovery: paused") |
| ticker := time.NewTicker(READ_TIMEOUT * time.Second) |
| waitLoop: |
| for { |
| select { |
| case <-ticker.C: |
| if !self.pauseLanDisc { |
| break waitLoop |
| } |
| } |
| } |
| log.Println("LAN discovery: resumed") |
| continue |
| } |
| |
| if err != nil { |
| continue |
| } |
| |
| // LAN discovery packet format: "OVERLOARD [host]:port" |
| data := strings.Split(string(buf[:n]), " ") |
| if data[0] != "OVERLORD" { |
| continue |
| } |
| |
| overlordAddrParts := strings.Split(data[1], ":") |
| remoteAddrParts := strings.Split(remote.String(), ":") |
| |
| var remoteAddr string |
| if strings.Trim(overlordAddrParts[0], " ") == "" { |
| remoteAddr = remoteAddrParts[0] + ":" + overlordAddrParts[1] |
| } else { |
| remoteAddr = data[1] |
| } |
| |
| if !self.ExistsInAddr(remoteAddr) { |
| log.Printf("LAN discovery: got overlord address %s", remoteAddr) |
| self.addrs = append(self.addrs, remoteAddr) |
| } |
| } |
| } |
| |
| // ScanGateWay scans currenty netowrk gateway and add it into addrs if not |
| // already exist. |
| func (self *Ghost) ScanGateway() { |
| if gateways, err := GetGateWayIP(); err == nil { |
| for _, gw := range gateways { |
| addr := fmt.Sprintf("%s:%d", gw, OVERLORD_PORT) |
| if !self.ExistsInAddr(addr) { |
| self.addrs = append(self.addrs, addr) |
| } |
| } |
| } |
| } |
| |
| // Bootstrap and start the client. |
| func (self *Ghost) Start(noLanDisc bool) { |
| log.Printf("%s started\n", ModeStr(self.mode)) |
| log.Printf("MID: %s\n", self.mid) |
| log.Printf("CID: %s\n", self.cid) |
| |
| if !noLanDisc { |
| go self.StartLanDiscovery() |
| } |
| |
| for { |
| self.ScanGateway() |
| err := self.Register() |
| if err == nil { |
| err = self.Listen() |
| } |
| if self.quit { |
| break |
| } |
| self.Reset() |
| log.Printf("%s, retrying in %ds\n", err, RETRY_INTERVAL) |
| time.Sleep(RETRY_INTERVAL * time.Second) |
| } |
| } |
| |
| func StartGhost(args []string, mid string, noLanDisc bool, propFile string) { |
| var addrs []string |
| |
| if len(args) >= 1 { |
| addrs = append(addrs, fmt.Sprintf("%s:%d", args[0], OVERLORD_PORT)) |
| } |
| addrs = append(addrs, fmt.Sprintf("%s:%d", OVERLORD_IP, OVERLORD_PORT)) |
| |
| g := NewGhost(addrs, AGENT, mid) |
| if propFile != "" { |
| g.LoadPropertiesFromFile(propFile) |
| } |
| go g.Start(noLanDisc) |
| |
| ticker := time.NewTicker(time.Duration(60 * time.Second)) |
| |
| for { |
| select { |
| case <-ticker.C: |
| log.Printf("Num of Goroutines: %d\n", runtime.NumGoroutine()) |
| } |
| } |
| } |