| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- // Copyright (c) 2025 Tulir Asokan
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this
- // file, You can obtain one at http://mozilla.org/MPL/2.0/.
- package socket
- import (
- "context"
- "errors"
- "fmt"
- "net/http"
- "sync"
- "github.com/coder/websocket"
- waLog "git.bobomao.top/joey/testwh/util/log"
- )
- type FrameSocket struct {
- parentCtx context.Context
- cancelCtx context.Context
- cancel context.CancelFunc
- conn *websocket.Conn
- log waLog.Logger
- lock sync.Mutex
- URL string
- HTTPHeaders http.Header
- HTTPClient *http.Client
- Frames chan []byte
- OnDisconnect func(ctx context.Context, remote bool)
- Header []byte
- closed bool
- incomingLength int
- receivedLength int
- incoming []byte
- partialHeader []byte
- }
- func NewFrameSocket(log waLog.Logger, client *http.Client) *FrameSocket {
- return &FrameSocket{
- log: log,
- Header: WAConnHeader,
- Frames: make(chan []byte),
- URL: URL,
- HTTPHeaders: http.Header{"Origin": {Origin}},
- HTTPClient: client,
- }
- }
- func (fs *FrameSocket) IsConnected() bool {
- return fs.conn != nil
- }
- func (fs *FrameSocket) Close(code websocket.StatusCode) {
- fs.lock.Lock()
- defer fs.lock.Unlock()
- if fs.conn == nil {
- return
- }
- fs.closed = true
- if code > 0 {
- err := fs.conn.Close(code, "")
- if err != nil {
- fs.log.Warnf("Error sending close to websocket: %v", err)
- }
- } else {
- err := fs.conn.CloseNow()
- if err != nil {
- fs.log.Debugf("Error force closing websocket: %v", err)
- }
- }
- fs.conn = nil
- fs.cancel()
- fs.cancel = nil
- if fs.OnDisconnect != nil {
- go fs.OnDisconnect(fs.parentCtx, code == 0)
- }
- }
- func (fs *FrameSocket) Connect(ctx context.Context) error {
- fs.lock.Lock()
- defer fs.lock.Unlock()
- if fs.conn != nil {
- return ErrSocketAlreadyOpen
- }
- fs.parentCtx = ctx
- fs.cancelCtx, fs.cancel = context.WithCancel(ctx)
- fs.log.Debugf("Dialing %s", fs.URL)
- conn, resp, err := websocket.Dial(ctx, fs.URL, &websocket.DialOptions{
- HTTPClient: fs.HTTPClient,
- HTTPHeader: fs.HTTPHeaders,
- })
- if err != nil {
- if resp != nil {
- err = ErrWithStatusCode{err, resp.StatusCode}
- }
- fs.cancel()
- return fmt.Errorf("failed to dial whatsapp web websocket: %w", err)
- }
- conn.SetReadLimit(FrameMaxSize)
- fs.conn = conn
- go fs.readPump(conn, ctx)
- return nil
- }
- func (fs *FrameSocket) Context() context.Context {
- return fs.cancelCtx
- }
- func (fs *FrameSocket) SendFrame(data []byte) error {
- conn := fs.conn
- if conn == nil {
- return ErrSocketClosed
- }
- dataLength := len(data)
- if dataLength >= FrameMaxSize {
- return fmt.Errorf("%w (got %d bytes, max %d bytes)", ErrFrameTooLarge, len(data), FrameMaxSize)
- }
- headerLength := len(fs.Header)
- // Whole frame is header + 3 bytes for length + data
- wholeFrame := make([]byte, headerLength+FrameLengthSize+dataLength)
- // Copy the header if it's there
- if fs.Header != nil {
- copy(wholeFrame[:headerLength], fs.Header)
- // We only want to send the header once
- fs.Header = nil
- }
- // Encode length of frame
- wholeFrame[headerLength] = byte(dataLength >> 16)
- wholeFrame[headerLength+1] = byte(dataLength >> 8)
- wholeFrame[headerLength+2] = byte(dataLength)
- // Copy actual frame data
- copy(wholeFrame[headerLength+FrameLengthSize:], data)
- return conn.Write(fs.cancelCtx, websocket.MessageBinary, wholeFrame)
- }
- func (fs *FrameSocket) frameComplete() {
- data := fs.incoming
- fs.incoming = nil
- fs.partialHeader = nil
- fs.incomingLength = 0
- fs.receivedLength = 0
- fs.Frames <- data
- }
- func (fs *FrameSocket) processData(msg []byte) {
- for len(msg) > 0 {
- // This probably doesn't happen a lot (if at all), so the code is unoptimized
- if fs.partialHeader != nil {
- msg = append(fs.partialHeader, msg...)
- fs.partialHeader = nil
- }
- if fs.incoming == nil {
- if len(msg) >= FrameLengthSize {
- length := (int(msg[0]) << 16) + (int(msg[1]) << 8) + int(msg[2])
- fs.incomingLength = length
- fs.receivedLength = len(msg)
- msg = msg[FrameLengthSize:]
- if len(msg) >= length {
- fs.incoming = msg[:length]
- msg = msg[length:]
- fs.frameComplete()
- } else {
- fs.incoming = make([]byte, length)
- copy(fs.incoming, msg)
- msg = nil
- }
- } else {
- fs.log.Warnf("Received partial header (report if this happens often)")
- fs.partialHeader = msg
- msg = nil
- }
- } else {
- if fs.receivedLength+len(msg) >= fs.incomingLength {
- copy(fs.incoming[fs.receivedLength:], msg[:fs.incomingLength-fs.receivedLength])
- msg = msg[fs.incomingLength-fs.receivedLength:]
- fs.frameComplete()
- } else {
- copy(fs.incoming[fs.receivedLength:], msg)
- fs.receivedLength += len(msg)
- msg = nil
- }
- }
- }
- }
- func (fs *FrameSocket) readPump(conn *websocket.Conn, ctx context.Context) {
- fs.log.Debugf("Frame websocket read pump starting %p", fs)
- defer func() {
- fs.log.Debugf("Frame websocket read pump exiting %p", fs)
- go fs.Close(0)
- }()
- for {
- msgType, data, err := conn.Read(ctx)
- if err != nil {
- // Ignore the error if the context has been closed
- if !fs.closed && !errors.Is(ctx.Err(), context.Canceled) {
- fs.log.Errorf("Error reading from websocket: %v", err)
- }
- return
- } else if msgType != websocket.MessageBinary {
- fs.log.Warnf("Got unexpected websocket message type %d", msgType)
- continue
- }
- fs.processData(data)
- }
- }
|