| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- package binary
- import (
- "fmt"
- "math"
- "strconv"
- "go.mau.fi/whatsmeow/binary/token"
- "go.mau.fi/whatsmeow/types"
- )
- type binaryEncoder struct {
- data []byte
- }
- func newEncoder() *binaryEncoder {
- return &binaryEncoder{[]byte{0}}
- }
- func (w *binaryEncoder) getData() []byte {
- return w.data
- }
- func (w *binaryEncoder) pushByte(b byte) {
- w.data = append(w.data, b)
- }
- func (w *binaryEncoder) pushBytes(bytes []byte) {
- w.data = append(w.data, bytes...)
- }
- func (w *binaryEncoder) pushIntN(value, n int, littleEndian bool) {
- for i := 0; i < n; i++ {
- var curShift int
- if littleEndian {
- curShift = i
- } else {
- curShift = n - i - 1
- }
- w.pushByte(byte((value >> uint(curShift*8)) & 0xFF))
- }
- }
- func (w *binaryEncoder) pushInt20(value int) {
- w.pushBytes([]byte{byte((value >> 16) & 0x0F), byte((value >> 8) & 0xFF), byte(value & 0xFF)})
- }
- func (w *binaryEncoder) pushInt8(value int) {
- w.pushIntN(value, 1, false)
- }
- func (w *binaryEncoder) pushInt16(value int) {
- w.pushIntN(value, 2, false)
- }
- func (w *binaryEncoder) pushInt32(value int) {
- w.pushIntN(value, 4, false)
- }
- func (w *binaryEncoder) pushString(value string) {
- w.pushBytes([]byte(value))
- }
- func (w *binaryEncoder) writeByteLength(length int) {
- if length < 256 {
- w.pushByte(token.Binary8)
- w.pushInt8(length)
- } else if length < (1 << 20) {
- w.pushByte(token.Binary20)
- w.pushInt20(length)
- } else if length < math.MaxInt32 {
- w.pushByte(token.Binary32)
- w.pushInt32(length)
- } else {
- panic(fmt.Errorf("length is too large: %d", length))
- }
- }
- const tagSize = 1
- func (w *binaryEncoder) writeNode(n Node) {
- if n.Tag == "0" {
- w.pushByte(token.List8)
- w.pushByte(token.ListEmpty)
- return
- }
- hasContent := 0
- if n.Content != nil {
- hasContent = 1
- }
- w.writeListStart(2*w.countAttributes(n.Attrs) + tagSize + hasContent)
- w.writeString(n.Tag)
- w.writeAttributes(n.Attrs)
- if n.Content != nil {
- w.write(n.Content)
- }
- }
- func (w *binaryEncoder) write(data interface{}) {
- switch typedData := data.(type) {
- case nil:
- w.pushByte(token.ListEmpty)
- case types.JID:
- w.writeJID(typedData)
- case string:
- w.writeString(typedData)
- case int:
- w.writeString(strconv.Itoa(typedData))
- case int32:
- w.writeString(strconv.FormatInt(int64(typedData), 10))
- case uint:
- w.writeString(strconv.FormatUint(uint64(typedData), 10))
- case uint32:
- w.writeString(strconv.FormatUint(uint64(typedData), 10))
- case int64:
- w.writeString(strconv.FormatInt(typedData, 10))
- case uint64:
- w.writeString(strconv.FormatUint(typedData, 10))
- case bool:
- w.writeString(strconv.FormatBool(typedData))
- case []byte:
- w.writeBytes(typedData)
- case []Node:
- w.writeListStart(len(typedData))
- for _, n := range typedData {
- w.writeNode(n)
- }
- default:
- panic(fmt.Errorf("%w: %T", ErrInvalidType, typedData))
- }
- }
- func (w *binaryEncoder) writeString(data string) {
- var dictIndex byte
- if tokenIndex, ok := token.IndexOfSingleToken(data); ok {
- w.pushByte(tokenIndex)
- } else if dictIndex, tokenIndex, ok = token.IndexOfDoubleByteToken(data); ok {
- w.pushByte(token.Dictionary0 + dictIndex)
- w.pushByte(tokenIndex)
- } else if validateNibble(data) {
- w.writePackedBytes(data, token.Nibble8)
- } else if validateHex(data) {
- w.writePackedBytes(data, token.Hex8)
- } else {
- w.writeStringRaw(data)
- }
- }
- func (w *binaryEncoder) writeBytes(value []byte) {
- w.writeByteLength(len(value))
- w.pushBytes(value)
- }
- func (w *binaryEncoder) writeStringRaw(value string) {
- w.writeByteLength(len(value))
- w.pushString(value)
- }
- func (w *binaryEncoder) writeJID(jid types.JID) {
- if ((jid.Server == types.DefaultUserServer || jid.Server == types.HiddenUserServer) && jid.Device > 0) || jid.Server == types.HostedServer {
- w.pushByte(token.ADJID)
- w.pushByte(jid.ActualAgent())
- w.pushByte(uint8(jid.Device))
- w.writeString(jid.User)
- } else if jid.Server == types.MessengerServer {
- w.pushByte(token.FBJID)
- w.write(jid.User)
- w.pushInt16(int(jid.Device))
- w.write(jid.Server)
- } else if jid.Server == types.InteropServer {
- w.pushByte(token.InteropJID)
- w.write(jid.User)
- w.pushInt16(int(jid.Device))
- w.pushInt16(int(jid.Integrator))
- w.write(jid.Server)
- } else {
- w.pushByte(token.JIDPair)
- if len(jid.User) == 0 {
- w.pushByte(token.ListEmpty)
- } else {
- w.write(jid.User)
- }
- w.write(jid.Server)
- }
- }
- func (w *binaryEncoder) writeAttributes(attributes Attrs) {
- for key, val := range attributes {
- if val == "" || val == nil {
- continue
- }
- w.writeString(key)
- w.write(val)
- }
- }
- func (w *binaryEncoder) countAttributes(attributes Attrs) (count int) {
- for _, val := range attributes {
- if val == "" || val == nil {
- continue
- }
- count += 1
- }
- return
- }
- func (w *binaryEncoder) writeListStart(listSize int) {
- if listSize == 0 {
- w.pushByte(byte(token.ListEmpty))
- } else if listSize < 256 {
- w.pushByte(byte(token.List8))
- w.pushInt8(listSize)
- } else {
- w.pushByte(byte(token.List16))
- w.pushInt16(listSize)
- }
- }
- func (w *binaryEncoder) writePackedBytes(value string, dataType int) {
- if len(value) > token.PackedMax {
- panic(fmt.Errorf("too many bytes to pack: %d", len(value)))
- }
- w.pushByte(byte(dataType))
- roundedLength := byte(math.Ceil(float64(len(value)) / 2.0))
- if len(value)%2 != 0 {
- roundedLength |= 128
- }
- w.pushByte(roundedLength)
- var packer func(byte) byte
- if dataType == token.Nibble8 {
- packer = packNibble
- } else if dataType == token.Hex8 {
- packer = packHex
- } else {
- // This should only be called with the correct values
- panic(fmt.Errorf("invalid packed byte data type %v", dataType))
- }
- for i, l := 0, len(value)/2; i < l; i++ {
- w.pushByte(w.packBytePair(packer, value[2*i], value[2*i+1]))
- }
- if len(value)%2 != 0 {
- w.pushByte(w.packBytePair(packer, value[len(value)-1], '\x00'))
- }
- }
- func (w *binaryEncoder) packBytePair(packer func(byte) byte, part1, part2 byte) byte {
- return (packer(part1) << 4) | packer(part2)
- }
- func validateNibble(value string) bool {
- if len(value) > token.PackedMax {
- return false
- }
- for _, char := range value {
- if !(char >= '0' && char <= '9') && char != '-' && char != '.' {
- return false
- }
- }
- return true
- }
- func packNibble(value byte) byte {
- switch value {
- case '-':
- return 10
- case '.':
- return 11
- case 0:
- return 15
- default:
- if value >= '0' && value <= '9' {
- return value - '0'
- }
- // This should be validated beforehand
- panic(fmt.Errorf("invalid string to pack as nibble: %d / '%s'", value, string(value)))
- }
- }
- func validateHex(value string) bool {
- if len(value) > token.PackedMax {
- return false
- }
- for _, char := range value {
- if !(char >= '0' && char <= '9') && !(char >= 'A' && char <= 'F') {
- return false
- }
- }
- return true
- }
- func packHex(value byte) byte {
- switch {
- case value >= '0' && value <= '9':
- return value - '0'
- case value >= 'A' && value <= 'F':
- return 10 + value - 'A'
- case value == 0:
- return 15
- default:
- // This should be validated beforehand
- panic(fmt.Errorf("invalid string to pack as hex: %d / '%s'", value, string(value)))
- }
- }
|