blob: bd907bb31588435126d5ca003c5a60ff184382a9 [file] [log] [blame]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package servod
import (
"fmt"
"io"
"net"
"sync"
"go.chromium.org/luci/common/errors"
"infra/libs/sshpool"
)
// proxy holds info to perform proxy confection to servod daemon.
type proxy struct {
host string
connFunc func() (net.Conn, error)
ls net.Listener
mutex sync.Mutex
errFuncs []func(error)
closed bool
}
const (
// Local address with dynamic port.
localAddr = "127.0.0.1:0"
// Local address template for remote host.
remoteAddrFmt = "127.0.0.1:%d"
)
// newProxy creates a new proxy with forward from remote to local host.
// Function is using a goroutine to listen and handle each incoming connection.
// Initialization of proxy is going asynchronous after return proxy instance.
func newProxy(pool *sshpool.Pool, host string, remotePort int32, errFuncs ...func(error)) (*proxy, error) {
remoteAddr := fmt.Sprintf(remoteAddrFmt, remotePort)
connFunc := func() (net.Conn, error) {
conn, err := pool.Get(host)
if err != nil {
return nil, errors.Annotate(err, "get proxy %q: fail to get client from pool", host).Err()
}
defer pool.Put(host, conn)
// Establish connection with remote server.
return conn.Dial("tcp", remoteAddr)
}
// Create listener for local port.
local, err := net.Listen("tcp", localAddr)
if err != nil {
return nil, err
}
proxy := &proxy{
host: host,
ls: local,
connFunc: connFunc,
errFuncs: errFuncs,
closed: false,
}
// Start a goroutine that serves as the listener and launches
// a new goroutine to handle each incoming connection.
// Running by goroutine to avoid waiting connections and return proxy for usage.
go func() {
for {
if proxy.closed {
break
}
// Waits for and returns the next connection.
local, err := proxy.ls.Accept()
if err != nil {
break
}
go func() {
if err := proxy.handleConn(local); err != nil && len(proxy.errFuncs) > 0 {
proxy.mutex.Lock()
for _, ef := range proxy.errFuncs {
ef(err)
}
proxy.mutex.Unlock()
}
}()
}
}()
return proxy, nil
}
// Close closes listening for incoming connections of proxy.
func (p *proxy) Close() error {
p.closed = true
p.mutex.Lock()
p.errFuncs = nil
p.mutex.Unlock()
return p.ls.Close()
}
// handleConn establishes a new connection to the destination port using connFunc
// and copies data between it and src. It closes src before returning.
func (p *proxy) handleConn(src net.Conn) error {
if p.closed {
return errors.Reason("handle connection: proxy closed").Err()
}
defer src.Close()
dst, err := p.connFunc()
if err != nil {
return err
}
defer dst.Close()
ch := make(chan error)
go func() {
_, err := io.Copy(src, dst)
ch <- err
}()
go func() {
_, err := io.Copy(dst, src)
ch <- err
}()
var firstErr error
for i := 0; i < 2; i++ {
if err := <-ch; err != io.EOF && firstErr == nil {
firstErr = err
}
}
return firstErr
}
// LocalAddr provides assigned local address.
// Example: 127.0.0.1:23456
func (p *proxy) LocalAddr() string {
return p.ls.Addr().String()
}