lidmap.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. // Copyright (c) 2025 Tulir Asokan
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. // Package sqlstore contains an SQL-backed implementation of the interfaces in the store package.
  7. package sqlstore
  8. import (
  9. "context"
  10. "database/sql"
  11. "errors"
  12. "fmt"
  13. "slices"
  14. "strings"
  15. "sync"
  16. "github.com/rs/zerolog"
  17. "go.mau.fi/util/dbutil"
  18. "go.mau.fi/util/exslices"
  19. "git.bobomao.top/joey/testwh/store"
  20. "git.bobomao.top/joey/testwh/types"
  21. )
  22. type CachedLIDMap struct {
  23. db *dbutil.Database
  24. pnToLIDCache map[string]string
  25. lidToPNCache map[string]string
  26. cacheFilled bool
  27. lidCacheLock sync.RWMutex
  28. }
  29. var _ store.LIDStore = (*CachedLIDMap)(nil)
  30. func NewCachedLIDMap(db *dbutil.Database) *CachedLIDMap {
  31. return &CachedLIDMap{
  32. db: db,
  33. pnToLIDCache: make(map[string]string),
  34. lidToPNCache: make(map[string]string),
  35. }
  36. }
  37. const (
  38. deleteExistingLIDMappingQuery = `DELETE FROM whatsmeow_lid_map WHERE (lid<>$1 AND pn=$2)`
  39. putLIDMappingQuery = `
  40. INSERT INTO whatsmeow_lid_map (lid, pn)
  41. VALUES ($1, $2)
  42. ON CONFLICT (lid) DO UPDATE SET pn=excluded.pn WHERE whatsmeow_lid_map.pn<>excluded.pn
  43. `
  44. getLIDForPNQuery = `SELECT lid FROM whatsmeow_lid_map WHERE pn=$1`
  45. getPNForLIDQuery = `SELECT pn FROM whatsmeow_lid_map WHERE lid=$1`
  46. getAllLIDMappingsQuery = `SELECT lid, pn FROM whatsmeow_lid_map`
  47. )
  48. func (s *CachedLIDMap) FillCache(ctx context.Context) error {
  49. s.lidCacheLock.Lock()
  50. defer s.lidCacheLock.Unlock()
  51. rows, err := s.db.Query(ctx, getAllLIDMappingsQuery)
  52. if err != nil {
  53. return err
  54. }
  55. err = s.scanManyLids(rows, nil)
  56. s.cacheFilled = err == nil
  57. return err
  58. }
  59. func (s *CachedLIDMap) scanManyLids(rows dbutil.Rows, fn func(lid, pn string)) error {
  60. if fn == nil {
  61. fn = func(lid, pn string) {}
  62. }
  63. for rows.Next() {
  64. var lid, pn string
  65. err := rows.Scan(&lid, &pn)
  66. if err != nil {
  67. return err
  68. }
  69. s.pnToLIDCache[pn] = lid
  70. s.lidToPNCache[lid] = pn
  71. fn(lid, pn)
  72. }
  73. err := rows.Close()
  74. if err != nil {
  75. return err
  76. }
  77. return rows.Err()
  78. }
  79. func (s *CachedLIDMap) getLIDMapping(ctx context.Context, source types.JID, targetServer, query string, sourceToTarget, targetToSource map[string]string) (types.JID, error) {
  80. s.lidCacheLock.RLock()
  81. targetUser, ok := sourceToTarget[source.User]
  82. cacheFilled := s.cacheFilled
  83. s.lidCacheLock.RUnlock()
  84. if ok || cacheFilled {
  85. if targetUser == "" {
  86. return types.JID{}, nil
  87. }
  88. return types.JID{User: targetUser, Device: source.Device, Server: targetServer}, nil
  89. }
  90. s.lidCacheLock.Lock()
  91. defer s.lidCacheLock.Unlock()
  92. err := s.db.QueryRow(ctx, query, source.User).Scan(&targetUser)
  93. if errors.Is(err, sql.ErrNoRows) {
  94. // continue with empty result
  95. } else if err != nil {
  96. return types.JID{}, err
  97. }
  98. sourceToTarget[source.User] = targetUser
  99. if targetUser != "" {
  100. targetToSource[targetUser] = source.User
  101. return types.JID{User: targetUser, Device: source.Device, Server: targetServer}, nil
  102. }
  103. return types.JID{}, nil
  104. }
  105. func (s *CachedLIDMap) GetLIDForPN(ctx context.Context, pn types.JID) (types.JID, error) {
  106. if pn.Server != types.DefaultUserServer {
  107. return types.JID{}, fmt.Errorf("invalid GetLIDForPN call with non-PN JID %s", pn)
  108. }
  109. return s.getLIDMapping(
  110. ctx, pn, types.HiddenUserServer, getLIDForPNQuery,
  111. s.pnToLIDCache, s.lidToPNCache,
  112. )
  113. }
  114. func (s *CachedLIDMap) GetPNForLID(ctx context.Context, lid types.JID) (types.JID, error) {
  115. if lid.Server != types.HiddenUserServer {
  116. return types.JID{}, fmt.Errorf("invalid GetPNForLID call with non-LID JID %s", lid)
  117. }
  118. return s.getLIDMapping(
  119. ctx, lid, types.DefaultUserServer, getPNForLIDQuery,
  120. s.lidToPNCache, s.pnToLIDCache,
  121. )
  122. }
  123. func (s *CachedLIDMap) GetManyLIDsForPNs(ctx context.Context, pns []types.JID) (map[types.JID]types.JID, error) {
  124. if len(pns) == 0 {
  125. return nil, nil
  126. }
  127. result := make(map[types.JID]types.JID, len(pns))
  128. s.lidCacheLock.RLock()
  129. missingPNs := make([]string, 0, len(pns))
  130. missingPNDevices := make(map[string][]types.JID)
  131. for _, pn := range pns {
  132. if pn.Server != types.DefaultUserServer {
  133. continue
  134. }
  135. if lidUser, ok := s.pnToLIDCache[pn.User]; ok && lidUser != "" {
  136. result[pn] = types.JID{User: lidUser, Device: pn.Device, Server: types.HiddenUserServer}
  137. } else if !s.cacheFilled {
  138. missingPNs = append(missingPNs, pn.User)
  139. missingPNDevices[pn.User] = append(missingPNDevices[pn.User], pn)
  140. }
  141. }
  142. s.lidCacheLock.RUnlock()
  143. if len(missingPNs) == 0 {
  144. return result, nil
  145. }
  146. s.lidCacheLock.Lock()
  147. defer s.lidCacheLock.Unlock()
  148. var rows dbutil.Rows
  149. var err error
  150. if s.db.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
  151. rows, err = s.db.Query(
  152. ctx,
  153. `SELECT lid, pn FROM whatsmeow_lid_map WHERE pn = ANY($1)`,
  154. PostgresArrayWrapper(missingPNs),
  155. )
  156. } else {
  157. placeholders := make([]string, len(missingPNs))
  158. for i := range missingPNs {
  159. placeholders[i] = fmt.Sprintf("$%d", i+1)
  160. }
  161. rows, err = s.db.Query(
  162. ctx,
  163. fmt.Sprintf(`SELECT lid, pn FROM whatsmeow_lid_map WHERE pn IN (%s)`, strings.Join(placeholders, ",")),
  164. exslices.CastToAny(missingPNs)...,
  165. )
  166. }
  167. if err != nil {
  168. return nil, err
  169. }
  170. err = s.scanManyLids(rows, func(lid, pn string) {
  171. for _, dev := range missingPNDevices[pn] {
  172. lidDev := dev
  173. lidDev.Server = types.HiddenUserServer
  174. lidDev.User = lid
  175. result[dev] = lidDev.ToNonAD()
  176. }
  177. })
  178. return result, err
  179. }
  180. func (s *CachedLIDMap) PutLIDMapping(ctx context.Context, lid, pn types.JID) error {
  181. if lid.Server != types.HiddenUserServer || pn.Server != types.DefaultUserServer {
  182. return fmt.Errorf("invalid PutLIDMapping call %s/%s", lid, pn)
  183. }
  184. s.lidCacheLock.Lock()
  185. defer s.lidCacheLock.Unlock()
  186. cachedLID, ok := s.pnToLIDCache[pn.User]
  187. if ok && cachedLID == lid.User {
  188. return nil
  189. }
  190. return s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
  191. return s.unlockedPutLIDMapping(ctx, lid, pn)
  192. })
  193. }
  194. func (s *CachedLIDMap) PutManyLIDMappings(ctx context.Context, mappings []store.LIDMapping) error {
  195. s.lidCacheLock.Lock()
  196. defer s.lidCacheLock.Unlock()
  197. mappings = slices.DeleteFunc(mappings, func(mapping store.LIDMapping) bool {
  198. if mapping.LID.Server != types.HiddenUserServer || mapping.PN.Server != types.DefaultUserServer {
  199. zerolog.Ctx(ctx).Debug().
  200. Stringer("entry_lid", mapping.LID).
  201. Stringer("entry_pn", mapping.PN).
  202. Msg("Ignoring invalid entry in PutManyLIDMappings")
  203. return true
  204. }
  205. cachedLID, ok := s.pnToLIDCache[mapping.PN.User]
  206. if ok && cachedLID == mapping.LID.User {
  207. return true
  208. }
  209. return false
  210. })
  211. mappings = exslices.DeduplicateUnsortedOverwrite(mappings)
  212. if len(mappings) == 0 {
  213. return nil
  214. }
  215. return s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
  216. for _, mapping := range mappings {
  217. err := s.unlockedPutLIDMapping(ctx, mapping.LID, mapping.PN)
  218. if err != nil {
  219. return err
  220. }
  221. }
  222. return nil
  223. })
  224. }
  225. func (s *CachedLIDMap) unlockedPutLIDMapping(ctx context.Context, lid, pn types.JID) error {
  226. if lid.Server != types.HiddenUserServer || pn.Server != types.DefaultUserServer {
  227. return fmt.Errorf("invalid PutLIDMapping call %s/%s", lid, pn)
  228. }
  229. _, err := s.db.Exec(ctx, deleteExistingLIDMappingQuery, lid.User, pn.User)
  230. if err != nil {
  231. return err
  232. }
  233. _, err = s.db.Exec(ctx, putLIDMappingQuery, lid.User, pn.User)
  234. if err != nil {
  235. return err
  236. }
  237. s.pnToLIDCache[pn.User] = lid.User
  238. s.lidToPNCache[lid.User] = pn.User
  239. return nil
  240. }