encoder.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. package binary
  2. import (
  3. "fmt"
  4. "math"
  5. "strconv"
  6. "go.mau.fi/whatsmeow/binary/token"
  7. "go.mau.fi/whatsmeow/types"
  8. )
  9. type binaryEncoder struct {
  10. data []byte
  11. }
  12. func newEncoder() *binaryEncoder {
  13. return &binaryEncoder{[]byte{0}}
  14. }
  15. func (w *binaryEncoder) getData() []byte {
  16. return w.data
  17. }
  18. func (w *binaryEncoder) pushByte(b byte) {
  19. w.data = append(w.data, b)
  20. }
  21. func (w *binaryEncoder) pushBytes(bytes []byte) {
  22. w.data = append(w.data, bytes...)
  23. }
  24. func (w *binaryEncoder) pushIntN(value, n int, littleEndian bool) {
  25. for i := 0; i < n; i++ {
  26. var curShift int
  27. if littleEndian {
  28. curShift = i
  29. } else {
  30. curShift = n - i - 1
  31. }
  32. w.pushByte(byte((value >> uint(curShift*8)) & 0xFF))
  33. }
  34. }
  35. func (w *binaryEncoder) pushInt20(value int) {
  36. w.pushBytes([]byte{byte((value >> 16) & 0x0F), byte((value >> 8) & 0xFF), byte(value & 0xFF)})
  37. }
  38. func (w *binaryEncoder) pushInt8(value int) {
  39. w.pushIntN(value, 1, false)
  40. }
  41. func (w *binaryEncoder) pushInt16(value int) {
  42. w.pushIntN(value, 2, false)
  43. }
  44. func (w *binaryEncoder) pushInt32(value int) {
  45. w.pushIntN(value, 4, false)
  46. }
  47. func (w *binaryEncoder) pushString(value string) {
  48. w.pushBytes([]byte(value))
  49. }
  50. func (w *binaryEncoder) writeByteLength(length int) {
  51. if length < 256 {
  52. w.pushByte(token.Binary8)
  53. w.pushInt8(length)
  54. } else if length < (1 << 20) {
  55. w.pushByte(token.Binary20)
  56. w.pushInt20(length)
  57. } else if length < math.MaxInt32 {
  58. w.pushByte(token.Binary32)
  59. w.pushInt32(length)
  60. } else {
  61. panic(fmt.Errorf("length is too large: %d", length))
  62. }
  63. }
  64. const tagSize = 1
  65. func (w *binaryEncoder) writeNode(n Node) {
  66. if n.Tag == "0" {
  67. w.pushByte(token.List8)
  68. w.pushByte(token.ListEmpty)
  69. return
  70. }
  71. hasContent := 0
  72. if n.Content != nil {
  73. hasContent = 1
  74. }
  75. w.writeListStart(2*w.countAttributes(n.Attrs) + tagSize + hasContent)
  76. w.writeString(n.Tag)
  77. w.writeAttributes(n.Attrs)
  78. if n.Content != nil {
  79. w.write(n.Content)
  80. }
  81. }
  82. func (w *binaryEncoder) write(data interface{}) {
  83. switch typedData := data.(type) {
  84. case nil:
  85. w.pushByte(token.ListEmpty)
  86. case types.JID:
  87. w.writeJID(typedData)
  88. case string:
  89. w.writeString(typedData)
  90. case int:
  91. w.writeString(strconv.Itoa(typedData))
  92. case int32:
  93. w.writeString(strconv.FormatInt(int64(typedData), 10))
  94. case uint:
  95. w.writeString(strconv.FormatUint(uint64(typedData), 10))
  96. case uint32:
  97. w.writeString(strconv.FormatUint(uint64(typedData), 10))
  98. case int64:
  99. w.writeString(strconv.FormatInt(typedData, 10))
  100. case uint64:
  101. w.writeString(strconv.FormatUint(typedData, 10))
  102. case bool:
  103. w.writeString(strconv.FormatBool(typedData))
  104. case []byte:
  105. w.writeBytes(typedData)
  106. case []Node:
  107. w.writeListStart(len(typedData))
  108. for _, n := range typedData {
  109. w.writeNode(n)
  110. }
  111. default:
  112. panic(fmt.Errorf("%w: %T", ErrInvalidType, typedData))
  113. }
  114. }
  115. func (w *binaryEncoder) writeString(data string) {
  116. var dictIndex byte
  117. if tokenIndex, ok := token.IndexOfSingleToken(data); ok {
  118. w.pushByte(tokenIndex)
  119. } else if dictIndex, tokenIndex, ok = token.IndexOfDoubleByteToken(data); ok {
  120. w.pushByte(token.Dictionary0 + dictIndex)
  121. w.pushByte(tokenIndex)
  122. } else if validateNibble(data) {
  123. w.writePackedBytes(data, token.Nibble8)
  124. } else if validateHex(data) {
  125. w.writePackedBytes(data, token.Hex8)
  126. } else {
  127. w.writeStringRaw(data)
  128. }
  129. }
  130. func (w *binaryEncoder) writeBytes(value []byte) {
  131. w.writeByteLength(len(value))
  132. w.pushBytes(value)
  133. }
  134. func (w *binaryEncoder) writeStringRaw(value string) {
  135. w.writeByteLength(len(value))
  136. w.pushString(value)
  137. }
  138. func (w *binaryEncoder) writeJID(jid types.JID) {
  139. if ((jid.Server == types.DefaultUserServer || jid.Server == types.HiddenUserServer) && jid.Device > 0) || jid.Server == types.HostedServer {
  140. w.pushByte(token.ADJID)
  141. w.pushByte(jid.ActualAgent())
  142. w.pushByte(uint8(jid.Device))
  143. w.writeString(jid.User)
  144. } else if jid.Server == types.MessengerServer {
  145. w.pushByte(token.FBJID)
  146. w.write(jid.User)
  147. w.pushInt16(int(jid.Device))
  148. w.write(jid.Server)
  149. } else if jid.Server == types.InteropServer {
  150. w.pushByte(token.InteropJID)
  151. w.write(jid.User)
  152. w.pushInt16(int(jid.Device))
  153. w.pushInt16(int(jid.Integrator))
  154. w.write(jid.Server)
  155. } else {
  156. w.pushByte(token.JIDPair)
  157. if len(jid.User) == 0 {
  158. w.pushByte(token.ListEmpty)
  159. } else {
  160. w.write(jid.User)
  161. }
  162. w.write(jid.Server)
  163. }
  164. }
  165. func (w *binaryEncoder) writeAttributes(attributes Attrs) {
  166. for key, val := range attributes {
  167. if val == "" || val == nil {
  168. continue
  169. }
  170. w.writeString(key)
  171. w.write(val)
  172. }
  173. }
  174. func (w *binaryEncoder) countAttributes(attributes Attrs) (count int) {
  175. for _, val := range attributes {
  176. if val == "" || val == nil {
  177. continue
  178. }
  179. count += 1
  180. }
  181. return
  182. }
  183. func (w *binaryEncoder) writeListStart(listSize int) {
  184. if listSize == 0 {
  185. w.pushByte(byte(token.ListEmpty))
  186. } else if listSize < 256 {
  187. w.pushByte(byte(token.List8))
  188. w.pushInt8(listSize)
  189. } else {
  190. w.pushByte(byte(token.List16))
  191. w.pushInt16(listSize)
  192. }
  193. }
  194. func (w *binaryEncoder) writePackedBytes(value string, dataType int) {
  195. if len(value) > token.PackedMax {
  196. panic(fmt.Errorf("too many bytes to pack: %d", len(value)))
  197. }
  198. w.pushByte(byte(dataType))
  199. roundedLength := byte(math.Ceil(float64(len(value)) / 2.0))
  200. if len(value)%2 != 0 {
  201. roundedLength |= 128
  202. }
  203. w.pushByte(roundedLength)
  204. var packer func(byte) byte
  205. if dataType == token.Nibble8 {
  206. packer = packNibble
  207. } else if dataType == token.Hex8 {
  208. packer = packHex
  209. } else {
  210. // This should only be called with the correct values
  211. panic(fmt.Errorf("invalid packed byte data type %v", dataType))
  212. }
  213. for i, l := 0, len(value)/2; i < l; i++ {
  214. w.pushByte(w.packBytePair(packer, value[2*i], value[2*i+1]))
  215. }
  216. if len(value)%2 != 0 {
  217. w.pushByte(w.packBytePair(packer, value[len(value)-1], '\x00'))
  218. }
  219. }
  220. func (w *binaryEncoder) packBytePair(packer func(byte) byte, part1, part2 byte) byte {
  221. return (packer(part1) << 4) | packer(part2)
  222. }
  223. func validateNibble(value string) bool {
  224. if len(value) > token.PackedMax {
  225. return false
  226. }
  227. for _, char := range value {
  228. if !(char >= '0' && char <= '9') && char != '-' && char != '.' {
  229. return false
  230. }
  231. }
  232. return true
  233. }
  234. func packNibble(value byte) byte {
  235. switch value {
  236. case '-':
  237. return 10
  238. case '.':
  239. return 11
  240. case 0:
  241. return 15
  242. default:
  243. if value >= '0' && value <= '9' {
  244. return value - '0'
  245. }
  246. // This should be validated beforehand
  247. panic(fmt.Errorf("invalid string to pack as nibble: %d / '%s'", value, string(value)))
  248. }
  249. }
  250. func validateHex(value string) bool {
  251. if len(value) > token.PackedMax {
  252. return false
  253. }
  254. for _, char := range value {
  255. if !(char >= '0' && char <= '9') && !(char >= 'A' && char <= 'F') {
  256. return false
  257. }
  258. }
  259. return true
  260. }
  261. func packHex(value byte) byte {
  262. switch {
  263. case value >= '0' && value <= '9':
  264. return value - '0'
  265. case value >= 'A' && value <= 'F':
  266. return 10 + value - 'A'
  267. case value == 0:
  268. return 15
  269. default:
  270. // This should be validated beforehand
  271. panic(fmt.Errorf("invalid string to pack as hex: %d / '%s'", value, string(value)))
  272. }
  273. }