||
- // Copyright (c) 2025 Tulir Asokan
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this
- // file, You can obtain one at http://mozilla.org/MPL/2.0/.
- // Package sqlstore contains an SQL-backed implementation of the interfaces in the store package.
- package sqlstore
- import (
- "context"
- "database/sql"
- "database/sql/driver"
- "errors"
- "fmt"
- "slices"
- "strings"
- "sync"
- "time"
- "go.mau.fi/util/dbutil"
- "go.mau.fi/util/exslices"
- "go.mau.fi/util/exsync"
- "git.bobomao.top/joey/testwh/store"
- "git.bobomao.top/joey/testwh/types"
- "git.bobomao.top/joey/testwh/util/keys"
- )
- // ErrInvalidLength is returned by some database getters if the database returned a byte array with an unexpected length.
- // This should be impossible, as the database schema contains CHECK()s for all the relevant columns.
- var ErrInvalidLength = errors.New("database returned byte array with illegal length")
- // PostgresArrayWrapper is a function to wrap array values before passing them to the sql package.
- //
- // When using github.com/lib/pq, you should set
- //
- // whatsmeow.PostgresArrayWrapper = pq.Array
- var PostgresArrayWrapper func(any) interface {
- driver.Valuer
- sql.Scanner
- }
- type SQLStore struct {
- *Container
- JID string
- preKeyLock sync.Mutex
- contactCache map[types.JID]*types.ContactInfo
- contactCacheLock sync.Mutex
- migratedPNSessionsCache *exsync.Set[string]
- }
- // NewSQLStore creates a new SQLStore with the given database container and user JID.
- // It contains implementations of all the different stores in the store package.
- //
- // In general, you should use Container.NewDevice or Container.GetDevice instead of this.
- func NewSQLStore(c *Container, jid types.JID) *SQLStore {
- return &SQLStore{
- Container: c,
- JID: jid.String(),
- contactCache: make(map[types.JID]*types.ContactInfo),
- migratedPNSessionsCache: exsync.NewSet[string](),
- }
- }
- var _ store.AllSessionSpecificStores = (*SQLStore)(nil)
- const (
- putIdentityQuery = `
- INSERT INTO whatsmeow_identity_keys (our_jid, their_id, identity) VALUES ($1, $2, $3)
- ON CONFLICT (our_jid, their_id) DO UPDATE SET identity=excluded.identity
- `
- deleteAllIdentitiesQuery = `DELETE FROM whatsmeow_identity_keys WHERE our_jid=$1 AND their_id LIKE $2`
- deleteIdentityQuery = `DELETE FROM whatsmeow_identity_keys WHERE our_jid=$1 AND their_id=$2`
- getIdentityQuery = `SELECT identity FROM whatsmeow_identity_keys WHERE our_jid=$1 AND their_id=$2`
- )
- func (s *SQLStore) PutIdentity(ctx context.Context, address string, key [32]byte) error {
- _, err := s.db.Exec(ctx, putIdentityQuery, s.JID, address, key[:])
- return err
- }
- func (s *SQLStore) DeleteAllIdentities(ctx context.Context, phone string) error {
- _, err := s.db.Exec(ctx, deleteAllIdentitiesQuery, s.JID, phone+":%")
- return err
- }
- func (s *SQLStore) DeleteIdentity(ctx context.Context, address string) error {
- _, err := s.db.Exec(ctx, deleteAllIdentitiesQuery, s.JID, address)
- return err
- }
- func (s *SQLStore) IsTrustedIdentity(ctx context.Context, address string, key [32]byte) (bool, error) {
- var existingIdentity []byte
- err := s.db.QueryRow(ctx, getIdentityQuery, s.JID, address).Scan(&existingIdentity)
- if errors.Is(err, sql.ErrNoRows) {
- // Trust if not known, it'll be saved automatically later
- return true, nil
- } else if err != nil {
- return false, err
- } else if len(existingIdentity) != 32 {
- return false, ErrInvalidLength
- }
- return *(*[32]byte)(existingIdentity) == key, nil
- }
- const (
- getSessionQuery = `SELECT session FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id=$2`
- hasSessionQuery = `SELECT true FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id=$2`
- getManySessionQueryPostgres = `SELECT their_id, session FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id = ANY($2)`
- getManySessionQueryGeneric = `SELECT their_id, session FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id IN (%s)`
- putSessionQuery = `
- INSERT INTO whatsmeow_sessions (our_jid, their_id, session) VALUES ($1, $2, $3)
- ON CONFLICT (our_jid, their_id) DO UPDATE SET session=excluded.session
- `
- deleteAllSessionsQuery = `DELETE FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id LIKE $2`
- deleteSessionQuery = `DELETE FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id=$2`
- migratePNToLIDSessionsQuery = `
- INSERT INTO whatsmeow_sessions (our_jid, their_id, session)
- SELECT our_jid, replace(their_id, $2, $3), session
- FROM whatsmeow_sessions
- WHERE our_jid=$1 AND their_id LIKE $2 || ':%'
- ON CONFLICT (our_jid, their_id) DO UPDATE SET session=excluded.session
- `
- deleteAllIdentityKeysQuery = `DELETE FROM whatsmeow_identity_keys WHERE our_jid=$1 AND their_id LIKE $2`
- migratePNToLIDIdentityKeysQuery = `
- INSERT INTO whatsmeow_identity_keys (our_jid, their_id, identity)
- SELECT our_jid, replace(their_id, $2, $3), identity
- FROM whatsmeow_identity_keys
- WHERE our_jid=$1 AND their_id LIKE $2 || ':%'
- ON CONFLICT (our_jid, their_id) DO UPDATE SET identity=excluded.identity
- `
- deleteAllSenderKeysQuery = `DELETE FROM whatsmeow_sender_keys WHERE our_jid=$1 AND sender_id LIKE $2`
- migratePNToLIDSenderKeysQuery = `
- INSERT INTO whatsmeow_sender_keys (our_jid, chat_id, sender_id, sender_key)
- SELECT our_jid, chat_id, replace(sender_id, $2, $3), sender_key
- FROM whatsmeow_sender_keys
- WHERE our_jid=$1 AND sender_id LIKE $2 || ':%'
- ON CONFLICT (our_jid, chat_id, sender_id) DO UPDATE SET sender_key=excluded.sender_key
- `
- )
- func (s *SQLStore) GetSession(ctx context.Context, address string) (session []byte, err error) {
- err = s.db.QueryRow(ctx, getSessionQuery, s.JID, address).Scan(&session)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
- }
- func (s *SQLStore) HasSession(ctx context.Context, address string) (has bool, err error) {
- err = s.db.QueryRow(ctx, hasSessionQuery, s.JID, address).Scan(&has)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
- }
- type addressSessionTuple struct {
- Address string
- Session []byte
- }
- var sessionScanner = dbutil.ConvertRowFn[addressSessionTuple](func(row dbutil.Scannable) (out addressSessionTuple, err error) {
- err = row.Scan(&out.Address, &out.Session)
- return
- })
- func (s *SQLStore) GetManySessions(ctx context.Context, addresses []string) (map[string][]byte, error) {
- if len(addresses) == 0 {
- return nil, nil
- }
- var rows dbutil.Rows
- var err error
- if s.db.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
- rows, err = s.db.Query(ctx, getManySessionQueryPostgres, s.JID, PostgresArrayWrapper(addresses))
- } else {
- args := make([]any, len(addresses)+1)
- placeholders := make([]string, len(addresses))
- args[0] = s.JID
- for i, addr := range addresses {
- args[i+1] = addr
- placeholders[i] = fmt.Sprintf("$%d", i+2)
- }
- rows, err = s.db.Query(ctx, fmt.Sprintf(getManySessionQueryGeneric, strings.Join(placeholders, ",")), args...)
- }
- result := make(map[string][]byte, len(addresses))
- for _, addr := range addresses {
- result[addr] = nil
- }
- err = sessionScanner.NewRowIter(rows, err).Iter(func(tuple addressSessionTuple) (bool, error) {
- result[tuple.Address] = tuple.Session
- return true, nil
- })
- if err != nil {
- return nil, err
- }
- return result, nil
- }
- func (s *SQLStore) PutManySessions(ctx context.Context, sessions map[string][]byte) error {
- return s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
- for addr, sess := range sessions {
- err := s.PutSession(ctx, addr, sess)
- if err != nil {
- return err
- }
- }
- return nil
- })
- }
- func (s *SQLStore) PutSession(ctx context.Context, address string, session []byte) error {
- _, err := s.db.Exec(ctx, putSessionQuery, s.JID, address, session)
- return err
- }
- func (s *SQLStore) DeleteAllSessions(ctx context.Context, phone string) error {
- return s.deleteAllSessions(ctx, phone)
- }
- func (s *SQLStore) deleteAllSessions(ctx context.Context, phone string) error {
- _, err := s.db.Exec(ctx, deleteAllSessionsQuery, s.JID, phone+":%")
- return err
- }
- func (s *SQLStore) deleteAllSenderKeys(ctx context.Context, phone string) error {
- _, err := s.db.Exec(ctx, deleteAllSenderKeysQuery, s.JID, phone+":%")
- return err
- }
- func (s *SQLStore) deleteAllIdentityKeys(ctx context.Context, phone string) error {
- _, err := s.db.Exec(ctx, deleteAllIdentityKeysQuery, s.JID, phone+":%")
- return err
- }
- func (s *SQLStore) DeleteSession(ctx context.Context, address string) error {
- _, err := s.db.Exec(ctx, deleteSessionQuery, s.JID, address)
- return err
- }
- func (s *SQLStore) MigratePNToLID(ctx context.Context, pn, lid types.JID) error {
- pnSignal := pn.SignalAddressUser()
- if !s.migratedPNSessionsCache.Add(pnSignal) {
- return nil
- }
- var sessionsUpdated, identityKeysUpdated, senderKeysUpdated int64
- lidSignal := lid.SignalAddressUser()
- err := s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
- res, err := s.db.Exec(ctx, migratePNToLIDSessionsQuery, s.JID, pnSignal, lidSignal)
- if err != nil {
- return fmt.Errorf("failed to migrate sessions: %w", err)
- }
- sessionsUpdated, err = res.RowsAffected()
- if err != nil {
- return fmt.Errorf("failed to get rows affected for sessions: %w", err)
- }
- err = s.deleteAllSessions(ctx, pnSignal)
- if err != nil {
- return fmt.Errorf("failed to delete extra sessions: %w", err)
- }
- res, err = s.db.Exec(ctx, migratePNToLIDIdentityKeysQuery, s.JID, pnSignal, lidSignal)
- if err != nil {
- return fmt.Errorf("failed to migrate identity keys: %w", err)
- }
- identityKeysUpdated, err = res.RowsAffected()
- if err != nil {
- return fmt.Errorf("failed to get rows affected for identity keys: %w", err)
- }
- err = s.deleteAllIdentityKeys(ctx, pnSignal)
- if err != nil {
- return fmt.Errorf("failed to delete extra identity keys: %w", err)
- }
- res, err = s.db.Exec(ctx, migratePNToLIDSenderKeysQuery, s.JID, pnSignal, lidSignal)
- if err != nil {
- return fmt.Errorf("failed to migrate sender keys: %w", err)
- }
- senderKeysUpdated, err = res.RowsAffected()
- if err != nil {
- return fmt.Errorf("failed to get rows affected for sender keys: %w", err)
- }
- err = s.deleteAllSenderKeys(ctx, pnSignal)
- if err != nil {
- return fmt.Errorf("failed to delete extra sender keys: %w", err)
- }
- return nil
- })
- if err != nil {
- return err
- }
- if sessionsUpdated > 0 || senderKeysUpdated > 0 || identityKeysUpdated > 0 {
- s.log.Infof("Migrated %d sessions, %d identity keys and %d sender keys from %s to %s", sessionsUpdated, identityKeysUpdated, senderKeysUpdated, pnSignal, lidSignal)
- } else {
- s.log.Debugf("No sessions or sender keys found to migrate from %s to %s", pnSignal, lidSignal)
- }
- return nil
- }
- const (
- getLastPreKeyIDQuery = `SELECT MAX(key_id) FROM whatsmeow_pre_keys WHERE jid=$1`
- insertPreKeyQuery = `INSERT INTO whatsmeow_pre_keys (jid, key_id, key, uploaded) VALUES ($1, $2, $3, $4)`
- getUnuploadedPreKeysQuery = `SELECT key_id, key FROM whatsmeow_pre_keys WHERE jid=$1 AND uploaded=false ORDER BY key_id LIMIT $2`
- getPreKeyQuery = `SELECT key_id, key FROM whatsmeow_pre_keys WHERE jid=$1 AND key_id=$2`
- deletePreKeyQuery = `DELETE FROM whatsmeow_pre_keys WHERE jid=$1 AND key_id=$2`
- markPreKeysAsUploadedQuery = `UPDATE whatsmeow_pre_keys SET uploaded=true WHERE jid=$1 AND key_id<=$2`
- getUploadedPreKeyCountQuery = `SELECT COUNT(*) FROM whatsmeow_pre_keys WHERE jid=$1 AND uploaded=true`
- )
- func (s *SQLStore) genOnePreKey(ctx context.Context, id uint32, markUploaded bool) (*keys.PreKey, error) {
- key := keys.NewPreKey(id)
- _, err := s.db.Exec(ctx, insertPreKeyQuery, s.JID, key.KeyID, key.Priv[:], markUploaded)
- return key, err
- }
- func (s *SQLStore) getNextPreKeyID(ctx context.Context) (uint32, error) {
- var lastKeyID sql.NullInt32
- err := s.db.QueryRow(ctx, getLastPreKeyIDQuery, s.JID).Scan(&lastKeyID)
- if err != nil {
- return 0, fmt.Errorf("failed to query next prekey ID: %w", err)
- }
- return uint32(lastKeyID.Int32) + 1, nil
- }
- func (s *SQLStore) GenOnePreKey(ctx context.Context) (*keys.PreKey, error) {
- s.preKeyLock.Lock()
- defer s.preKeyLock.Unlock()
- nextKeyID, err := s.getNextPreKeyID(ctx)
- if err != nil {
- return nil, err
- }
- return s.genOnePreKey(ctx, nextKeyID, true)
- }
- func (s *SQLStore) GetOrGenPreKeys(ctx context.Context, count uint32) ([]*keys.PreKey, error) {
- s.preKeyLock.Lock()
- defer s.preKeyLock.Unlock()
- res, err := s.db.Query(ctx, getUnuploadedPreKeysQuery, s.JID, count)
- if err != nil {
- return nil, fmt.Errorf("failed to query existing prekeys: %w", err)
- }
- newKeys := make([]*keys.PreKey, count)
- var existingCount uint32
- for res.Next() {
- var key *keys.PreKey
- key, err = scanPreKey(res)
- if err != nil {
- return nil, err
- } else if key != nil {
- newKeys[existingCount] = key
- existingCount++
- }
- }
- if existingCount < uint32(len(newKeys)) {
- var nextKeyID uint32
- nextKeyID, err = s.getNextPreKeyID(ctx)
- if err != nil {
- return nil, err
- }
- for i := existingCount; i < count; i++ {
- newKeys[i], err = s.genOnePreKey(ctx, nextKeyID, false)
- if err != nil {
- return nil, fmt.Errorf("failed to generate prekey: %w", err)
- }
- nextKeyID++
- }
- }
- return newKeys, nil
- }
- func scanPreKey(row dbutil.Scannable) (*keys.PreKey, error) {
- var priv []byte
- var id uint32
- err := row.Scan(&id, &priv)
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- } else if err != nil {
- return nil, err
- } else if len(priv) != 32 {
- return nil, ErrInvalidLength
- }
- return &keys.PreKey{
- KeyPair: *keys.NewKeyPairFromPrivateKey(*(*[32]byte)(priv)),
- KeyID: id,
- }, nil
- }
- func (s *SQLStore) GetPreKey(ctx context.Context, id uint32) (*keys.PreKey, error) {
- return scanPreKey(s.db.QueryRow(ctx, getPreKeyQuery, s.JID, id))
- }
- func (s *SQLStore) RemovePreKey(ctx context.Context, id uint32) error {
- _, err := s.db.Exec(ctx, deletePreKeyQuery, s.JID, id)
- return err
- }
- func (s *SQLStore) MarkPreKeysAsUploaded(ctx context.Context, upToID uint32) error {
- _, err := s.db.Exec(ctx, markPreKeysAsUploadedQuery, s.JID, upToID)
- return err
- }
- func (s *SQLStore) UploadedPreKeyCount(ctx context.Context) (count int, err error) {
- err = s.db.QueryRow(ctx, getUploadedPreKeyCountQuery, s.JID).Scan(&count)
- return
- }
- const (
- getSenderKeyQuery = `SELECT sender_key FROM whatsmeow_sender_keys WHERE our_jid=$1 AND chat_id=$2 AND sender_id=$3`
- putSenderKeyQuery = `
- INSERT INTO whatsmeow_sender_keys (our_jid, chat_id, sender_id, sender_key) VALUES ($1, $2, $3, $4)
- ON CONFLICT (our_jid, chat_id, sender_id) DO UPDATE SET sender_key=excluded.sender_key
- `
- )
- func (s *SQLStore) PutSenderKey(ctx context.Context, group, user string, session []byte) error {
- _, err := s.db.Exec(ctx, putSenderKeyQuery, s.JID, group, user, session)
- return err
- }
- func (s *SQLStore) GetSenderKey(ctx context.Context, group, user string) (key []byte, err error) {
- err = s.db.QueryRow(ctx, getSenderKeyQuery, s.JID, group, user).Scan(&key)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
- }
- const (
- putAppStateSyncKeyQuery = `
- INSERT INTO whatsmeow_app_state_sync_keys (jid, key_id, key_data, timestamp, fingerprint) VALUES ($1, $2, $3, $4, $5)
- ON CONFLICT (jid, key_id) DO UPDATE
- SET key_data=excluded.key_data, timestamp=excluded.timestamp, fingerprint=excluded.fingerprint
- WHERE excluded.timestamp > whatsmeow_app_state_sync_keys.timestamp
- `
- getAppStateSyncKeyQuery = `SELECT key_data, timestamp, fingerprint FROM whatsmeow_app_state_sync_keys WHERE jid=$1 AND key_id=$2`
- getLatestAppStateSyncKeyIDQuery = `SELECT key_id FROM whatsmeow_app_state_sync_keys WHERE jid=$1 ORDER BY timestamp DESC LIMIT 1`
- )
- func (s *SQLStore) PutAppStateSyncKey(ctx context.Context, id []byte, key store.AppStateSyncKey) error {
- _, err := s.db.Exec(ctx, putAppStateSyncKeyQuery, s.JID, id, key.Data, key.Timestamp, key.Fingerprint)
- return err
- }
- func (s *SQLStore) GetAppStateSyncKey(ctx context.Context, id []byte) (*store.AppStateSyncKey, error) {
- var key store.AppStateSyncKey
- err := s.db.QueryRow(ctx, getAppStateSyncKeyQuery, s.JID, id).Scan(&key.Data, &key.Timestamp, &key.Fingerprint)
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- }
- return &key, err
- }
- func (s *SQLStore) GetLatestAppStateSyncKeyID(ctx context.Context) ([]byte, error) {
- var keyID []byte
- err := s.db.QueryRow(ctx, getLatestAppStateSyncKeyIDQuery, s.JID).Scan(&keyID)
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- }
- return keyID, err
- }
- const (
- putAppStateVersionQuery = `
- INSERT INTO whatsmeow_app_state_version (jid, name, version, hash) VALUES ($1, $2, $3, $4)
- ON CONFLICT (jid, name) DO UPDATE SET version=excluded.version, hash=excluded.hash
- `
- getAppStateVersionQuery = `SELECT version, hash FROM whatsmeow_app_state_version WHERE jid=$1 AND name=$2`
- deleteAppStateVersionQuery = `DELETE FROM whatsmeow_app_state_version WHERE jid=$1 AND name=$2`
- putAppStateMutationMACsQuery = `INSERT INTO whatsmeow_app_state_mutation_macs (jid, name, version, index_mac, value_mac) VALUES `
- deleteAppStateMutationMACsQueryPostgres = `DELETE FROM whatsmeow_app_state_mutation_macs WHERE jid=$1 AND name=$2 AND index_mac=ANY($3::bytea[])`
- deleteAppStateMutationMACsQueryGeneric = `DELETE FROM whatsmeow_app_state_mutation_macs WHERE jid=$1 AND name=$2 AND index_mac IN `
- getAppStateMutationMACQuery = `SELECT value_mac FROM whatsmeow_app_state_mutation_macs WHERE jid=$1 AND name=$2 AND index_mac=$3 ORDER BY version DESC LIMIT 1`
- )
- func (s *SQLStore) PutAppStateVersion(ctx context.Context, name string, version uint64, hash [128]byte) error {
- _, err := s.db.Exec(ctx, putAppStateVersionQuery, s.JID, name, version, hash[:])
- return err
- }
- func (s *SQLStore) GetAppStateVersion(ctx context.Context, name string) (version uint64, hash [128]byte, err error) {
- var uncheckedHash []byte
- err = s.db.QueryRow(ctx, getAppStateVersionQuery, s.JID, name).Scan(&version, &uncheckedHash)
- if errors.Is(err, sql.ErrNoRows) {
- // version will be 0 and hash will be an empty array, which is the correct initial state
- err = nil
- } else if err != nil {
- // There's an error, just return it
- } else if len(uncheckedHash) != 128 {
- // This shouldn't happen
- err = ErrInvalidLength
- } else {
- // No errors, convert hash slice to array
- hash = *(*[128]byte)(uncheckedHash)
- }
- return
- }
- func (s *SQLStore) DeleteAppStateVersion(ctx context.Context, name string) error {
- _, err := s.db.Exec(ctx, deleteAppStateVersionQuery, s.JID, name)
- return err
- }
- func (s *SQLStore) putAppStateMutationMACs(ctx context.Context, name string, version uint64, mutations []store.AppStateMutationMAC) error {
- values := make([]any, 3+len(mutations)*2)
- queryParts := make([]string, len(mutations))
- values[0] = s.JID
- values[1] = name
- values[2] = version
- placeholderSyntax := "($1, $2, $3, $%d, $%d)"
- if s.db.Dialect == dbutil.SQLite {
- placeholderSyntax = "(?1, ?2, ?3, ?%d, ?%d)"
- }
- for i, mutation := range mutations {
- baseIndex := 3 + i*2
- values[baseIndex] = mutation.IndexMAC
- values[baseIndex+1] = mutation.ValueMAC
- queryParts[i] = fmt.Sprintf(placeholderSyntax, baseIndex+1, baseIndex+2)
- }
- _, err := s.db.Exec(ctx, putAppStateMutationMACsQuery+strings.Join(queryParts, ","), values...)
- return err
- }
- const mutationBatchSize = 400
- func (s *SQLStore) PutAppStateMutationMACs(ctx context.Context, name string, version uint64, mutations []store.AppStateMutationMAC) error {
- if len(mutations) == 0 {
- return nil
- }
- return s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
- for slice := range slices.Chunk(mutations, mutationBatchSize) {
- err := s.putAppStateMutationMACs(ctx, name, version, slice)
- if err != nil {
- return err
- }
- }
- return nil
- })
- }
- func (s *SQLStore) DeleteAppStateMutationMACs(ctx context.Context, name string, indexMACs [][]byte) (err error) {
- if len(indexMACs) == 0 {
- return
- }
- if s.db.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
- _, err = s.db.Exec(ctx, deleteAppStateMutationMACsQueryPostgres, s.JID, name, PostgresArrayWrapper(indexMACs))
- } else {
- args := make([]any, 2+len(indexMACs))
- args[0] = s.JID
- args[1] = name
- queryParts := make([]string, len(indexMACs))
- for i, item := range indexMACs {
- args[2+i] = item
- queryParts[i] = fmt.Sprintf("$%d", i+3)
- }
- _, err = s.db.Exec(ctx, deleteAppStateMutationMACsQueryGeneric+"("+strings.Join(queryParts, ",")+")", args...)
- }
- return
- }
- func (s *SQLStore) GetAppStateMutationMAC(ctx context.Context, name string, indexMAC []byte) (valueMAC []byte, err error) {
- err = s.db.QueryRow(ctx, getAppStateMutationMACQuery, s.JID, name, indexMAC).Scan(&valueMAC)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
- }
- const (
- putContactNameQuery = `
- INSERT INTO whatsmeow_contacts (our_jid, their_jid, first_name, full_name) VALUES ($1, $2, $3, $4)
- ON CONFLICT (our_jid, their_jid) DO UPDATE SET first_name=excluded.first_name, full_name=excluded.full_name
- `
- putRedactedPhoneQuery = `
- INSERT INTO whatsmeow_contacts (our_jid, their_jid, redacted_phone)
- VALUES ($1, $2, $3)
- ON CONFLICT (our_jid, their_jid) DO UPDATE SET redacted_phone=excluded.redacted_phone
- `
- putPushNameQuery = `
- INSERT INTO whatsmeow_contacts (our_jid, their_jid, push_name) VALUES ($1, $2, $3)
- ON CONFLICT (our_jid, their_jid) DO UPDATE SET push_name=excluded.push_name
- `
- putBusinessNameQuery = `
- INSERT INTO whatsmeow_contacts (our_jid, their_jid, business_name) VALUES ($1, $2, $3)
- ON CONFLICT (our_jid, their_jid) DO UPDATE SET business_name=excluded.business_name
- `
- getContactQuery = `
- SELECT first_name, full_name, push_name, business_name, redacted_phone FROM whatsmeow_contacts WHERE our_jid=$1 AND their_jid=$2
- `
- getAllContactsQuery = `
- SELECT their_jid, first_name, full_name, push_name, business_name, redacted_phone FROM whatsmeow_contacts WHERE our_jid=$1
- `
- )
- var putContactNamesMassInsertBuilder = dbutil.NewMassInsertBuilder[store.ContactEntry, [1]any](
- putContactNameQuery, "($1, $%d, $%d, $%d)",
- )
- var putRedactedPhonesMassInsertBuilder = dbutil.NewMassInsertBuilder[store.RedactedPhoneEntry, [1]any](
- putRedactedPhoneQuery, "($1, $%d, $%d)",
- )
- func (s *SQLStore) PutPushName(ctx context.Context, user types.JID, pushName string) (bool, string, error) {
- s.contactCacheLock.Lock()
- defer s.contactCacheLock.Unlock()
- cached, err := s.getContact(ctx, user)
- if err != nil {
- return false, "", err
- }
- if cached.PushName != pushName {
- _, err = s.db.Exec(ctx, putPushNameQuery, s.JID, user, pushName)
- if err != nil {
- return false, "", err
- }
- previousName := cached.PushName
- cached.PushName = pushName
- cached.Found = true
- return true, previousName, nil
- }
- return false, "", nil
- }
- func (s *SQLStore) PutBusinessName(ctx context.Context, user types.JID, businessName string) (bool, string, error) {
- s.contactCacheLock.Lock()
- defer s.contactCacheLock.Unlock()
- cached, err := s.getContact(ctx, user)
- if err != nil {
- return false, "", err
- }
- if cached.BusinessName != businessName {
- _, err = s.db.Exec(ctx, putBusinessNameQuery, s.JID, user, businessName)
- if err != nil {
- return false, "", err
- }
- previousName := cached.BusinessName
- cached.BusinessName = businessName
- cached.Found = true
- return true, previousName, nil
- }
- return false, "", nil
- }
- func (s *SQLStore) PutContactName(ctx context.Context, user types.JID, firstName, fullName string) error {
- s.contactCacheLock.Lock()
- defer s.contactCacheLock.Unlock()
- cached, err := s.getContact(ctx, user)
- if err != nil {
- return err
- }
- if cached.FirstName != firstName || cached.FullName != fullName {
- _, err = s.db.Exec(ctx, putContactNameQuery, s.JID, user, firstName, fullName)
- if err != nil {
- return err
- }
- cached.FirstName = firstName
- cached.FullName = fullName
- cached.Found = true
- }
- return nil
- }
- const contactBatchSize = 300
- func (s *SQLStore) PutAllContactNames(ctx context.Context, contacts []store.ContactEntry) error {
- if len(contacts) == 0 {
- return nil
- }
- origLen := len(contacts)
- contacts = exslices.DeduplicateUnsortedOverwriteFunc(contacts, func(t store.ContactEntry) types.JID {
- return t.JID
- })
- if origLen != len(contacts) {
- s.log.Warnf("%d duplicate contacts found in PutAllContactNames", origLen-len(contacts))
- }
- err := s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
- for slice := range slices.Chunk(contacts, contactBatchSize) {
- query, vars := putContactNamesMassInsertBuilder.Build([1]any{s.JID}, slice)
- _, err := s.db.Exec(ctx, query, vars...)
- if err != nil {
- return err
- }
- }
- return nil
- })
- if err != nil {
- return err
- }
- s.contactCacheLock.Lock()
- // Just clear the cache, fetching pushnames and business names would be too much effort
- s.contactCache = make(map[types.JID]*types.ContactInfo)
- s.contactCacheLock.Unlock()
- return nil
- }
- func (s *SQLStore) PutManyRedactedPhones(ctx context.Context, entries []store.RedactedPhoneEntry) error {
- if len(entries) == 0 {
- return nil
- }
- origLen := len(entries)
- entries = exslices.DeduplicateUnsortedOverwriteFunc(entries, func(t store.RedactedPhoneEntry) types.JID {
- return t.JID
- })
- if origLen != len(entries) {
- s.log.Warnf("%d duplicate contacts found in PutManyRedactedPhones", origLen-len(entries))
- }
- err := s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
- for slice := range slices.Chunk(entries, contactBatchSize) {
- query, vars := putRedactedPhonesMassInsertBuilder.Build([1]any{s.JID}, slice)
- _, err := s.db.Exec(ctx, query, vars...)
- if err != nil {
- return err
- }
- }
- return nil
- })
- if err != nil {
- return err
- }
- s.contactCacheLock.Lock()
- for _, entry := range entries {
- if cached, ok := s.contactCache[entry.JID]; ok && cached.RedactedPhone == entry.RedactedPhone {
- continue
- }
- delete(s.contactCache, entry.JID)
- }
- s.contactCacheLock.Unlock()
- return nil
- }
- func (s *SQLStore) getContact(ctx context.Context, user types.JID) (*types.ContactInfo, error) {
- cached, ok := s.contactCache[user]
- if ok {
- return cached, nil
- }
- var first, full, push, business, redactedPhone sql.NullString
- err := s.db.QueryRow(ctx, getContactQuery, s.JID, user).Scan(&first, &full, &push, &business, &redactedPhone)
- if err != nil && !errors.Is(err, sql.ErrNoRows) {
- return nil, err
- }
- info := &types.ContactInfo{
- Found: err == nil,
- FirstName: first.String,
- FullName: full.String,
- PushName: push.String,
- BusinessName: business.String,
- RedactedPhone: redactedPhone.String,
- }
- s.contactCache[user] = info
- return info, nil
- }
- func (s *SQLStore) GetContact(ctx context.Context, user types.JID) (types.ContactInfo, error) {
- s.contactCacheLock.Lock()
- info, err := s.getContact(ctx, user)
- s.contactCacheLock.Unlock()
- if err != nil {
- return types.ContactInfo{}, err
- }
- return *info, nil
- }
- func (s *SQLStore) GetAllContacts(ctx context.Context) (map[types.JID]types.ContactInfo, error) {
- s.contactCacheLock.Lock()
- defer s.contactCacheLock.Unlock()
- rows, err := s.db.Query(ctx, getAllContactsQuery, s.JID)
- if err != nil {
- return nil, err
- }
- output := make(map[types.JID]types.ContactInfo, len(s.contactCache))
- for rows.Next() {
- var jid types.JID
- var first, full, push, business, redactedPhone sql.NullString
- err = rows.Scan(&jid, &first, &full, &push, &business, &redactedPhone)
- if err != nil {
- return nil, fmt.Errorf("error scanning row: %w", err)
- }
- info := types.ContactInfo{
- Found: true,
- FirstName: first.String,
- FullName: full.String,
- PushName: push.String,
- BusinessName: business.String,
- RedactedPhone: redactedPhone.String,
- }
- output[jid] = info
- s.contactCache[jid] = &info
- }
- return output, nil
- }
- const (
- putChatSettingQuery = `
- INSERT INTO whatsmeow_chat_settings (our_jid, chat_jid, %[1]s) VALUES ($1, $2, $3)
- ON CONFLICT (our_jid, chat_jid) DO UPDATE SET %[1]s=excluded.%[1]s
- `
- getChatSettingsQuery = `
- SELECT muted_until, pinned, archived FROM whatsmeow_chat_settings WHERE our_jid=$1 AND chat_jid=$2
- `
- )
- func (s *SQLStore) PutMutedUntil(ctx context.Context, chat types.JID, mutedUntil time.Time) error {
- var val int64
- if mutedUntil == store.MutedForever {
- val = -1
- } else if !mutedUntil.IsZero() {
- val = mutedUntil.Unix()
- }
- _, err := s.db.Exec(ctx, fmt.Sprintf(putChatSettingQuery, "muted_until"), s.JID, chat, val)
- return err
- }
- func (s *SQLStore) PutPinned(ctx context.Context, chat types.JID, pinned bool) error {
- _, err := s.db.Exec(ctx, fmt.Sprintf(putChatSettingQuery, "pinned"), s.JID, chat, pinned)
- return err
- }
- func (s *SQLStore) PutArchived(ctx context.Context, chat types.JID, archived bool) error {
- _, err := s.db.Exec(ctx, fmt.Sprintf(putChatSettingQuery, "archived"), s.JID, chat, archived)
- return err
- }
- func (s *SQLStore) GetChatSettings(ctx context.Context, chat types.JID) (settings types.LocalChatSettings, err error) {
- var mutedUntil int64
- err = s.db.QueryRow(ctx, getChatSettingsQuery, s.JID, chat).Scan(&mutedUntil, &settings.Pinned, &settings.Archived)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- } else if err != nil {
- return
- } else {
- settings.Found = true
- }
- if mutedUntil < 0 {
- settings.MutedUntil = store.MutedForever
- } else if mutedUntil > 0 {
- settings.MutedUntil = time.Unix(mutedUntil, 0)
- }
- return
- }
- const (
- putMsgSecret = `
- INSERT INTO whatsmeow_message_secrets (our_jid, chat_jid, sender_jid, message_id, key)
- VALUES ($1, $2, $3, $4, $5)
- ON CONFLICT (our_jid, chat_jid, sender_jid, message_id) DO NOTHING
- `
- getMsgSecret = `
- SELECT key, sender_jid
- FROM whatsmeow_message_secrets
- WHERE our_jid=$1 AND (chat_jid=$2 OR chat_jid=(
- CASE
- WHEN $2 LIKE '%@lid'
- THEN (SELECT pn || '@s.whatsapp.net' FROM whatsmeow_lid_map WHERE lid=replace($2, '@lid', ''))
- WHEN $2 LIKE '%@s.whatsapp.net'
- THEN (SELECT lid || '@lid' FROM whatsmeow_lid_map WHERE pn=replace($2, '@s.whatsapp.net', ''))
- END
- )) AND message_id=$4 AND (sender_jid=$3 OR sender_jid=(
- CASE
- WHEN $3 LIKE '%@lid'
- THEN (SELECT pn || '@s.whatsapp.net' FROM whatsmeow_lid_map WHERE lid=replace($3, '@lid', ''))
- WHEN $3 LIKE '%@s.whatsapp.net'
- THEN (SELECT lid || '@lid' FROM whatsmeow_lid_map WHERE pn=replace($3, '@s.whatsapp.net', ''))
- END
- ))
- `
- )
- func (s *SQLStore) PutMessageSecrets(ctx context.Context, inserts []store.MessageSecretInsert) (err error) {
- if len(inserts) == 0 {
- return nil
- }
- return s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
- for _, insert := range inserts {
- _, err = s.db.Exec(ctx, putMsgSecret, s.JID, insert.Chat.ToNonAD(), insert.Sender.ToNonAD(), insert.ID, insert.Secret)
- if err != nil {
- return err
- }
- }
- return nil
- })
- }
- func (s *SQLStore) PutMessageSecret(ctx context.Context, chat, sender types.JID, id types.MessageID, secret []byte) (err error) {
- _, err = s.db.Exec(ctx, putMsgSecret, s.JID, chat.ToNonAD(), sender.ToNonAD(), id, secret)
- return
- }
- func (s *SQLStore) GetMessageSecret(ctx context.Context, chat, sender types.JID, id types.MessageID) (secret []byte, realSender types.JID, err error) {
- err = s.db.QueryRow(ctx, getMsgSecret, s.JID, chat.ToNonAD(), sender.ToNonAD(), id).Scan(&secret, &realSender)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
- }
- const (
- putPrivacyTokens = `
- INSERT INTO whatsmeow_privacy_tokens (our_jid, their_jid, token, timestamp)
- VALUES ($1, $2, $3, $4)
- ON CONFLICT (our_jid, their_jid) DO UPDATE SET token=EXCLUDED.token, timestamp=EXCLUDED.timestamp
- `
- getPrivacyToken = `
- SELECT token, timestamp FROM whatsmeow_privacy_tokens WHERE our_jid=$1 AND (their_jid=$2 OR their_jid=(
- CASE
- WHEN $2 LIKE '%@lid'
- THEN (SELECT pn || '@s.whatsapp.net' FROM whatsmeow_lid_map WHERE lid=replace($2, '@lid', ''))
- WHEN $2 LIKE '%@s.whatsapp.net'
- THEN (SELECT lid || '@lid' FROM whatsmeow_lid_map WHERE pn=replace($2, '@s.whatsapp.net', ''))
- ELSE $2
- END
- ))
- ORDER BY timestamp DESC LIMIT 1
- `
- )
- func (s *SQLStore) PutPrivacyTokens(ctx context.Context, tokens ...store.PrivacyToken) error {
- args := make([]any, 1+len(tokens)*3)
- placeholders := make([]string, len(tokens))
- args[0] = s.JID
- for i, token := range tokens {
- args[i*3+1] = token.User.ToNonAD().String()
- args[i*3+2] = token.Token
- args[i*3+3] = token.Timestamp.Unix()
- placeholders[i] = fmt.Sprintf("($1, $%d, $%d, $%d)", i*3+2, i*3+3, i*3+4)
- }
- query := strings.ReplaceAll(putPrivacyTokens, "($1, $2, $3, $4)", strings.Join(placeholders, ","))
- _, err := s.db.Exec(ctx, query, args...)
- return err
- }
- func (s *SQLStore) GetPrivacyToken(ctx context.Context, user types.JID) (*store.PrivacyToken, error) {
- var token store.PrivacyToken
- token.User = user.ToNonAD()
- var ts int64
- err := s.db.QueryRow(ctx, getPrivacyToken, s.JID, token.User).Scan(&token.Token, &ts)
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- } else if err != nil {
- return nil, err
- } else {
- token.Timestamp = time.Unix(ts, 0)
- return &token, nil
- }
- }
- const (
- getBufferedEventQuery = `
- SELECT plaintext, server_timestamp, insert_timestamp FROM whatsmeow_event_buffer WHERE our_jid = $1 AND ciphertext_hash = $2
- `
- putBufferedEventQuery = `
- INSERT INTO whatsmeow_event_buffer (our_jid, ciphertext_hash, plaintext, server_timestamp, insert_timestamp)
- VALUES ($1, $2, $3, $4, $5)
- `
- clearBufferedEventPlaintextQuery = `
- UPDATE whatsmeow_event_buffer SET plaintext = NULL WHERE our_jid = $1 AND ciphertext_hash = $2
- `
- deleteOldBufferedHashesQuery = `
- DELETE FROM whatsmeow_event_buffer WHERE insert_timestamp < $1
- `
- )
- func (s *SQLStore) GetBufferedEvent(ctx context.Context, ciphertextHash [32]byte) (*store.BufferedEvent, error) {
- var insertTimeMS, serverTimeSeconds int64
- var buf store.BufferedEvent
- err := s.db.QueryRow(ctx, getBufferedEventQuery, s.JID, ciphertextHash[:]).Scan(&buf.Plaintext, &serverTimeSeconds, &insertTimeMS)
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- buf.ServerTime = time.Unix(serverTimeSeconds, 0)
- buf.InsertTime = time.UnixMilli(insertTimeMS)
- return &buf, nil
- }
- func (s *SQLStore) PutBufferedEvent(ctx context.Context, ciphertextHash [32]byte, plaintext []byte, serverTimestamp time.Time) error {
- _, err := s.db.Exec(ctx, putBufferedEventQuery, s.JID, ciphertextHash[:], plaintext, serverTimestamp.Unix(), time.Now().UnixMilli())
- return err
- }
- func (s *SQLStore) DoDecryptionTxn(ctx context.Context, fn func(context.Context) error) error {
- ctx = context.WithValue(ctx, dbutil.ContextKeyDoTxnCallerSkip, 2)
- return s.db.DoTxn(ctx, nil, fn)
- }
- func (s *SQLStore) ClearBufferedEventPlaintext(ctx context.Context, ciphertextHash [32]byte) error {
- _, err := s.db.Exec(ctx, clearBufferedEventPlaintextQuery, s.JID, ciphertextHash[:])
- return err
- }
- func (s *SQLStore) DeleteOldBufferedHashes(ctx context.Context) error {
- // The WhatsApp servers only buffer events for 14 days,
- // so we can safely delete anything older than that.
- _, err := s.db.Exec(ctx, deleteOldBufferedHashesQuery, time.Now().Add(-14*24*time.Hour).UnixMilli())
- return err
- }
|