decode.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. // Copyright (c) 2021 Tulir Asokan
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. package appstate
  7. import (
  8. "bytes"
  9. "context"
  10. "crypto/sha256"
  11. "encoding/json"
  12. "fmt"
  13. "google.golang.org/protobuf/proto"
  14. waBinary "git.bobomao.top/joey/testwh/binary"
  15. "git.bobomao.top/joey/testwh/proto/waServerSync"
  16. "git.bobomao.top/joey/testwh/proto/waSyncAction"
  17. "git.bobomao.top/joey/testwh/store"
  18. "git.bobomao.top/joey/testwh/util/cbcutil"
  19. )
  20. // PatchList represents a decoded response to getting app state patches from the WhatsApp servers.
  21. type PatchList struct {
  22. Name WAPatchName
  23. HasMorePatches bool
  24. Patches []*waServerSync.SyncdPatch
  25. Snapshot *waServerSync.SyncdSnapshot
  26. }
  27. // DownloadExternalFunc is a function that can download a blob of external app state patches.
  28. type DownloadExternalFunc func(context.Context, *waServerSync.ExternalBlobReference) ([]byte, error)
  29. func parseSnapshotInternal(ctx context.Context, collection *waBinary.Node, downloadExternal DownloadExternalFunc) (*waServerSync.SyncdSnapshot, error) {
  30. snapshotNode := collection.GetChildByTag("snapshot")
  31. rawSnapshot, ok := snapshotNode.Content.([]byte)
  32. if snapshotNode.Tag != "snapshot" || !ok {
  33. return nil, nil
  34. }
  35. var snapshot waServerSync.ExternalBlobReference
  36. err := proto.Unmarshal(rawSnapshot, &snapshot)
  37. if err != nil {
  38. return nil, fmt.Errorf("failed to unmarshal snapshot: %w", err)
  39. }
  40. var rawData []byte
  41. rawData, err = downloadExternal(ctx, &snapshot)
  42. if err != nil {
  43. return nil, fmt.Errorf("failed to download external mutations: %w", err)
  44. }
  45. var downloaded waServerSync.SyncdSnapshot
  46. err = proto.Unmarshal(rawData, &downloaded)
  47. if err != nil {
  48. return nil, fmt.Errorf("failed to unmarshal mutation list: %w", err)
  49. }
  50. return &downloaded, nil
  51. }
  52. func parsePatchListInternal(ctx context.Context, collection *waBinary.Node, downloadExternal DownloadExternalFunc) ([]*waServerSync.SyncdPatch, error) {
  53. patchesNode := collection.GetChildByTag("patches")
  54. patchNodes := patchesNode.GetChildren()
  55. patches := make([]*waServerSync.SyncdPatch, 0, len(patchNodes))
  56. for i, patchNode := range patchNodes {
  57. rawPatch, ok := patchNode.Content.([]byte)
  58. if patchNode.Tag != "patch" || !ok {
  59. continue
  60. }
  61. var patch waServerSync.SyncdPatch
  62. err := proto.Unmarshal(rawPatch, &patch)
  63. if err != nil {
  64. return nil, fmt.Errorf("failed to unmarshal patch #%d: %w", i+1, err)
  65. }
  66. if patch.GetExternalMutations() != nil && downloadExternal != nil {
  67. var rawData []byte
  68. rawData, err = downloadExternal(ctx, patch.GetExternalMutations())
  69. if err != nil {
  70. return nil, fmt.Errorf("failed to download external mutations: %w", err)
  71. }
  72. var downloaded waServerSync.SyncdMutations
  73. err = proto.Unmarshal(rawData, &downloaded)
  74. if err != nil {
  75. return nil, fmt.Errorf("failed to unmarshal mutation list: %w", err)
  76. } else if len(downloaded.GetMutations()) == 0 {
  77. return nil, fmt.Errorf("didn't get any mutations from download")
  78. }
  79. patch.Mutations = downloaded.Mutations
  80. }
  81. patches = append(patches, &patch)
  82. }
  83. return patches, nil
  84. }
  85. // ParsePatchList will decode an XML node containing app state patches, including downloading any external blobs.
  86. func ParsePatchList(ctx context.Context, collection *waBinary.Node, downloadExternal DownloadExternalFunc) (*PatchList, error) {
  87. ag := collection.AttrGetter()
  88. snapshot, err := parseSnapshotInternal(ctx, collection, downloadExternal)
  89. if err != nil {
  90. return nil, err
  91. }
  92. patches, err := parsePatchListInternal(ctx, collection, downloadExternal)
  93. if err != nil {
  94. return nil, err
  95. }
  96. list := &PatchList{
  97. Name: WAPatchName(ag.String("name")),
  98. HasMorePatches: ag.OptionalBool("has_more_patches"),
  99. Patches: patches,
  100. Snapshot: snapshot,
  101. }
  102. return list, ag.Error()
  103. }
  104. type patchOutput struct {
  105. RemovedMACs [][]byte
  106. AddedMACs []store.AppStateMutationMAC
  107. Mutations []Mutation
  108. }
  109. func (proc *Processor) decodeMutations(ctx context.Context, mutations []*waServerSync.SyncdMutation, out *patchOutput, validateMACs bool) error {
  110. for i, mutation := range mutations {
  111. keyID := mutation.GetRecord().GetKeyID().GetID()
  112. keys, err := proc.getAppStateKey(ctx, keyID)
  113. if err != nil {
  114. return fmt.Errorf("failed to get key %X to decode mutation: %w", keyID, err)
  115. }
  116. content := mutation.GetRecord().GetValue().GetBlob()
  117. content, valueMAC := content[:len(content)-32], content[len(content)-32:]
  118. if validateMACs {
  119. expectedValueMAC := generateContentMAC(mutation.GetOperation(), content, keyID, keys.ValueMAC)
  120. if !bytes.Equal(expectedValueMAC, valueMAC) {
  121. return fmt.Errorf("failed to verify mutation #%d: %w", i+1, ErrMismatchingContentMAC)
  122. }
  123. }
  124. iv, content := content[:16], content[16:]
  125. plaintext, err := cbcutil.Decrypt(keys.ValueEncryption, iv, content)
  126. if err != nil {
  127. return fmt.Errorf("failed to decrypt mutation #%d: %w", i+1, err)
  128. }
  129. var syncAction waSyncAction.SyncActionData
  130. err = proto.Unmarshal(plaintext, &syncAction)
  131. if err != nil {
  132. return fmt.Errorf("failed to unmarshal mutation #%d: %w", i+1, err)
  133. }
  134. indexMAC := mutation.GetRecord().GetIndex().GetBlob()
  135. if validateMACs {
  136. expectedIndexMAC := concatAndHMAC(sha256.New, keys.Index, syncAction.Index)
  137. if !bytes.Equal(expectedIndexMAC, indexMAC) {
  138. return fmt.Errorf("failed to verify mutation #%d: %w", i+1, ErrMismatchingIndexMAC)
  139. }
  140. }
  141. var index []string
  142. err = json.Unmarshal(syncAction.GetIndex(), &index)
  143. if err != nil {
  144. return fmt.Errorf("failed to unmarshal index of mutation #%d: %w", i+1, err)
  145. }
  146. if mutation.GetOperation() == waServerSync.SyncdMutation_REMOVE {
  147. out.RemovedMACs = append(out.RemovedMACs, indexMAC)
  148. } else if mutation.GetOperation() == waServerSync.SyncdMutation_SET {
  149. out.AddedMACs = append(out.AddedMACs, store.AppStateMutationMAC{
  150. IndexMAC: indexMAC,
  151. ValueMAC: valueMAC,
  152. })
  153. }
  154. out.Mutations = append(out.Mutations, Mutation{
  155. Operation: mutation.GetOperation(),
  156. Action: syncAction.GetValue(),
  157. Version: syncAction.GetVersion(),
  158. Index: index,
  159. IndexMAC: indexMAC,
  160. ValueMAC: valueMAC,
  161. })
  162. }
  163. return nil
  164. }
  165. func (proc *Processor) storeMACs(ctx context.Context, name WAPatchName, currentState HashState, out *patchOutput) {
  166. err := proc.Store.AppState.PutAppStateVersion(ctx, string(name), currentState.Version, currentState.Hash)
  167. if err != nil {
  168. proc.Log.Errorf("Failed to update app state version in the database: %v", err)
  169. }
  170. err = proc.Store.AppState.DeleteAppStateMutationMACs(ctx, string(name), out.RemovedMACs)
  171. if err != nil {
  172. proc.Log.Errorf("Failed to remove deleted mutation MACs from the database: %v", err)
  173. }
  174. err = proc.Store.AppState.PutAppStateMutationMACs(ctx, string(name), currentState.Version, out.AddedMACs)
  175. if err != nil {
  176. proc.Log.Errorf("Failed to insert added mutation MACs to the database: %v", err)
  177. }
  178. }
  179. func (proc *Processor) validateSnapshotMAC(ctx context.Context, name WAPatchName, currentState HashState, keyID, expectedSnapshotMAC []byte) (keys ExpandedAppStateKeys, err error) {
  180. keys, err = proc.getAppStateKey(ctx, keyID)
  181. if err != nil {
  182. err = fmt.Errorf("failed to get key %X to verify patch v%d MACs: %w", keyID, currentState.Version, err)
  183. return
  184. }
  185. snapshotMAC := currentState.generateSnapshotMAC(name, keys.SnapshotMAC)
  186. if !bytes.Equal(snapshotMAC, expectedSnapshotMAC) {
  187. err = fmt.Errorf("failed to verify patch v%d: %w", currentState.Version, ErrMismatchingLTHash)
  188. }
  189. return
  190. }
  191. func (proc *Processor) decodeSnapshot(ctx context.Context, name WAPatchName, ss *waServerSync.SyncdSnapshot, initialState HashState, validateMACs bool, newMutationsInput []Mutation) (newMutations []Mutation, currentState HashState, err error) {
  192. currentState = initialState
  193. currentState.Version = ss.GetVersion().GetVersion()
  194. encryptedMutations := make([]*waServerSync.SyncdMutation, len(ss.GetRecords()))
  195. for i, record := range ss.GetRecords() {
  196. encryptedMutations[i] = &waServerSync.SyncdMutation{
  197. Operation: waServerSync.SyncdMutation_SET.Enum(),
  198. Record: record,
  199. }
  200. }
  201. var warn []error
  202. warn, err = currentState.updateHash(encryptedMutations, func(indexMAC []byte, maxIndex int) ([]byte, error) {
  203. return nil, nil
  204. })
  205. if len(warn) > 0 {
  206. proc.Log.Warnf("Warnings while updating hash for %s: %+v", name, warn)
  207. }
  208. if err != nil {
  209. err = fmt.Errorf("failed to update state hash: %w", err)
  210. return
  211. }
  212. if validateMACs {
  213. _, err = proc.validateSnapshotMAC(ctx, name, currentState, ss.GetKeyID().GetID(), ss.GetMac())
  214. if err != nil {
  215. return
  216. }
  217. }
  218. var out patchOutput
  219. out.Mutations = newMutationsInput
  220. err = proc.decodeMutations(ctx, encryptedMutations, &out, validateMACs)
  221. if err != nil {
  222. err = fmt.Errorf("failed to decode snapshot of v%d: %w", currentState.Version, err)
  223. return
  224. }
  225. proc.storeMACs(ctx, name, currentState, &out)
  226. newMutations = out.Mutations
  227. return
  228. }
  229. // DecodePatches will decode all the patches in a PatchList into a list of app state mutations.
  230. func (proc *Processor) DecodePatches(ctx context.Context, list *PatchList, initialState HashState, validateMACs bool) (newMutations []Mutation, currentState HashState, err error) {
  231. currentState = initialState
  232. var expectedLength int
  233. if list.Snapshot != nil {
  234. expectedLength = len(list.Snapshot.GetRecords())
  235. }
  236. for _, patch := range list.Patches {
  237. expectedLength += len(patch.GetMutations())
  238. }
  239. newMutations = make([]Mutation, 0, expectedLength)
  240. if list.Snapshot != nil {
  241. newMutations, currentState, err = proc.decodeSnapshot(ctx, list.Name, list.Snapshot, currentState, validateMACs, newMutations)
  242. if err != nil {
  243. return
  244. }
  245. }
  246. for _, patch := range list.Patches {
  247. version := patch.GetVersion().GetVersion()
  248. currentState.Version = version
  249. var warn []error
  250. warn, err = currentState.updateHash(patch.GetMutations(), func(indexMAC []byte, maxIndex int) ([]byte, error) {
  251. for i := maxIndex - 1; i >= 0; i-- {
  252. if bytes.Equal(patch.Mutations[i].GetRecord().GetIndex().GetBlob(), indexMAC) {
  253. value := patch.Mutations[i].GetRecord().GetValue().GetBlob()
  254. return value[len(value)-32:], nil
  255. }
  256. }
  257. // Previous value not found in current patch, look in the database
  258. return proc.Store.AppState.GetAppStateMutationMAC(ctx, string(list.Name), indexMAC)
  259. })
  260. if len(warn) > 0 {
  261. proc.Log.Warnf("Warnings while updating hash for %s: %+v", list.Name, warn)
  262. }
  263. if err != nil {
  264. err = fmt.Errorf("failed to update state hash: %w", err)
  265. return
  266. }
  267. if validateMACs {
  268. var keys ExpandedAppStateKeys
  269. keys, err = proc.validateSnapshotMAC(ctx, list.Name, currentState, patch.GetKeyID().GetID(), patch.GetSnapshotMAC())
  270. if err != nil {
  271. return
  272. }
  273. patchMAC := generatePatchMAC(patch, list.Name, keys.PatchMAC, patch.GetVersion().GetVersion())
  274. if !bytes.Equal(patchMAC, patch.GetPatchMAC()) {
  275. err = fmt.Errorf("failed to verify patch v%d: %w", version, ErrMismatchingPatchMAC)
  276. return
  277. }
  278. }
  279. var out patchOutput
  280. out.Mutations = newMutations
  281. err = proc.decodeMutations(ctx, patch.GetMutations(), &out, validateMACs)
  282. if err != nil {
  283. return
  284. }
  285. proc.storeMACs(ctx, list.Name, currentState, &out)
  286. newMutations = out.Mutations
  287. }
  288. return
  289. }