download-to-file.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. // Copyright (c) 2024 Tulir Asokan
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. package whatsmeow
  7. import (
  8. "context"
  9. "crypto/hmac"
  10. "crypto/sha256"
  11. "encoding/base64"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "os"
  16. "strings"
  17. "time"
  18. "go.mau.fi/util/fallocate"
  19. "go.mau.fi/util/retryafter"
  20. "go.mau.fi/whatsmeow/proto/waMediaTransport"
  21. "go.mau.fi/whatsmeow/util/cbcutil"
  22. )
  23. type File interface {
  24. io.Reader
  25. io.Writer
  26. io.Seeker
  27. io.ReaderAt
  28. io.WriterAt
  29. Truncate(size int64) error
  30. Stat() (os.FileInfo, error)
  31. }
  32. // DownloadToFile downloads the attachment from the given protobuf message.
  33. //
  34. // This is otherwise identical to [Download], but writes the attachment to a file instead of returning it as a byte slice.
  35. func (cli *Client) DownloadToFile(ctx context.Context, msg DownloadableMessage, file File) error {
  36. if cli == nil {
  37. return ErrClientIsNil
  38. }
  39. mediaType := GetMediaType(msg)
  40. if mediaType == "" {
  41. return fmt.Errorf("%w %T", ErrUnknownMediaType, msg)
  42. }
  43. urlable, ok := msg.(downloadableMessageWithURL)
  44. var url string
  45. var isWebWhatsappNetURL bool
  46. if ok {
  47. url = urlable.GetURL()
  48. isWebWhatsappNetURL = strings.HasPrefix(url, "https://web.whatsapp.net")
  49. }
  50. if len(url) > 0 && !isWebWhatsappNetURL {
  51. return cli.downloadAndDecryptToFile(ctx, url, msg.GetMediaKey(), mediaType, getSize(msg), msg.GetFileEncSHA256(), msg.GetFileSHA256(), file)
  52. } else if len(msg.GetDirectPath()) > 0 {
  53. return cli.DownloadMediaWithPathToFile(ctx, msg.GetDirectPath(), msg.GetFileEncSHA256(), msg.GetFileSHA256(), msg.GetMediaKey(), getSize(msg), mediaType, mediaTypeToMMSType[mediaType], file)
  54. } else {
  55. if isWebWhatsappNetURL {
  56. cli.Log.Warnf("Got a media message with a web.whatsapp.net URL (%s) and no direct path", url)
  57. }
  58. return ErrNoURLPresent
  59. }
  60. }
  61. func (cli *Client) DownloadFBToFile(
  62. ctx context.Context,
  63. transport *waMediaTransport.WAMediaTransport_Integral,
  64. mediaType MediaType,
  65. file File,
  66. ) error {
  67. return cli.DownloadMediaWithPathToFile(ctx, transport.GetDirectPath(), transport.GetFileEncSHA256(), transport.GetFileSHA256(), transport.GetMediaKey(), -1, mediaType, mediaTypeToMMSType[mediaType], file)
  68. }
  69. func (cli *Client) DownloadMediaWithPathToFile(
  70. ctx context.Context,
  71. directPath string,
  72. encFileHash, fileHash, mediaKey []byte,
  73. fileLength int,
  74. mediaType MediaType,
  75. mmsType string,
  76. file File,
  77. ) error {
  78. mediaConn, err := cli.refreshMediaConn(ctx, false)
  79. if err != nil {
  80. return fmt.Errorf("failed to refresh media connections: %w", err)
  81. }
  82. if len(mmsType) == 0 {
  83. mmsType = mediaTypeToMMSType[mediaType]
  84. }
  85. for i, host := range mediaConn.Hosts {
  86. // TODO omit hash for unencrypted media?
  87. mediaURL := fmt.Sprintf("https://%s%s&hash=%s&mms-type=%s&__wa-mms=", host.Hostname, directPath, base64.URLEncoding.EncodeToString(encFileHash), mmsType)
  88. err = cli.downloadAndDecryptToFile(ctx, mediaURL, mediaKey, mediaType, fileLength, encFileHash, fileHash, file)
  89. if err == nil ||
  90. errors.Is(err, ErrFileLengthMismatch) ||
  91. errors.Is(err, ErrInvalidMediaSHA256) ||
  92. errors.Is(err, ErrMediaDownloadFailedWith403) ||
  93. errors.Is(err, ErrMediaDownloadFailedWith404) ||
  94. errors.Is(err, ErrMediaDownloadFailedWith410) ||
  95. errors.Is(err, context.Canceled) {
  96. return err
  97. } else if i >= len(mediaConn.Hosts)-1 {
  98. return fmt.Errorf("failed to download media from last host: %w", err)
  99. }
  100. cli.Log.Warnf("Failed to download media: %s, trying with next host...", err)
  101. }
  102. return err
  103. }
  104. func (cli *Client) downloadAndDecryptToFile(
  105. ctx context.Context,
  106. url string,
  107. mediaKey []byte,
  108. appInfo MediaType,
  109. fileLength int,
  110. fileEncSHA256, fileSHA256 []byte,
  111. file File,
  112. ) error {
  113. iv, cipherKey, macKey, _ := getMediaKeys(mediaKey, appInfo)
  114. hasher := sha256.New()
  115. if mac, err := cli.downloadPossiblyEncryptedMediaWithRetriesToFile(ctx, url, fileEncSHA256, file); err != nil {
  116. return err
  117. } else if mediaKey == nil && fileEncSHA256 == nil && mac == nil {
  118. // Unencrypted media, just return the downloaded data
  119. return nil
  120. } else if err = validateMediaFile(file, iv, macKey, mac); err != nil {
  121. return err
  122. } else if _, err = file.Seek(0, io.SeekStart); err != nil {
  123. return fmt.Errorf("failed to seek to start of file after validating mac: %w", err)
  124. } else if err = cbcutil.DecryptFile(cipherKey, iv, file); err != nil {
  125. return fmt.Errorf("failed to decrypt file: %w", err)
  126. } else if ReturnDownloadWarnings {
  127. if info, err := file.Stat(); err != nil {
  128. return fmt.Errorf("failed to stat file: %w", err)
  129. } else if fileLength >= 0 && info.Size() != int64(fileLength) {
  130. return fmt.Errorf("%w: expected %d, got %d", ErrFileLengthMismatch, fileLength, info.Size())
  131. } else if _, err = file.Seek(0, io.SeekStart); err != nil {
  132. return fmt.Errorf("failed to seek to start of file after decrypting: %w", err)
  133. } else if _, err = io.Copy(hasher, file); err != nil {
  134. return fmt.Errorf("failed to hash file: %w", err)
  135. } else if !hmac.Equal(fileSHA256, hasher.Sum(nil)) {
  136. return ErrInvalidMediaSHA256
  137. }
  138. }
  139. return nil
  140. }
  141. func (cli *Client) downloadPossiblyEncryptedMediaWithRetriesToFile(ctx context.Context, url string, checksum []byte, file File) (mac []byte, err error) {
  142. for retryNum := 0; retryNum < 5; retryNum++ {
  143. if checksum == nil {
  144. _, _, err = cli.downloadMediaToFile(ctx, url, file)
  145. } else {
  146. mac, err = cli.downloadEncryptedMediaToFile(ctx, url, checksum, file)
  147. }
  148. if err == nil || !shouldRetryMediaDownload(err) {
  149. return
  150. }
  151. retryDuration := time.Duration(retryNum+1) * time.Second
  152. var httpErr DownloadHTTPError
  153. if errors.As(err, &httpErr) {
  154. retryDuration = retryafter.Parse(httpErr.Response.Header.Get("Retry-After"), retryDuration)
  155. }
  156. cli.Log.Warnf("Failed to download media due to network error: %v, retrying in %s...", err, retryDuration)
  157. _, err = file.Seek(0, io.SeekStart)
  158. if err != nil {
  159. return nil, fmt.Errorf("failed to seek to start of file to retry download: %w", err)
  160. }
  161. select {
  162. case <-ctx.Done():
  163. return nil, ctx.Err()
  164. case <-time.After(retryDuration):
  165. }
  166. }
  167. return
  168. }
  169. func (cli *Client) downloadMediaToFile(ctx context.Context, url string, file io.Writer) (int64, []byte, error) {
  170. resp, err := cli.doMediaDownloadRequest(ctx, url)
  171. if err != nil {
  172. return 0, nil, err
  173. }
  174. defer resp.Body.Close()
  175. osFile, ok := file.(*os.File)
  176. if ok && resp.ContentLength > 0 {
  177. err = fallocate.Fallocate(osFile, int(resp.ContentLength))
  178. if err != nil {
  179. return 0, nil, fmt.Errorf("failed to preallocate file: %w", err)
  180. }
  181. }
  182. hasher := sha256.New()
  183. n, err := io.Copy(file, io.TeeReader(resp.Body, hasher))
  184. return n, hasher.Sum(nil), err
  185. }
  186. func (cli *Client) downloadEncryptedMediaToFile(ctx context.Context, url string, checksum []byte, file File) ([]byte, error) {
  187. size, hash, err := cli.downloadMediaToFile(ctx, url, file)
  188. if err != nil {
  189. return nil, err
  190. } else if size <= mediaHMACLength {
  191. return nil, ErrTooShortFile
  192. } else if len(checksum) == 32 && !hmac.Equal(checksum, hash) {
  193. return nil, ErrInvalidMediaEncSHA256
  194. }
  195. mac := make([]byte, mediaHMACLength)
  196. _, err = file.ReadAt(mac, size-mediaHMACLength)
  197. if err != nil {
  198. return nil, fmt.Errorf("failed to read MAC from file: %w", err)
  199. }
  200. err = file.Truncate(size - mediaHMACLength)
  201. if err != nil {
  202. return nil, fmt.Errorf("failed to truncate file to remove MAC: %w", err)
  203. }
  204. return mac, nil
  205. }
  206. func validateMediaFile(file io.ReadSeeker, iv, macKey, mac []byte) error {
  207. h := hmac.New(sha256.New, macKey)
  208. h.Write(iv)
  209. _, err := file.Seek(0, io.SeekStart)
  210. if err != nil {
  211. return fmt.Errorf("failed to seek to start of file: %w", err)
  212. }
  213. _, err = io.Copy(h, file)
  214. if err != nil {
  215. return fmt.Errorf("failed to hash file: %w", err)
  216. }
  217. if !hmac.Equal(h.Sum(nil)[:mediaHMACLength], mac) {
  218. return ErrInvalidMediaHMAC
  219. }
  220. return nil
  221. }