blob: e14415a7ea40dc8e78268ba3de790566db74f027 [file] [log] [blame]
// Copyright 2015 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package overlord
import (
"encoding/json"
"errors"
"github.com/satori/go.uuid"
"log"
"net"
"strings"
"time"
)
const (
DEBUG_RPC = false
SEP = "\r\n"
BUFSIZ = 8192
REQUEST_TIMEOUT_SECS = 60 // Number of seconds before request timeouts
TIMEOUT_CHECK_SECS = 3 // The time between checking for timeout
)
// The interface which defines a sendable message.
type Message interface {
Marshal() ([]byte, error)
}
// Request Object.
// Implements the Message interface.
// If Timeout < 0, then the response can be omitted.
type Request struct {
Rid string `json:"rid"`
Timeout int64 `json:"timeout"`
Name string `json:"name"`
Params json.RawMessage `json:"params"`
}
// Create a new Request object.
// name is the name of the request.
// params is map between string and any other JSON-serializable data structure.
func NewRequest(name string, params map[string]interface{}) *Request {
req := &Request{
Rid: uuid.NewV4().String(),
Timeout: REQUEST_TIMEOUT_SECS,
Name: name,
}
if targs, err := json.Marshal(params); err != nil {
panic(err)
} else {
req.Params = json.RawMessage(targs)
}
return req
}
// Set the timeout of request.
// The default timeout is is defined in REQUEST_TIMEOUT_SECS.
func (self *Request) SetTimeout(timeout int64) {
self.Timeout = timeout
}
func (self *Request) Marshal() ([]byte, error) {
return json.Marshal(self)
}
// Response Object.
// Implements the Message interface.
type Response struct {
Rid string `json:"rid"`
Response string `json:"response"`
Params json.RawMessage `json:"params"`
}
// Create a new Response object.
// rid is the request ID of the request this response is intended for.
// response is the reponse status text.
// params is map between string and any other JSON-serializable data structure.
func NewResponse(rid, response string, params map[string]interface{}) *Response {
res := &Response{
Rid: rid,
Response: response,
}
if targs, err := json.Marshal(params); err != nil {
panic(err)
} else {
res.Params = json.RawMessage(targs)
}
return res
}
func (self *Response) Marshal() ([]byte, error) {
return json.Marshal(self)
}
// The function type of the response handler.
// if res is nil, means that the response timeout.
type ResponseHandler func(res *Response) error
// The structure that stores the response handler information.
type Responder struct {
RequestTime int64 // Time of request
Timeout int64 // Timeout in seconds
Handler ResponseHandler // The corresponding request handler
}
// RPCCore is the core implementation of the TCP-based 2-way RPC protocol.
type RPCCore struct {
Conn net.Conn // handle to the TCP connection
ReadBuffer string // internal read buffer
responders map[string]Responder // response handlers
}
func NewRPCCore(conn net.Conn) *RPCCore {
return &RPCCore{Conn: conn, responders: make(map[string]Responder)}
}
func (self *RPCCore) SendMessage(msg Message) error {
if self.Conn == nil {
return errors.New("SendMessage failed, connection not established")
}
var err error
var msgBytes []byte
if msgBytes, err = msg.Marshal(); err == nil {
if DEBUG_RPC {
log.Printf("-----> %s\n", string(msgBytes))
}
_, err = self.Conn.Write(append(msgBytes, []byte(SEP)...))
}
return err
}
func (self *RPCCore) SendRequest(req *Request, handler ResponseHandler) error {
err := self.SendMessage(req)
if err == nil && req.Timeout >= 0 {
res := Responder{time.Now().Unix(), req.Timeout, handler}
self.responders[req.Rid] = res
}
return err
}
func (self *RPCCore) SendResponse(res *Response) error {
return self.SendMessage(res)
}
func (self *RPCCore) handleResponse(res *Response) error {
defer delete(self.responders, res.Rid)
if responder, ok := self.responders[res.Rid]; ok {
if responder.Handler != nil {
if err := responder.Handler(res); err != nil {
return err
}
}
} else {
return errors.New("Received unsolicited response, ignored")
}
return nil
}
// Spawnes a goroutine that actively read from the socket.
// This function returns two channels. The first one is the channel that
// send the content from the socket, and the second channel send an error
// object if there is one.
func (self *RPCCore) SpawnReaderRoutine() (chan []byte, chan error) {
readChan := make(chan []byte)
readErrChan := make(chan error, 1)
go func() {
for {
buf := make([]byte, BUFSIZ)
n, err := self.Conn.Read(buf)
if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
continue
}
readErrChan <- err
return
}
readChan <- buf[:n]
}
}()
return readChan, readErrChan
}
// Parses a single JSON string into a Message object.
func (self *RPCCore) ParseMessage(msgJson string) (Message, error) {
var req Request
var res Response
err := json.Unmarshal([]byte(msgJson), &req)
if err != nil || len(req.Name) == 0 {
err := json.Unmarshal([]byte(msgJson), &res)
if err != nil {
err = errors.New("mal-formed JSON request, ignored")
} else {
return &res, nil
}
} else {
return &req, nil
}
return nil, err
}
// Parses a buffer from SpawnReaderRoutine into Request objects.
// The response message is automatically handled by the RPCCore itself by
// invoking the corresponding response handler.
func (self *RPCCore) ParseRequests(buffer string, single bool) []*Request {
var reqs []*Request
var writeback string
var msgsJson []string
self.ReadBuffer += buffer
if single {
idx := strings.Index(self.ReadBuffer, SEP)
if idx == -1 {
return nil
}
msgsJson = []string{self.ReadBuffer[:idx]}
self.ReadBuffer = self.ReadBuffer[idx+2:]
} else {
msgs := strings.Split(self.ReadBuffer, SEP)
if len(msgs) > 0 {
self.ReadBuffer = msgs[len(msgs)-1]
msgsJson = msgs[:len(msgs)-1]
} else {
return nil
}
}
for _, msgJson := range msgsJson {
if DEBUG_RPC {
log.Printf("<----- " + msgJson)
}
if msg, err := self.ParseMessage(msgJson); err != nil {
writeback += msgJson + SEP
log.Printf("Message parse failed: %s\n", err)
continue
} else {
switch m := msg.(type) {
case *Request:
reqs = append(reqs, m)
case *Response:
err := self.handleResponse(m)
if err != nil {
log.Printf("Response error: %s\n", err)
}
}
}
}
return reqs
}
// Scan for timeout requests.
func (self *RPCCore) ScanForTimeoutRequests() error {
for rid, res := range self.responders {
if time.Now().Unix()-res.RequestTime > res.Timeout {
if res.Handler != nil {
if err := res.Handler(nil); err != nil {
delete(self.responders, rid)
return err
}
} else {
log.Printf("Request %s timeout\n", rid)
}
delete(self.responders, rid)
}
}
return nil
}
func (self *RPCCore) ClearRequests() {
self.responders = make(map[string]Responder)
}