| // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package websocket |
| |
| import ( |
| "compress/flate" |
| "errors" |
| "io" |
| "strings" |
| "sync" |
| ) |
| |
| var ( |
| flateWriterPool = sync.Pool{} |
| ) |
| |
| func decompressNoContextTakeover(r io.Reader) io.Reader { |
| const tail = |
| // Add four bytes as specified in RFC |
| "\x00\x00\xff\xff" + |
| // Add final block to squelch unexpected EOF error from flate reader. |
| "\x01\x00\x00\xff\xff" |
| return flate.NewReader(io.MultiReader(r, strings.NewReader(tail))) |
| } |
| |
| func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { |
| tw := &truncWriter{w: w} |
| i := flateWriterPool.Get() |
| var fw *flate.Writer |
| var err error |
| if i == nil { |
| fw, err = flate.NewWriter(tw, 3) |
| } else { |
| fw = i.(*flate.Writer) |
| fw.Reset(tw) |
| } |
| return &flateWrapper{fw: fw, tw: tw}, err |
| } |
| |
| // truncWriter is an io.Writer that writes all but the last four bytes of the |
| // stream to another io.Writer. |
| type truncWriter struct { |
| w io.WriteCloser |
| n int |
| p [4]byte |
| } |
| |
| func (w *truncWriter) Write(p []byte) (int, error) { |
| n := 0 |
| |
| // fill buffer first for simplicity. |
| if w.n < len(w.p) { |
| n = copy(w.p[w.n:], p) |
| p = p[n:] |
| w.n += n |
| if len(p) == 0 { |
| return n, nil |
| } |
| } |
| |
| m := len(p) |
| if m > len(w.p) { |
| m = len(w.p) |
| } |
| |
| if nn, err := w.w.Write(w.p[:m]); err != nil { |
| return n + nn, err |
| } |
| |
| copy(w.p[:], w.p[m:]) |
| copy(w.p[len(w.p)-m:], p[len(p)-m:]) |
| nn, err := w.w.Write(p[:len(p)-m]) |
| return n + nn, err |
| } |
| |
| type flateWrapper struct { |
| fw *flate.Writer |
| tw *truncWriter |
| } |
| |
| func (w *flateWrapper) Write(p []byte) (int, error) { |
| if w.fw == nil { |
| return 0, errWriteClosed |
| } |
| return w.fw.Write(p) |
| } |
| |
| func (w *flateWrapper) Close() error { |
| if w.fw == nil { |
| return errWriteClosed |
| } |
| err1 := w.fw.Flush() |
| flateWriterPool.Put(w.fw) |
| w.fw = nil |
| if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { |
| return errors.New("websocket: internal error, unexpected bytes at end of flate stream") |
| } |
| err2 := w.tw.w.Close() |
| if err1 != nil { |
| return err1 |
| } |
| return err2 |
| } |