| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- // Copyright (c) 2025 Tulir Asokan
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this
- // file, You can obtain one at http://mozilla.org/MPL/2.0/.
- // Package sqlstore contains an SQL-backed implementation of the interfaces in the store package.
- package sqlstore
- import (
- "context"
- "database/sql"
- "errors"
- "fmt"
- "slices"
- "strings"
- "sync"
- "github.com/rs/zerolog"
- "go.mau.fi/util/dbutil"
- "go.mau.fi/util/exslices"
- "git.bobomao.top/joey/testwh/store"
- "git.bobomao.top/joey/testwh/types"
- )
- type CachedLIDMap struct {
- db *dbutil.Database
- pnToLIDCache map[string]string
- lidToPNCache map[string]string
- cacheFilled bool
- lidCacheLock sync.RWMutex
- }
- var _ store.LIDStore = (*CachedLIDMap)(nil)
- func NewCachedLIDMap(db *dbutil.Database) *CachedLIDMap {
- return &CachedLIDMap{
- db: db,
- pnToLIDCache: make(map[string]string),
- lidToPNCache: make(map[string]string),
- }
- }
- const (
- deleteExistingLIDMappingQuery = `DELETE FROM whatsmeow_lid_map WHERE (lid<>$1 AND pn=$2)`
- putLIDMappingQuery = `
- INSERT INTO whatsmeow_lid_map (lid, pn)
- VALUES ($1, $2)
- ON CONFLICT (lid) DO UPDATE SET pn=excluded.pn WHERE whatsmeow_lid_map.pn<>excluded.pn
- `
- getLIDForPNQuery = `SELECT lid FROM whatsmeow_lid_map WHERE pn=$1`
- getPNForLIDQuery = `SELECT pn FROM whatsmeow_lid_map WHERE lid=$1`
- getAllLIDMappingsQuery = `SELECT lid, pn FROM whatsmeow_lid_map`
- )
- func (s *CachedLIDMap) FillCache(ctx context.Context) error {
- s.lidCacheLock.Lock()
- defer s.lidCacheLock.Unlock()
- rows, err := s.db.Query(ctx, getAllLIDMappingsQuery)
- if err != nil {
- return err
- }
- err = s.scanManyLids(rows, nil)
- s.cacheFilled = err == nil
- return err
- }
- func (s *CachedLIDMap) scanManyLids(rows dbutil.Rows, fn func(lid, pn string)) error {
- if fn == nil {
- fn = func(lid, pn string) {}
- }
- for rows.Next() {
- var lid, pn string
- err := rows.Scan(&lid, &pn)
- if err != nil {
- return err
- }
- s.pnToLIDCache[pn] = lid
- s.lidToPNCache[lid] = pn
- fn(lid, pn)
- }
- err := rows.Close()
- if err != nil {
- return err
- }
- return rows.Err()
- }
- func (s *CachedLIDMap) getLIDMapping(ctx context.Context, source types.JID, targetServer, query string, sourceToTarget, targetToSource map[string]string) (types.JID, error) {
- s.lidCacheLock.RLock()
- targetUser, ok := sourceToTarget[source.User]
- cacheFilled := s.cacheFilled
- s.lidCacheLock.RUnlock()
- if ok || cacheFilled {
- if targetUser == "" {
- return types.JID{}, nil
- }
- return types.JID{User: targetUser, Device: source.Device, Server: targetServer}, nil
- }
- s.lidCacheLock.Lock()
- defer s.lidCacheLock.Unlock()
- err := s.db.QueryRow(ctx, query, source.User).Scan(&targetUser)
- if errors.Is(err, sql.ErrNoRows) {
- // continue with empty result
- } else if err != nil {
- return types.JID{}, err
- }
- sourceToTarget[source.User] = targetUser
- if targetUser != "" {
- targetToSource[targetUser] = source.User
- return types.JID{User: targetUser, Device: source.Device, Server: targetServer}, nil
- }
- return types.JID{}, nil
- }
- func (s *CachedLIDMap) GetLIDForPN(ctx context.Context, pn types.JID) (types.JID, error) {
- if pn.Server != types.DefaultUserServer {
- return types.JID{}, fmt.Errorf("invalid GetLIDForPN call with non-PN JID %s", pn)
- }
- return s.getLIDMapping(
- ctx, pn, types.HiddenUserServer, getLIDForPNQuery,
- s.pnToLIDCache, s.lidToPNCache,
- )
- }
- func (s *CachedLIDMap) GetPNForLID(ctx context.Context, lid types.JID) (types.JID, error) {
- if lid.Server != types.HiddenUserServer {
- return types.JID{}, fmt.Errorf("invalid GetPNForLID call with non-LID JID %s", lid)
- }
- return s.getLIDMapping(
- ctx, lid, types.DefaultUserServer, getPNForLIDQuery,
- s.lidToPNCache, s.pnToLIDCache,
- )
- }
- func (s *CachedLIDMap) GetManyLIDsForPNs(ctx context.Context, pns []types.JID) (map[types.JID]types.JID, error) {
- if len(pns) == 0 {
- return nil, nil
- }
- result := make(map[types.JID]types.JID, len(pns))
- s.lidCacheLock.RLock()
- missingPNs := make([]string, 0, len(pns))
- missingPNDevices := make(map[string][]types.JID)
- for _, pn := range pns {
- if pn.Server != types.DefaultUserServer {
- continue
- }
- if lidUser, ok := s.pnToLIDCache[pn.User]; ok && lidUser != "" {
- result[pn] = types.JID{User: lidUser, Device: pn.Device, Server: types.HiddenUserServer}
- } else if !s.cacheFilled {
- missingPNs = append(missingPNs, pn.User)
- missingPNDevices[pn.User] = append(missingPNDevices[pn.User], pn)
- }
- }
- s.lidCacheLock.RUnlock()
- if len(missingPNs) == 0 {
- return result, nil
- }
- s.lidCacheLock.Lock()
- defer s.lidCacheLock.Unlock()
- var rows dbutil.Rows
- var err error
- if s.db.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
- rows, err = s.db.Query(
- ctx,
- `SELECT lid, pn FROM whatsmeow_lid_map WHERE pn = ANY($1)`,
- PostgresArrayWrapper(missingPNs),
- )
- } else {
- placeholders := make([]string, len(missingPNs))
- for i := range missingPNs {
- placeholders[i] = fmt.Sprintf("$%d", i+1)
- }
- rows, err = s.db.Query(
- ctx,
- fmt.Sprintf(`SELECT lid, pn FROM whatsmeow_lid_map WHERE pn IN (%s)`, strings.Join(placeholders, ",")),
- exslices.CastToAny(missingPNs)...,
- )
- }
- if err != nil {
- return nil, err
- }
- err = s.scanManyLids(rows, func(lid, pn string) {
- for _, dev := range missingPNDevices[pn] {
- lidDev := dev
- lidDev.Server = types.HiddenUserServer
- lidDev.User = lid
- result[dev] = lidDev.ToNonAD()
- }
- })
- return result, err
- }
- func (s *CachedLIDMap) PutLIDMapping(ctx context.Context, lid, pn types.JID) error {
- if lid.Server != types.HiddenUserServer || pn.Server != types.DefaultUserServer {
- return fmt.Errorf("invalid PutLIDMapping call %s/%s", lid, pn)
- }
- s.lidCacheLock.Lock()
- defer s.lidCacheLock.Unlock()
- cachedLID, ok := s.pnToLIDCache[pn.User]
- if ok && cachedLID == lid.User {
- return nil
- }
- return s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
- return s.unlockedPutLIDMapping(ctx, lid, pn)
- })
- }
- func (s *CachedLIDMap) PutManyLIDMappings(ctx context.Context, mappings []store.LIDMapping) error {
- s.lidCacheLock.Lock()
- defer s.lidCacheLock.Unlock()
- mappings = slices.DeleteFunc(mappings, func(mapping store.LIDMapping) bool {
- if mapping.LID.Server != types.HiddenUserServer || mapping.PN.Server != types.DefaultUserServer {
- zerolog.Ctx(ctx).Debug().
- Stringer("entry_lid", mapping.LID).
- Stringer("entry_pn", mapping.PN).
- Msg("Ignoring invalid entry in PutManyLIDMappings")
- return true
- }
- cachedLID, ok := s.pnToLIDCache[mapping.PN.User]
- if ok && cachedLID == mapping.LID.User {
- return true
- }
- return false
- })
- mappings = exslices.DeduplicateUnsortedOverwrite(mappings)
- if len(mappings) == 0 {
- return nil
- }
- return s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
- for _, mapping := range mappings {
- err := s.unlockedPutLIDMapping(ctx, mapping.LID, mapping.PN)
- if err != nil {
- return err
- }
- }
- return nil
- })
- }
- func (s *CachedLIDMap) unlockedPutLIDMapping(ctx context.Context, lid, pn types.JID) error {
- if lid.Server != types.HiddenUserServer || pn.Server != types.DefaultUserServer {
- return fmt.Errorf("invalid PutLIDMapping call %s/%s", lid, pn)
- }
- _, err := s.db.Exec(ctx, deleteExistingLIDMappingQuery, lid.User, pn.User)
- if err != nil {
- return err
- }
- _, err = s.db.Exec(ctx, putLIDMappingQuery, lid.User, pn.User)
- if err != nil {
- return err
- }
- s.pnToLIDCache[pn.User] = lid.User
- s.lidToPNCache[lid.User] = pn.User
- return nil
- }
|