noisesocket.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. // Copyright (c) 2025 Tulir Asokan
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. package socket
  7. import (
  8. "context"
  9. "crypto/cipher"
  10. "encoding/binary"
  11. "sync"
  12. "sync/atomic"
  13. "github.com/coder/websocket"
  14. )
  15. type NoiseSocket struct {
  16. fs *FrameSocket
  17. onFrame FrameHandler
  18. writeKey cipher.AEAD
  19. readKey cipher.AEAD
  20. writeCounter uint32
  21. readCounter uint32
  22. writeLock sync.Mutex
  23. destroyed atomic.Bool
  24. stopConsumer chan struct{}
  25. }
  26. type DisconnectHandler func(ctx context.Context, socket *NoiseSocket, remote bool)
  27. type FrameHandler func(context.Context, []byte)
  28. func newNoiseSocket(
  29. ctx context.Context,
  30. fs *FrameSocket,
  31. writeKey, readKey cipher.AEAD,
  32. frameHandler FrameHandler,
  33. disconnectHandler DisconnectHandler,
  34. ) (*NoiseSocket, error) {
  35. ns := &NoiseSocket{
  36. fs: fs,
  37. writeKey: writeKey,
  38. readKey: readKey,
  39. onFrame: frameHandler,
  40. stopConsumer: make(chan struct{}),
  41. }
  42. fs.OnDisconnect = func(ctx context.Context, remote bool) {
  43. disconnectHandler(ctx, ns, remote)
  44. }
  45. go ns.consumeFrames(ctx, fs.Frames)
  46. return ns, nil
  47. }
  48. func (ns *NoiseSocket) consumeFrames(ctx context.Context, frames <-chan []byte) {
  49. if ctx == nil {
  50. // ctx being nil implies the connection already closed somehow
  51. return
  52. }
  53. ctxDone := ctx.Done()
  54. for {
  55. select {
  56. case frame := <-frames:
  57. ns.receiveEncryptedFrame(ctx, frame)
  58. case <-ctxDone:
  59. return
  60. case <-ns.stopConsumer:
  61. return
  62. }
  63. }
  64. }
  65. func generateIV(count uint32) []byte {
  66. iv := make([]byte, 12)
  67. binary.BigEndian.PutUint32(iv[8:], count)
  68. return iv
  69. }
  70. func (ns *NoiseSocket) Stop(disconnect bool) {
  71. if ns.destroyed.CompareAndSwap(false, true) {
  72. close(ns.stopConsumer)
  73. ns.fs.OnDisconnect = nil
  74. if disconnect {
  75. ns.fs.Close(websocket.StatusNormalClosure)
  76. }
  77. }
  78. }
  79. func (ns *NoiseSocket) SendFrame(ctx context.Context, plaintext []byte) error {
  80. ns.writeLock.Lock()
  81. defer ns.writeLock.Unlock()
  82. if ctx.Err() != nil {
  83. return ctx.Err()
  84. }
  85. // Don't reuse plaintext slice for storage as it may be needed for retries
  86. ciphertext := ns.writeKey.Seal(nil, generateIV(ns.writeCounter), plaintext, nil)
  87. ns.writeCounter++
  88. doneChan := make(chan error, 1)
  89. go func() {
  90. doneChan <- ns.fs.SendFrame(ciphertext)
  91. }()
  92. select {
  93. case <-ctx.Done():
  94. return ctx.Err()
  95. case retErr := <-doneChan:
  96. return retErr
  97. }
  98. }
  99. func (ns *NoiseSocket) receiveEncryptedFrame(ctx context.Context, ciphertext []byte) {
  100. plaintext, err := ns.readKey.Open(ciphertext[:0], generateIV(ns.readCounter), ciphertext, nil)
  101. ns.readCounter++
  102. if err != nil {
  103. ns.fs.log.Warnf("Failed to decrypt frame: %v", err)
  104. return
  105. }
  106. ns.onFrame(ctx, plaintext)
  107. }
  108. func (ns *NoiseSocket) IsConnected() bool {
  109. return ns.fs.IsConnected()
  110. }