prekeys.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. // Copyright (c) 2021 Tulir Asokan
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. package whatsmeow
  7. import (
  8. "context"
  9. "encoding/binary"
  10. "fmt"
  11. "time"
  12. "go.mau.fi/libsignal/ecc"
  13. "go.mau.fi/libsignal/keys/identity"
  14. "go.mau.fi/libsignal/keys/prekey"
  15. "go.mau.fi/libsignal/util/optional"
  16. waBinary "go.mau.fi/whatsmeow/binary"
  17. "go.mau.fi/whatsmeow/types"
  18. "go.mau.fi/whatsmeow/util/keys"
  19. )
  20. const (
  21. // WantedPreKeyCount is the number of prekeys that the client should upload to the WhatsApp servers in a single batch.
  22. WantedPreKeyCount = 50
  23. // MinPreKeyCount is the number of prekeys when the client will upload a new batch of prekeys to the WhatsApp servers.
  24. MinPreKeyCount = 5
  25. )
  26. func (cli *Client) getServerPreKeyCount(ctx context.Context) (int, error) {
  27. resp, err := cli.sendIQ(ctx, infoQuery{
  28. Namespace: "encrypt",
  29. Type: "get",
  30. To: types.ServerJID,
  31. Content: []waBinary.Node{
  32. {Tag: "count"},
  33. },
  34. })
  35. if err != nil {
  36. return 0, fmt.Errorf("failed to get prekey count on server: %w", err)
  37. }
  38. count := resp.GetChildByTag("count")
  39. ag := count.AttrGetter()
  40. val := ag.Int("value")
  41. return val, ag.Error()
  42. }
  43. func (cli *Client) uploadPreKeys(ctx context.Context, initialUpload bool) {
  44. cli.uploadPreKeysLock.Lock()
  45. defer cli.uploadPreKeysLock.Unlock()
  46. if cli.lastPreKeyUpload.Add(10 * time.Minute).After(time.Now()) {
  47. sc, _ := cli.getServerPreKeyCount(ctx)
  48. if sc >= WantedPreKeyCount {
  49. cli.Log.Debugf("Canceling prekey upload request due to likely race condition")
  50. return
  51. }
  52. }
  53. var registrationIDBytes [4]byte
  54. binary.BigEndian.PutUint32(registrationIDBytes[:], cli.Store.RegistrationID)
  55. wantedCount := WantedPreKeyCount
  56. if initialUpload {
  57. wantedCount = 812
  58. }
  59. preKeys, err := cli.Store.PreKeys.GetOrGenPreKeys(ctx, uint32(wantedCount))
  60. if err != nil {
  61. cli.Log.Errorf("Failed to get prekeys to upload: %v", err)
  62. return
  63. }
  64. cli.Log.Infof("Uploading %d new prekeys to server", len(preKeys))
  65. _, err = cli.sendIQ(ctx, infoQuery{
  66. Namespace: "encrypt",
  67. Type: "set",
  68. To: types.ServerJID,
  69. Content: []waBinary.Node{
  70. {Tag: "registration", Content: registrationIDBytes[:]},
  71. {Tag: "type", Content: []byte{ecc.DjbType}},
  72. {Tag: "identity", Content: cli.Store.IdentityKey.Pub[:]},
  73. {Tag: "list", Content: preKeysToNodes(preKeys)},
  74. preKeyToNode(cli.Store.SignedPreKey),
  75. },
  76. })
  77. if err != nil {
  78. cli.Log.Errorf("Failed to send request to upload prekeys: %v", err)
  79. return
  80. }
  81. cli.Log.Debugf("Got response to uploading prekeys")
  82. err = cli.Store.PreKeys.MarkPreKeysAsUploaded(ctx, preKeys[len(preKeys)-1].KeyID)
  83. if err != nil {
  84. cli.Log.Warnf("Failed to mark prekeys as uploaded: %v", err)
  85. return
  86. }
  87. cli.lastPreKeyUpload = time.Now()
  88. return
  89. }
  90. func (cli *Client) fetchPreKeysNoError(ctx context.Context, retryDevices []types.JID) map[types.JID]*prekey.Bundle {
  91. if len(retryDevices) == 0 {
  92. return nil
  93. }
  94. bundlesResp, err := cli.fetchPreKeys(ctx, retryDevices)
  95. if err != nil {
  96. cli.Log.Warnf("Failed to fetch prekeys for %v with no existing session: %v", retryDevices, err)
  97. return nil
  98. }
  99. bundles := make(map[types.JID]*prekey.Bundle, len(retryDevices))
  100. for _, jid := range retryDevices {
  101. resp := bundlesResp[jid]
  102. if resp.err != nil {
  103. cli.Log.Warnf("Failed to fetch prekey for %s: %v", jid, resp.err)
  104. continue
  105. }
  106. bundles[jid] = resp.bundle
  107. }
  108. return bundles
  109. }
  110. type preKeyResp struct {
  111. bundle *prekey.Bundle
  112. err error
  113. }
  114. func (cli *Client) fetchPreKeys(ctx context.Context, users []types.JID) (map[types.JID]preKeyResp, error) {
  115. requests := make([]waBinary.Node, len(users))
  116. for i, user := range users {
  117. requests[i].Tag = "user"
  118. requests[i].Attrs = waBinary.Attrs{
  119. "jid": user,
  120. "reason": "identity",
  121. }
  122. }
  123. resp, err := cli.sendIQ(ctx, infoQuery{
  124. Namespace: "encrypt",
  125. Type: "get",
  126. To: types.ServerJID,
  127. Content: []waBinary.Node{{
  128. Tag: "key",
  129. Content: requests,
  130. }},
  131. })
  132. if err != nil {
  133. return nil, fmt.Errorf("failed to send prekey request: %w", err)
  134. } else if len(resp.GetChildren()) == 0 {
  135. return nil, fmt.Errorf("got empty response to prekey request")
  136. }
  137. list := resp.GetChildByTag("list")
  138. respData := make(map[types.JID]preKeyResp)
  139. for _, child := range list.GetChildren() {
  140. if child.Tag != "user" {
  141. continue
  142. }
  143. jid := child.AttrGetter().JID("jid")
  144. bundle, err := nodeToPreKeyBundle(uint32(jid.Device), child)
  145. respData[jid] = preKeyResp{bundle, err}
  146. }
  147. return respData, nil
  148. }
  149. func preKeyToNode(key *keys.PreKey) waBinary.Node {
  150. var keyID [4]byte
  151. binary.BigEndian.PutUint32(keyID[:], key.KeyID)
  152. node := waBinary.Node{
  153. Tag: "key",
  154. Content: []waBinary.Node{
  155. {Tag: "id", Content: keyID[1:]},
  156. {Tag: "value", Content: key.Pub[:]},
  157. },
  158. }
  159. if key.Signature != nil {
  160. node.Tag = "skey"
  161. node.Content = append(node.GetChildren(), waBinary.Node{
  162. Tag: "signature",
  163. Content: key.Signature[:],
  164. })
  165. }
  166. return node
  167. }
  168. func nodeToPreKeyBundle(deviceID uint32, node waBinary.Node) (*prekey.Bundle, error) {
  169. errorNode, ok := node.GetOptionalChildByTag("error")
  170. if ok && errorNode.Tag == "error" {
  171. return nil, fmt.Errorf("got error getting prekeys: %s", errorNode.XMLString())
  172. }
  173. registrationBytes, ok := node.GetChildByTag("registration").Content.([]byte)
  174. if !ok || len(registrationBytes) != 4 {
  175. return nil, fmt.Errorf("invalid registration ID in prekey response")
  176. }
  177. registrationID := binary.BigEndian.Uint32(registrationBytes)
  178. keysNode, ok := node.GetOptionalChildByTag("keys")
  179. if !ok {
  180. keysNode = node
  181. }
  182. identityKeyRaw, ok := keysNode.GetChildByTag("identity").Content.([]byte)
  183. if !ok || len(identityKeyRaw) != 32 {
  184. return nil, fmt.Errorf("invalid identity key in prekey response")
  185. }
  186. identityKeyPub := *(*[32]byte)(identityKeyRaw)
  187. preKeyNode, ok := keysNode.GetOptionalChildByTag("key")
  188. preKey := &keys.PreKey{}
  189. if ok {
  190. var err error
  191. preKey, err = nodeToPreKey(preKeyNode)
  192. if err != nil {
  193. return nil, fmt.Errorf("invalid prekey in prekey response: %w", err)
  194. }
  195. }
  196. signedPreKey, err := nodeToPreKey(keysNode.GetChildByTag("skey"))
  197. if err != nil {
  198. return nil, fmt.Errorf("invalid signed prekey in prekey response: %w", err)
  199. }
  200. var bundle *prekey.Bundle
  201. if ok {
  202. bundle = prekey.NewBundle(registrationID, deviceID,
  203. optional.NewOptionalUint32(preKey.KeyID), signedPreKey.KeyID,
  204. ecc.NewDjbECPublicKey(*preKey.Pub), ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
  205. identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub)))
  206. } else {
  207. bundle = prekey.NewBundle(registrationID, deviceID, optional.NewEmptyUint32(), signedPreKey.KeyID,
  208. nil, ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
  209. identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub)))
  210. }
  211. return bundle, nil
  212. }
  213. func nodeToPreKey(node waBinary.Node) (*keys.PreKey, error) {
  214. key := keys.PreKey{
  215. KeyPair: keys.KeyPair{},
  216. KeyID: 0,
  217. Signature: nil,
  218. }
  219. if id := node.GetChildByTag("id"); id.Tag != "id" {
  220. return nil, fmt.Errorf("prekey node doesn't contain ID tag")
  221. } else if idBytes, ok := id.Content.([]byte); !ok {
  222. return nil, fmt.Errorf("prekey ID has unexpected content (%T)", id.Content)
  223. } else if len(idBytes) != 3 {
  224. return nil, fmt.Errorf("prekey ID has unexpected number of bytes (%d, expected 3)", len(idBytes))
  225. } else {
  226. key.KeyID = binary.BigEndian.Uint32(append([]byte{0}, idBytes...))
  227. }
  228. if pubkey := node.GetChildByTag("value"); pubkey.Tag != "value" {
  229. return nil, fmt.Errorf("prekey node doesn't contain value tag")
  230. } else if pubkeyBytes, ok := pubkey.Content.([]byte); !ok {
  231. return nil, fmt.Errorf("prekey value has unexpected content (%T)", pubkey.Content)
  232. } else if len(pubkeyBytes) != 32 {
  233. return nil, fmt.Errorf("prekey value has unexpected number of bytes (%d, expected 32)", len(pubkeyBytes))
  234. } else {
  235. key.KeyPair.Pub = (*[32]byte)(pubkeyBytes)
  236. }
  237. if node.Tag == "skey" {
  238. if sig := node.GetChildByTag("signature"); sig.Tag != "signature" {
  239. return nil, fmt.Errorf("prekey node doesn't contain signature tag")
  240. } else if sigBytes, ok := sig.Content.([]byte); !ok {
  241. return nil, fmt.Errorf("prekey signature has unexpected content (%T)", sig.Content)
  242. } else if len(sigBytes) != 64 {
  243. return nil, fmt.Errorf("prekey signature has unexpected number of bytes (%d, expected 64)", len(sigBytes))
  244. } else {
  245. key.Signature = (*[64]byte)(sigBytes)
  246. }
  247. }
  248. return &key, nil
  249. }
  250. func preKeysToNodes(prekeys []*keys.PreKey) []waBinary.Node {
  251. nodes := make([]waBinary.Node, len(prekeys))
  252. for i, key := range prekeys {
  253. nodes[i] = preKeyToNode(key)
  254. }
  255. return nodes
  256. }