framesocket.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. // Copyright (c) 2025 Tulir Asokan
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. package socket
  7. import (
  8. "context"
  9. "errors"
  10. "fmt"
  11. "net/http"
  12. "sync"
  13. "github.com/coder/websocket"
  14. waLog "git.bobomao.top/joey/testwh/util/log"
  15. )
  16. type FrameSocket struct {
  17. parentCtx context.Context
  18. cancelCtx context.Context
  19. cancel context.CancelFunc
  20. conn *websocket.Conn
  21. log waLog.Logger
  22. lock sync.Mutex
  23. URL string
  24. HTTPHeaders http.Header
  25. HTTPClient *http.Client
  26. Frames chan []byte
  27. OnDisconnect func(ctx context.Context, remote bool)
  28. Header []byte
  29. closed bool
  30. incomingLength int
  31. receivedLength int
  32. incoming []byte
  33. partialHeader []byte
  34. }
  35. func NewFrameSocket(log waLog.Logger, client *http.Client) *FrameSocket {
  36. return &FrameSocket{
  37. log: log,
  38. Header: WAConnHeader,
  39. Frames: make(chan []byte),
  40. URL: URL,
  41. HTTPHeaders: http.Header{"Origin": {Origin}},
  42. HTTPClient: client,
  43. }
  44. }
  45. func (fs *FrameSocket) IsConnected() bool {
  46. return fs.conn != nil
  47. }
  48. func (fs *FrameSocket) Close(code websocket.StatusCode) {
  49. fs.lock.Lock()
  50. defer fs.lock.Unlock()
  51. if fs.conn == nil {
  52. return
  53. }
  54. fs.closed = true
  55. if code > 0 {
  56. err := fs.conn.Close(code, "")
  57. if err != nil {
  58. fs.log.Warnf("Error sending close to websocket: %v", err)
  59. }
  60. } else {
  61. err := fs.conn.CloseNow()
  62. if err != nil {
  63. fs.log.Debugf("Error force closing websocket: %v", err)
  64. }
  65. }
  66. fs.conn = nil
  67. fs.cancel()
  68. fs.cancel = nil
  69. if fs.OnDisconnect != nil {
  70. go fs.OnDisconnect(fs.parentCtx, code == 0)
  71. }
  72. }
  73. func (fs *FrameSocket) Connect(ctx context.Context) error {
  74. fs.lock.Lock()
  75. defer fs.lock.Unlock()
  76. if fs.conn != nil {
  77. return ErrSocketAlreadyOpen
  78. }
  79. fs.parentCtx = ctx
  80. fs.cancelCtx, fs.cancel = context.WithCancel(ctx)
  81. fs.log.Debugf("Dialing %s", fs.URL)
  82. conn, resp, err := websocket.Dial(ctx, fs.URL, &websocket.DialOptions{
  83. HTTPClient: fs.HTTPClient,
  84. HTTPHeader: fs.HTTPHeaders,
  85. })
  86. if err != nil {
  87. if resp != nil {
  88. err = ErrWithStatusCode{err, resp.StatusCode}
  89. }
  90. fs.cancel()
  91. return fmt.Errorf("failed to dial whatsapp web websocket: %w", err)
  92. }
  93. conn.SetReadLimit(FrameMaxSize)
  94. fs.conn = conn
  95. go fs.readPump(conn, ctx)
  96. return nil
  97. }
  98. func (fs *FrameSocket) Context() context.Context {
  99. return fs.cancelCtx
  100. }
  101. func (fs *FrameSocket) SendFrame(data []byte) error {
  102. conn := fs.conn
  103. if conn == nil {
  104. return ErrSocketClosed
  105. }
  106. dataLength := len(data)
  107. if dataLength >= FrameMaxSize {
  108. return fmt.Errorf("%w (got %d bytes, max %d bytes)", ErrFrameTooLarge, len(data), FrameMaxSize)
  109. }
  110. headerLength := len(fs.Header)
  111. // Whole frame is header + 3 bytes for length + data
  112. wholeFrame := make([]byte, headerLength+FrameLengthSize+dataLength)
  113. // Copy the header if it's there
  114. if fs.Header != nil {
  115. copy(wholeFrame[:headerLength], fs.Header)
  116. // We only want to send the header once
  117. fs.Header = nil
  118. }
  119. // Encode length of frame
  120. wholeFrame[headerLength] = byte(dataLength >> 16)
  121. wholeFrame[headerLength+1] = byte(dataLength >> 8)
  122. wholeFrame[headerLength+2] = byte(dataLength)
  123. // Copy actual frame data
  124. copy(wholeFrame[headerLength+FrameLengthSize:], data)
  125. return conn.Write(fs.cancelCtx, websocket.MessageBinary, wholeFrame)
  126. }
  127. func (fs *FrameSocket) frameComplete() {
  128. data := fs.incoming
  129. fs.incoming = nil
  130. fs.partialHeader = nil
  131. fs.incomingLength = 0
  132. fs.receivedLength = 0
  133. fs.Frames <- data
  134. }
  135. func (fs *FrameSocket) processData(msg []byte) {
  136. for len(msg) > 0 {
  137. // This probably doesn't happen a lot (if at all), so the code is unoptimized
  138. if fs.partialHeader != nil {
  139. msg = append(fs.partialHeader, msg...)
  140. fs.partialHeader = nil
  141. }
  142. if fs.incoming == nil {
  143. if len(msg) >= FrameLengthSize {
  144. length := (int(msg[0]) << 16) + (int(msg[1]) << 8) + int(msg[2])
  145. fs.incomingLength = length
  146. fs.receivedLength = len(msg)
  147. msg = msg[FrameLengthSize:]
  148. if len(msg) >= length {
  149. fs.incoming = msg[:length]
  150. msg = msg[length:]
  151. fs.frameComplete()
  152. } else {
  153. fs.incoming = make([]byte, length)
  154. copy(fs.incoming, msg)
  155. msg = nil
  156. }
  157. } else {
  158. fs.log.Warnf("Received partial header (report if this happens often)")
  159. fs.partialHeader = msg
  160. msg = nil
  161. }
  162. } else {
  163. if fs.receivedLength+len(msg) >= fs.incomingLength {
  164. copy(fs.incoming[fs.receivedLength:], msg[:fs.incomingLength-fs.receivedLength])
  165. msg = msg[fs.incomingLength-fs.receivedLength:]
  166. fs.frameComplete()
  167. } else {
  168. copy(fs.incoming[fs.receivedLength:], msg)
  169. fs.receivedLength += len(msg)
  170. msg = nil
  171. }
  172. }
  173. }
  174. }
  175. func (fs *FrameSocket) readPump(conn *websocket.Conn, ctx context.Context) {
  176. fs.log.Debugf("Frame websocket read pump starting %p", fs)
  177. defer func() {
  178. fs.log.Debugf("Frame websocket read pump exiting %p", fs)
  179. go fs.Close(0)
  180. }()
  181. for {
  182. msgType, data, err := conn.Read(ctx)
  183. if err != nil {
  184. // Ignore the error if the context has been closed
  185. if !fs.closed && !errors.Is(ctx.Err(), context.Canceled) {
  186. fs.log.Errorf("Error reading from websocket: %v", err)
  187. }
  188. return
  189. } else if msgType != websocket.MessageBinary {
  190. fs.log.Warnf("Got unexpected websocket message type %d", msgType)
  191. continue
  192. }
  193. fs.processData(data)
  194. }
  195. }