| package sftp |
| |
| import ( |
| "encoding" |
| "sort" |
| "sync" |
| ) |
| |
| // The goal of the packetManager is to keep the outgoing packets in the same |
| // order as the incoming as is requires by section 7 of the RFC. |
| |
| type packetManager struct { |
| requests chan orderedPacket |
| responses chan orderedPacket |
| fini chan struct{} |
| incoming orderedPackets |
| outgoing orderedPackets |
| sender packetSender // connection object |
| working *sync.WaitGroup |
| packetCount uint32 |
| // it is not nil if the allocator is enabled |
| alloc *allocator |
| } |
| |
| type packetSender interface { |
| sendPacket(encoding.BinaryMarshaler) error |
| } |
| |
| func newPktMgr(sender packetSender) *packetManager { |
| s := &packetManager{ |
| requests: make(chan orderedPacket, SftpServerWorkerCount), |
| responses: make(chan orderedPacket, SftpServerWorkerCount), |
| fini: make(chan struct{}), |
| incoming: make([]orderedPacket, 0, SftpServerWorkerCount), |
| outgoing: make([]orderedPacket, 0, SftpServerWorkerCount), |
| sender: sender, |
| working: &sync.WaitGroup{}, |
| } |
| go s.controller() |
| return s |
| } |
| |
| //// packet ordering |
| func (s *packetManager) newOrderID() uint32 { |
| s.packetCount++ |
| return s.packetCount |
| } |
| |
| // returns the next orderID without incrementing it. |
| // This is used before receiving a new packet, with the allocator enabled, to associate |
| // the slice allocated for the received packet with the orderID that will be used to mark |
| // the allocated slices for reuse once the request is served |
| func (s *packetManager) getNextOrderID() uint32 { |
| return s.packetCount + 1 |
| } |
| |
| type orderedRequest struct { |
| requestPacket |
| orderid uint32 |
| } |
| |
| func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest { |
| return orderedRequest{requestPacket: p, orderid: s.newOrderID()} |
| } |
| func (p orderedRequest) orderID() uint32 { return p.orderid } |
| func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid } |
| |
| type orderedResponse struct { |
| responsePacket |
| orderid uint32 |
| } |
| |
| func (s *packetManager) newOrderedResponse(p responsePacket, id uint32, |
| ) orderedResponse { |
| return orderedResponse{responsePacket: p, orderid: id} |
| } |
| func (p orderedResponse) orderID() uint32 { return p.orderid } |
| func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid } |
| |
| type orderedPacket interface { |
| id() uint32 |
| orderID() uint32 |
| } |
| type orderedPackets []orderedPacket |
| |
| func (o orderedPackets) Sort() { |
| sort.Slice(o, func(i, j int) bool { |
| return o[i].orderID() < o[j].orderID() |
| }) |
| } |
| |
| //// packet registry |
| // register incoming packets to be handled |
| func (s *packetManager) incomingPacket(pkt orderedRequest) { |
| s.working.Add(1) |
| s.requests <- pkt |
| } |
| |
| // register outgoing packets as being ready |
| func (s *packetManager) readyPacket(pkt orderedResponse) { |
| s.responses <- pkt |
| s.working.Done() |
| } |
| |
| // shut down packetManager controller |
| func (s *packetManager) close() { |
| // pause until current packets are processed |
| s.working.Wait() |
| close(s.fini) |
| } |
| |
| // Passed a worker function, returns a channel for incoming packets. |
| // Keep process packet responses in the order they are received while |
| // maximizing throughput of file transfers. |
| func (s *packetManager) workerChan(runWorker func(chan orderedRequest), |
| ) chan orderedRequest { |
| // multiple workers for faster read/writes |
| rwChan := make(chan orderedRequest, SftpServerWorkerCount) |
| for i := 0; i < SftpServerWorkerCount; i++ { |
| runWorker(rwChan) |
| } |
| |
| // single worker to enforce sequential processing of everything else |
| cmdChan := make(chan orderedRequest) |
| runWorker(cmdChan) |
| |
| pktChan := make(chan orderedRequest, SftpServerWorkerCount) |
| go func() { |
| for pkt := range pktChan { |
| switch pkt.requestPacket.(type) { |
| case *sshFxpReadPacket, *sshFxpWritePacket: |
| s.incomingPacket(pkt) |
| rwChan <- pkt |
| continue |
| case *sshFxpClosePacket: |
| // wait for reads/writes to finish when file is closed |
| // incomingPacket() call must occur after this |
| s.working.Wait() |
| } |
| s.incomingPacket(pkt) |
| // all non-RW use sequential cmdChan |
| cmdChan <- pkt |
| } |
| close(rwChan) |
| close(cmdChan) |
| s.close() |
| }() |
| |
| return pktChan |
| } |
| |
| // process packets |
| func (s *packetManager) controller() { |
| for { |
| select { |
| case pkt := <-s.requests: |
| debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID()) |
| s.incoming = append(s.incoming, pkt) |
| s.incoming.Sort() |
| case pkt := <-s.responses: |
| debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID()) |
| s.outgoing = append(s.outgoing, pkt) |
| s.outgoing.Sort() |
| case <-s.fini: |
| return |
| } |
| s.maybeSendPackets() |
| } |
| } |
| |
| // send as many packets as are ready |
| func (s *packetManager) maybeSendPackets() { |
| for { |
| if len(s.outgoing) == 0 || len(s.incoming) == 0 { |
| debug("break! -- outgoing: %v; incoming: %v", |
| len(s.outgoing), len(s.incoming)) |
| break |
| } |
| out := s.outgoing[0] |
| in := s.incoming[0] |
| // debug("incoming: %v", ids(s.incoming)) |
| // debug("outgoing: %v", ids(s.outgoing)) |
| if in.orderID() == out.orderID() { |
| debug("Sending packet: %v", out.id()) |
| s.sender.sendPacket(out.(encoding.BinaryMarshaler)) |
| if s.alloc != nil { |
| // mark for reuse the slices allocated for this request |
| s.alloc.ReleasePages(in.orderID()) |
| } |
| // pop off heads |
| copy(s.incoming, s.incoming[1:]) // shift left |
| s.incoming[len(s.incoming)-1] = nil // clear last |
| s.incoming = s.incoming[:len(s.incoming)-1] // remove last |
| copy(s.outgoing, s.outgoing[1:]) // shift left |
| s.outgoing[len(s.outgoing)-1] = nil // clear last |
| s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last |
| } else { |
| break |
| } |
| } |
| } |
| |
| // func oids(o []orderedPacket) []uint32 { |
| // res := make([]uint32, 0, len(o)) |
| // for _, v := range o { |
| // res = append(res, v.orderId()) |
| // } |
| // return res |
| // } |
| // func ids(o []orderedPacket) []uint32 { |
| // res := make([]uint32, 0, len(o)) |
| // for _, v := range o { |
| // res = append(res, v.id()) |
| // } |
| // return res |
| // } |