sessioncache.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. // Copyright (c) 2025 Tulir Asokan
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. package store
  7. import (
  8. "context"
  9. "fmt"
  10. "github.com/rs/zerolog"
  11. "go.mau.fi/libsignal/state/record"
  12. "go.mau.fi/util/exsync"
  13. )
  14. type contextKey int
  15. const (
  16. contextKeySessionCache contextKey = iota
  17. )
  18. type sessionCacheEntry struct {
  19. Dirty bool
  20. Found bool
  21. Record *record.Session
  22. }
  23. type sessionCache = exsync.Map[string, sessionCacheEntry]
  24. func getSessionCache(ctx context.Context) *sessionCache {
  25. if ctx == nil {
  26. return nil
  27. }
  28. val := ctx.Value(contextKeySessionCache)
  29. if val == nil {
  30. return nil
  31. }
  32. if cache, ok := val.(*sessionCache); ok {
  33. return cache
  34. }
  35. return nil
  36. }
  37. func getCachedSession(ctx context.Context, addr string) *record.Session {
  38. cache := getSessionCache(ctx)
  39. if cache == nil {
  40. return nil
  41. }
  42. sess, ok := cache.Get(addr)
  43. if !ok {
  44. return nil
  45. }
  46. return sess.Record
  47. }
  48. func putCachedSession(ctx context.Context, addr string, record *record.Session) bool {
  49. cache := getSessionCache(ctx)
  50. if cache == nil {
  51. return false
  52. }
  53. cache.Set(addr, sessionCacheEntry{
  54. Dirty: true,
  55. Found: true,
  56. Record: record,
  57. })
  58. return true
  59. }
  60. func (device *Device) WithCachedSessions(ctx context.Context, addresses []string) (map[string]bool, context.Context, error) {
  61. if len(addresses) == 0 {
  62. return nil, ctx, nil
  63. }
  64. sessions, err := device.Sessions.GetManySessions(ctx, addresses)
  65. if err != nil {
  66. return nil, ctx, fmt.Errorf("failed to prefetch sessions: %w", err)
  67. }
  68. wrapped := make(map[string]sessionCacheEntry, len(sessions))
  69. existingSessions := make(map[string]bool, len(sessions))
  70. for addr, rawSess := range sessions {
  71. var sessionRecord *record.Session
  72. var found bool
  73. if rawSess == nil {
  74. sessionRecord = record.NewSession(SignalProtobufSerializer.Session, SignalProtobufSerializer.State)
  75. } else {
  76. found = true
  77. sessionRecord, err = record.NewSessionFromBytes(rawSess, SignalProtobufSerializer.Session, SignalProtobufSerializer.State)
  78. if err != nil {
  79. zerolog.Ctx(ctx).Err(err).
  80. Str("address", addr).
  81. Msg("Failed to deserialize session")
  82. continue
  83. }
  84. }
  85. existingSessions[addr] = found
  86. wrapped[addr] = sessionCacheEntry{Record: sessionRecord, Found: found}
  87. }
  88. ctx = context.WithValue(ctx, contextKeySessionCache, (*sessionCache)(exsync.NewMapWithData(wrapped)))
  89. return existingSessions, ctx, nil
  90. }
  91. func (device *Device) PutCachedSessions(ctx context.Context) error {
  92. cache := getSessionCache(ctx)
  93. if cache == nil {
  94. return nil
  95. }
  96. dirtySessions := make(map[string][]byte)
  97. for addr, item := range cache.Iter() {
  98. if item.Dirty {
  99. dirtySessions[addr] = item.Record.Serialize()
  100. }
  101. }
  102. if len(dirtySessions) > 0 {
  103. err := device.Sessions.PutManySessions(ctx, dirtySessions)
  104. if err != nil {
  105. return fmt.Errorf("failed to store cached sessions: %w", err)
  106. }
  107. }
  108. cache.Clear()
  109. return nil
  110. }