blob: f6c7c6f85f52163d8032e06406c08ac63b82eb36 [file] [log] [blame]
// Copyright 2015 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package overlord
import (
"encoding/json"
"errors"
"github.com/gorilla/websocket"
"io"
"log"
"net"
"strings"
"time"
)
type RegistrationFailedError error
const (
LOG_BUFSIZ = 1024 * 16
PING_RECV_TIMEOUT = PING_TIMEOUT * 2
)
// Since Shell and Logcat are initiated by Overlord, there is only one observer,
// i.e. the one who requested the connection. On the other hand, logcat
// could have multiple observers, so we need to broadcast the result to all of
// them.
type ConnServer struct {
*RPCCore
Mode int // Client mode, see constants.go
Bridge chan interface{} // Channel for overlord commmand
Cid string // Client ID
Mid string // Machine ID
Properties map[string]interface{} // Client properties
ovl *Overlord // Overlord handle
registered bool // Whether we are registered or not
wsConn *websocket.Conn // WebSocket for Shell and Logcat
logFormat int // Log format, see constants.go
logWsConns []*websocket.Conn // WebSockets for logcat
logHistory string // Log buffer for logcat
stopListen chan bool // Stop the Listen() loop
lastPing int64 // Last time the client pinged
}
func NewConnServer(ovl *Overlord, conn net.Conn) *ConnServer {
return &ConnServer{
RPCCore: NewRPCCore(conn),
Mode: NONE,
Bridge: make(chan interface{}),
Properties: make(map[string]interface{}),
ovl: ovl,
stopListen: make(chan bool, 1),
registered: false,
}
}
func (self *ConnServer) SetProperties(prop map[string]interface{}) {
if prop != nil {
self.Properties = prop
}
addr := self.Conn.RemoteAddr().String()
parts := strings.Split(addr, ":")
self.Properties["ip"] = strings.Join(parts[:len(parts)-1], ":")
}
func (self *ConnServer) Terminate() {
if self.registered {
self.ovl.Unregister(self)
}
if self.Conn != nil {
self.Conn.Close()
}
if self.wsConn != nil {
self.wsConn.WriteMessage(websocket.CloseMessage, []byte(""))
self.wsConn.Close()
}
}
// writeWebsocket is a helper function for written text to websocket in the
// correct format.
func (self *ConnServer) writeLogToWS(conn *websocket.Conn, buf string) error {
if self.Mode == LOGCAT && self.logFormat == TEXT {
buf = ToVTNewLine(buf)
}
return conn.WriteMessage(websocket.TextMessage, B64Encode(buf))
}
// Forwards the input from Websocket to TCP socket.
func (self *ConnServer) forwardWSInput() {
defer func() {
self.stopListen <- true
}()
for {
mt, payload, err := self.wsConn.ReadMessage()
if err != nil {
if err == io.EOF {
log.Println("WebSocket connection terminated")
} else {
log.Println("Unknown error while reading from WebSocket")
}
return
}
switch mt {
case websocket.BinaryMessage:
log.Printf("Ignoring binary message: %q\n", payload)
case websocket.TextMessage:
self.Conn.Write(payload)
default:
log.Printf("Invalid message type %d\n", mt)
return
}
}
return
}
// Forward the PTY output to WebSocket.
func (self *ConnServer) forwardTerminalOutput(buffer string) {
if self.wsConn == nil {
self.stopListen <- true
}
self.wsConn.WriteMessage(websocket.TextMessage, B64Encode(buffer))
}
// Forward the logcat output to WebSocket.
func (self *ConnServer) forwardShellOutput(buffer string) {
if self.wsConn == nil {
self.stopListen <- true
}
self.writeLogToWS(self.wsConn, buffer)
}
// Forward the logcat output to WebSocket.
func (self *ConnServer) forwardLogcatOutput(buffer string) {
self.logHistory += buffer
if l := len(self.logHistory); l > LOG_BUFSIZ {
self.logHistory = self.logHistory[l-LOG_BUFSIZ : l]
}
var aliveWsConns []*websocket.Conn
for _, conn := range self.logWsConns[:] {
if err := self.writeLogToWS(conn, buffer); err == nil {
aliveWsConns = append(aliveWsConns, conn)
} else {
conn.Close()
}
}
self.logWsConns = aliveWsConns
}
func (self *ConnServer) ProcessRequests(reqs []*Request) error {
for _, req := range reqs {
if err := self.handleRequest(req); err != nil {
return err
}
}
return nil
}
// Handle the requests from Overlord.
func (self *ConnServer) handleOverlordRequest(obj interface{}) {
log.Printf("Received %T command from overlord\n", obj)
switch v := obj.(type) {
case SpawnTerminalCmd:
self.SpawnTerminal(v.Sid)
case SpawnShellCmd:
self.SpawnShell(v.Sid, v.Command)
case ConnectLogcatCmd:
// Write log history to newly joined client
self.writeLogToWS(v.Conn, self.logHistory)
self.logWsConns = append(self.logWsConns, v.Conn)
}
}
// Main routine for listen to socket messages.
func (self *ConnServer) Listen() {
var reqs []*Request
readChan, readErrChan := self.SpawnReaderRoutine()
ticker := time.NewTicker(time.Duration(TIMEOUT_CHECK_SECS * time.Second))
defer self.Terminate()
for {
select {
case buffer := <-readChan:
switch self.Mode {
case TERMINAL:
self.forwardTerminalOutput(buffer)
case SHELL:
self.forwardShellOutput(buffer)
case LOGCAT:
self.forwardLogcatOutput(buffer)
default:
// Only Parse the first message if we are not registered, since
// if we are in logcat mode, we want to preserve the rest of the
// data and forward it to the websocket.
reqs = self.ParseRequests(buffer, !self.registered)
if err := self.ProcessRequests(reqs); err != nil {
if _, ok := err.(RegistrationFailedError); ok {
log.Printf("%s, abort", err)
return
} else {
log.Println(err)
}
}
// If self.mode changed, means we just got a registration message and
// are in a different mode.
switch self.Mode {
case TERMINAL, SHELL:
// Start a goroutine to forward the WebSocket Input
go self.forwardWSInput()
case LOGCAT:
// A logcat client does not wait for ACK before sending
// stream, so we need to forward the remaining content of the buffer
if self.ReadBuffer != "" {
self.forwardLogcatOutput(self.ReadBuffer)
self.ReadBuffer = ""
}
}
}
case err := <-readErrChan:
if err == io.EOF {
log.Printf("connection dropped: %s\n", self.Mid)
} else {
log.Printf("unknown network error for %s: %s\n", self.Mid, err.Error())
}
return
case msg := <-self.Bridge:
self.handleOverlordRequest(msg)
case <-ticker.C:
if err := self.ScanForTimeoutRequests(); err != nil {
log.Println(err)
}
if self.Mode == AGENT && self.lastPing != 0 &&
time.Now().Unix()-self.lastPing > PING_RECV_TIMEOUT {
log.Printf("Client %s timeout\n", self.Mid)
return
}
case s := <-self.stopListen:
if s {
return
}
}
}
}
// Request handlers
func (self *ConnServer) handlePingRequest(req *Request) error {
self.lastPing = time.Now().Unix()
res := NewResponse(req.Rid, "pong", nil)
return self.SendResponse(res)
}
func (self *ConnServer) handleRegisterRequest(req *Request) error {
type RequestArgs struct {
Cid string `json:"cid"`
Mid string `json:"mid"`
Mode int `json:"mode"`
Format int `json:"format"`
Properties map[string]interface{} `json:"properties"`
}
var args RequestArgs
if err := json.Unmarshal(req.Params, &args); err != nil {
return err
} else {
if len(args.Mid) == 0 {
return errors.New("handleRegisterRequest: Empty machine ID received")
}
if len(args.Cid) == 0 {
return errors.New("handleRegisterRequest: Empty client ID received")
}
}
var err error
self.Cid = args.Cid
self.Mid = args.Mid
self.Mode = args.Mode
self.logFormat = args.Format
self.SetProperties(args.Properties)
self.wsConn, err = self.ovl.Register(self)
if err != nil {
return RegistrationFailedError(err)
}
self.registered = true
self.lastPing = time.Now().Unix()
res := NewResponse(req.Rid, SUCCESS, nil)
return self.SendResponse(res)
}
func (self *ConnServer) handleRequest(req *Request) error {
var err error
switch req.Name {
case "ping":
err = self.handlePingRequest(req)
case "register":
err = self.handleRegisterRequest(req)
}
return err
}
// Spawn a remote terminal connection (a ghost with mode TERMINAL).
// sid is the session ID, which will be used as the client ID of the new ghost.
func (self *ConnServer) SpawnTerminal(sid string) {
handler := func(res *Response) error {
if res == nil {
return errors.New("SpawnTerminal: command timeout")
}
if res.Response != SUCCESS {
return errors.New("SpawnTerminal failed: " + res.Response)
}
return nil
}
req := NewRequest("terminal", map[string]interface{}{"sid": sid})
self.SendRequest(req, handler)
}
// Spawn a remote shell command connection (a ghost with mode SHELL).
// sid is the session ID, which will be used as the client ID of the new ghost.
// command is the command to execute.
func (self *ConnServer) SpawnShell(sid string, command string) {
handler := func(res *Response) error {
if res == nil {
return errors.New("SpawnShell: command timeout ")
}
if res.Response != SUCCESS {
return errors.New("SpawnShell failed: " + res.Response)
}
return nil
}
req := NewRequest("shell", map[string]interface{}{
"sid": sid, "command": command})
self.SendRequest(req, handler)
}