| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234 |
- // Copyright (c) 2024 Tulir Asokan
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this
- // file, You can obtain one at http://mozilla.org/MPL/2.0/.
- package whatsmeow
- import (
- "context"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
- "os"
- "strings"
- "time"
- "go.mau.fi/util/fallocate"
- "go.mau.fi/util/retryafter"
- "go.mau.fi/whatsmeow/proto/waMediaTransport"
- "go.mau.fi/whatsmeow/util/cbcutil"
- )
- type File interface {
- io.Reader
- io.Writer
- io.Seeker
- io.ReaderAt
- io.WriterAt
- Truncate(size int64) error
- Stat() (os.FileInfo, error)
- }
- // DownloadToFile downloads the attachment from the given protobuf message.
- //
- // This is otherwise identical to [Download], but writes the attachment to a file instead of returning it as a byte slice.
- func (cli *Client) DownloadToFile(ctx context.Context, msg DownloadableMessage, file File) error {
- if cli == nil {
- return ErrClientIsNil
- }
- mediaType := GetMediaType(msg)
- if mediaType == "" {
- return fmt.Errorf("%w %T", ErrUnknownMediaType, msg)
- }
- urlable, ok := msg.(downloadableMessageWithURL)
- var url string
- var isWebWhatsappNetURL bool
- if ok {
- url = urlable.GetURL()
- isWebWhatsappNetURL = strings.HasPrefix(url, "https://web.whatsapp.net")
- }
- if len(url) > 0 && !isWebWhatsappNetURL {
- return cli.downloadAndDecryptToFile(ctx, url, msg.GetMediaKey(), mediaType, getSize(msg), msg.GetFileEncSHA256(), msg.GetFileSHA256(), file)
- } else if len(msg.GetDirectPath()) > 0 {
- return cli.DownloadMediaWithPathToFile(ctx, msg.GetDirectPath(), msg.GetFileEncSHA256(), msg.GetFileSHA256(), msg.GetMediaKey(), getSize(msg), mediaType, mediaTypeToMMSType[mediaType], file)
- } else {
- if isWebWhatsappNetURL {
- cli.Log.Warnf("Got a media message with a web.whatsapp.net URL (%s) and no direct path", url)
- }
- return ErrNoURLPresent
- }
- }
- func (cli *Client) DownloadFBToFile(
- ctx context.Context,
- transport *waMediaTransport.WAMediaTransport_Integral,
- mediaType MediaType,
- file File,
- ) error {
- return cli.DownloadMediaWithPathToFile(ctx, transport.GetDirectPath(), transport.GetFileEncSHA256(), transport.GetFileSHA256(), transport.GetMediaKey(), -1, mediaType, mediaTypeToMMSType[mediaType], file)
- }
- func (cli *Client) DownloadMediaWithPathToFile(
- ctx context.Context,
- directPath string,
- encFileHash, fileHash, mediaKey []byte,
- fileLength int,
- mediaType MediaType,
- mmsType string,
- file File,
- ) error {
- mediaConn, err := cli.refreshMediaConn(ctx, false)
- if err != nil {
- return fmt.Errorf("failed to refresh media connections: %w", err)
- }
- if len(mmsType) == 0 {
- mmsType = mediaTypeToMMSType[mediaType]
- }
- for i, host := range mediaConn.Hosts {
- // TODO omit hash for unencrypted media?
- mediaURL := fmt.Sprintf("https://%s%s&hash=%s&mms-type=%s&__wa-mms=", host.Hostname, directPath, base64.URLEncoding.EncodeToString(encFileHash), mmsType)
- err = cli.downloadAndDecryptToFile(ctx, mediaURL, mediaKey, mediaType, fileLength, encFileHash, fileHash, file)
- if err == nil ||
- errors.Is(err, ErrFileLengthMismatch) ||
- errors.Is(err, ErrInvalidMediaSHA256) ||
- errors.Is(err, ErrMediaDownloadFailedWith403) ||
- errors.Is(err, ErrMediaDownloadFailedWith404) ||
- errors.Is(err, ErrMediaDownloadFailedWith410) ||
- errors.Is(err, context.Canceled) {
- return err
- } else if i >= len(mediaConn.Hosts)-1 {
- return fmt.Errorf("failed to download media from last host: %w", err)
- }
- cli.Log.Warnf("Failed to download media: %s, trying with next host...", err)
- }
- return err
- }
- func (cli *Client) downloadAndDecryptToFile(
- ctx context.Context,
- url string,
- mediaKey []byte,
- appInfo MediaType,
- fileLength int,
- fileEncSHA256, fileSHA256 []byte,
- file File,
- ) error {
- iv, cipherKey, macKey, _ := getMediaKeys(mediaKey, appInfo)
- hasher := sha256.New()
- if mac, err := cli.downloadPossiblyEncryptedMediaWithRetriesToFile(ctx, url, fileEncSHA256, file); err != nil {
- return err
- } else if mediaKey == nil && fileEncSHA256 == nil && mac == nil {
- // Unencrypted media, just return the downloaded data
- return nil
- } else if err = validateMediaFile(file, iv, macKey, mac); err != nil {
- return err
- } else if _, err = file.Seek(0, io.SeekStart); err != nil {
- return fmt.Errorf("failed to seek to start of file after validating mac: %w", err)
- } else if err = cbcutil.DecryptFile(cipherKey, iv, file); err != nil {
- return fmt.Errorf("failed to decrypt file: %w", err)
- } else if ReturnDownloadWarnings {
- if info, err := file.Stat(); err != nil {
- return fmt.Errorf("failed to stat file: %w", err)
- } else if fileLength >= 0 && info.Size() != int64(fileLength) {
- return fmt.Errorf("%w: expected %d, got %d", ErrFileLengthMismatch, fileLength, info.Size())
- } else if _, err = file.Seek(0, io.SeekStart); err != nil {
- return fmt.Errorf("failed to seek to start of file after decrypting: %w", err)
- } else if _, err = io.Copy(hasher, file); err != nil {
- return fmt.Errorf("failed to hash file: %w", err)
- } else if !hmac.Equal(fileSHA256, hasher.Sum(nil)) {
- return ErrInvalidMediaSHA256
- }
- }
- return nil
- }
- func (cli *Client) downloadPossiblyEncryptedMediaWithRetriesToFile(ctx context.Context, url string, checksum []byte, file File) (mac []byte, err error) {
- for retryNum := 0; retryNum < 5; retryNum++ {
- if checksum == nil {
- _, _, err = cli.downloadMediaToFile(ctx, url, file)
- } else {
- mac, err = cli.downloadEncryptedMediaToFile(ctx, url, checksum, file)
- }
- if err == nil || !shouldRetryMediaDownload(err) {
- return
- }
- retryDuration := time.Duration(retryNum+1) * time.Second
- var httpErr DownloadHTTPError
- if errors.As(err, &httpErr) {
- retryDuration = retryafter.Parse(httpErr.Response.Header.Get("Retry-After"), retryDuration)
- }
- cli.Log.Warnf("Failed to download media due to network error: %v, retrying in %s...", err, retryDuration)
- _, err = file.Seek(0, io.SeekStart)
- if err != nil {
- return nil, fmt.Errorf("failed to seek to start of file to retry download: %w", err)
- }
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-time.After(retryDuration):
- }
- }
- return
- }
- func (cli *Client) downloadMediaToFile(ctx context.Context, url string, file io.Writer) (int64, []byte, error) {
- resp, err := cli.doMediaDownloadRequest(ctx, url)
- if err != nil {
- return 0, nil, err
- }
- defer resp.Body.Close()
- osFile, ok := file.(*os.File)
- if ok && resp.ContentLength > 0 {
- err = fallocate.Fallocate(osFile, int(resp.ContentLength))
- if err != nil {
- return 0, nil, fmt.Errorf("failed to preallocate file: %w", err)
- }
- }
- hasher := sha256.New()
- n, err := io.Copy(file, io.TeeReader(resp.Body, hasher))
- return n, hasher.Sum(nil), err
- }
- func (cli *Client) downloadEncryptedMediaToFile(ctx context.Context, url string, checksum []byte, file File) ([]byte, error) {
- size, hash, err := cli.downloadMediaToFile(ctx, url, file)
- if err != nil {
- return nil, err
- } else if size <= mediaHMACLength {
- return nil, ErrTooShortFile
- } else if len(checksum) == 32 && !hmac.Equal(checksum, hash) {
- return nil, ErrInvalidMediaEncSHA256
- }
- mac := make([]byte, mediaHMACLength)
- _, err = file.ReadAt(mac, size-mediaHMACLength)
- if err != nil {
- return nil, fmt.Errorf("failed to read MAC from file: %w", err)
- }
- err = file.Truncate(size - mediaHMACLength)
- if err != nil {
- return nil, fmt.Errorf("failed to truncate file to remove MAC: %w", err)
- }
- return mac, nil
- }
- func validateMediaFile(file io.ReadSeeker, iv, macKey, mac []byte) error {
- h := hmac.New(sha256.New, macKey)
- h.Write(iv)
- _, err := file.Seek(0, io.SeekStart)
- if err != nil {
- return fmt.Errorf("failed to seek to start of file: %w", err)
- }
- _, err = io.Copy(h, file)
- if err != nil {
- return fmt.Errorf("failed to hash file: %w", err)
- }
- if !hmac.Equal(h.Sum(nil)[:mediaHMACLength], mac) {
- return ErrInvalidMediaHMAC
- }
- return nil
- }
|