| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- // Copyright (c) 2025 Tulir Asokan
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this
- // file, You can obtain one at http://mozilla.org/MPL/2.0/.
- package store
- import (
- "context"
- "fmt"
- "github.com/rs/zerolog"
- "go.mau.fi/libsignal/state/record"
- "go.mau.fi/util/exsync"
- )
- type contextKey int
- const (
- contextKeySessionCache contextKey = iota
- )
- type sessionCacheEntry struct {
- Dirty bool
- Found bool
- Record *record.Session
- }
- type sessionCache = exsync.Map[string, sessionCacheEntry]
- func getSessionCache(ctx context.Context) *sessionCache {
- if ctx == nil {
- return nil
- }
- val := ctx.Value(contextKeySessionCache)
- if val == nil {
- return nil
- }
- if cache, ok := val.(*sessionCache); ok {
- return cache
- }
- return nil
- }
- func getCachedSession(ctx context.Context, addr string) *record.Session {
- cache := getSessionCache(ctx)
- if cache == nil {
- return nil
- }
- sess, ok := cache.Get(addr)
- if !ok {
- return nil
- }
- return sess.Record
- }
- func putCachedSession(ctx context.Context, addr string, record *record.Session) bool {
- cache := getSessionCache(ctx)
- if cache == nil {
- return false
- }
- cache.Set(addr, sessionCacheEntry{
- Dirty: true,
- Found: true,
- Record: record,
- })
- return true
- }
- func (device *Device) WithCachedSessions(ctx context.Context, addresses []string) (map[string]bool, context.Context, error) {
- if len(addresses) == 0 {
- return nil, ctx, nil
- }
- sessions, err := device.Sessions.GetManySessions(ctx, addresses)
- if err != nil {
- return nil, ctx, fmt.Errorf("failed to prefetch sessions: %w", err)
- }
- wrapped := make(map[string]sessionCacheEntry, len(sessions))
- existingSessions := make(map[string]bool, len(sessions))
- for addr, rawSess := range sessions {
- var sessionRecord *record.Session
- var found bool
- if rawSess == nil {
- sessionRecord = record.NewSession(SignalProtobufSerializer.Session, SignalProtobufSerializer.State)
- } else {
- found = true
- sessionRecord, err = record.NewSessionFromBytes(rawSess, SignalProtobufSerializer.Session, SignalProtobufSerializer.State)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).
- Str("address", addr).
- Msg("Failed to deserialize session")
- continue
- }
- }
- existingSessions[addr] = found
- wrapped[addr] = sessionCacheEntry{Record: sessionRecord, Found: found}
- }
- ctx = context.WithValue(ctx, contextKeySessionCache, (*sessionCache)(exsync.NewMapWithData(wrapped)))
- return existingSessions, ctx, nil
- }
- func (device *Device) PutCachedSessions(ctx context.Context) error {
- cache := getSessionCache(ctx)
- if cache == nil {
- return nil
- }
- dirtySessions := make(map[string][]byte)
- for addr, item := range cache.Iter() {
- if item.Dirty {
- dirtySessions[addr] = item.Record.Serialize()
- }
- }
- if len(dirtySessions) > 0 {
- err := device.Sessions.PutManySessions(ctx, dirtySessions)
- if err != nil {
- return fmt.Errorf("failed to store cached sessions: %w", err)
- }
- }
- cache.Clear()
- return nil
- }
|