noisehandshake.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. // Copyright (c) 2025 Tulir Asokan
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. package socket
  7. import (
  8. "context"
  9. "crypto/cipher"
  10. "crypto/sha256"
  11. "fmt"
  12. "io"
  13. "sync/atomic"
  14. "golang.org/x/crypto/curve25519"
  15. "golang.org/x/crypto/hkdf"
  16. "go.mau.fi/whatsmeow/util/gcmutil"
  17. )
  18. type NoiseHandshake struct {
  19. hash []byte
  20. salt []byte
  21. key cipher.AEAD
  22. counter uint32
  23. }
  24. func NewNoiseHandshake() *NoiseHandshake {
  25. return &NoiseHandshake{}
  26. }
  27. func sha256Slice(data []byte) []byte {
  28. hash := sha256.Sum256(data)
  29. return hash[:]
  30. }
  31. func (nh *NoiseHandshake) Start(pattern string, header []byte) {
  32. data := []byte(pattern)
  33. if len(data) == 32 {
  34. nh.hash = data
  35. } else {
  36. nh.hash = sha256Slice(data)
  37. }
  38. nh.salt = nh.hash
  39. var err error
  40. nh.key, err = gcmutil.Prepare(nh.hash)
  41. if err != nil {
  42. panic(err)
  43. }
  44. nh.Authenticate(header)
  45. }
  46. func (nh *NoiseHandshake) Authenticate(data []byte) {
  47. nh.hash = sha256Slice(append(nh.hash, data...))
  48. }
  49. func (nh *NoiseHandshake) postIncrementCounter() uint32 {
  50. count := atomic.AddUint32(&nh.counter, 1)
  51. return count - 1
  52. }
  53. func (nh *NoiseHandshake) Encrypt(plaintext []byte) []byte {
  54. ciphertext := nh.key.Seal(nil, generateIV(nh.postIncrementCounter()), plaintext, nh.hash)
  55. nh.Authenticate(ciphertext)
  56. return ciphertext
  57. }
  58. func (nh *NoiseHandshake) Decrypt(ciphertext []byte) (plaintext []byte, err error) {
  59. plaintext, err = nh.key.Open(nil, generateIV(nh.postIncrementCounter()), ciphertext, nh.hash)
  60. if err == nil {
  61. nh.Authenticate(ciphertext)
  62. }
  63. return
  64. }
  65. func (nh *NoiseHandshake) Finish(
  66. ctx context.Context,
  67. fs *FrameSocket,
  68. frameHandler FrameHandler,
  69. disconnectHandler DisconnectHandler,
  70. ) (*NoiseSocket, error) {
  71. if write, read, err := nh.extractAndExpand(nh.salt, nil); err != nil {
  72. return nil, fmt.Errorf("failed to extract final keys: %w", err)
  73. } else if writeKey, err := gcmutil.Prepare(write); err != nil {
  74. return nil, fmt.Errorf("failed to create final write cipher: %w", err)
  75. } else if readKey, err := gcmutil.Prepare(read); err != nil {
  76. return nil, fmt.Errorf("failed to create final read cipher: %w", err)
  77. } else if ns, err := newNoiseSocket(ctx, fs, writeKey, readKey, frameHandler, disconnectHandler); err != nil {
  78. return nil, fmt.Errorf("failed to create noise socket: %w", err)
  79. } else {
  80. return ns, nil
  81. }
  82. }
  83. func (nh *NoiseHandshake) MixSharedSecretIntoKey(priv, pub [32]byte) error {
  84. secret, err := curve25519.X25519(priv[:], pub[:])
  85. if err != nil {
  86. return fmt.Errorf("failed to do x25519 scalar multiplication: %w", err)
  87. }
  88. return nh.MixIntoKey(secret)
  89. }
  90. func (nh *NoiseHandshake) MixIntoKey(data []byte) error {
  91. nh.counter = 0
  92. write, read, err := nh.extractAndExpand(nh.salt, data)
  93. if err != nil {
  94. return fmt.Errorf("failed to extract keys for mixing: %w", err)
  95. }
  96. nh.salt = write
  97. nh.key, err = gcmutil.Prepare(read)
  98. if err != nil {
  99. return fmt.Errorf("failed to create new cipher while mixing keys: %w", err)
  100. }
  101. return nil
  102. }
  103. func (nh *NoiseHandshake) extractAndExpand(salt, data []byte) (write []byte, read []byte, err error) {
  104. h := hkdf.New(sha256.New, data, salt, nil)
  105. write = make([]byte, 32)
  106. read = make([]byte, 32)
  107. if _, err = io.ReadFull(h, write); err != nil {
  108. err = fmt.Errorf("failed to read write key: %w", err)
  109. } else if _, err = io.ReadFull(h, read); err != nil {
  110. err = fmt.Errorf("failed to read read key: %w", err)
  111. }
  112. return
  113. }