cbc.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. /*
  2. CBC describes a block cipher mode. In cryptography, a block cipher mode of operation is an algorithm that uses a
  3. block cipher to provide an information service such as confidentiality or authenticity. A block cipher by itself
  4. is only suitable for the secure cryptographic transformation (encryption or decryption) of one fixed-length group of
  5. bits called a block. A mode of operation describes how to repeatedly apply a cipher's single-block operation to
  6. securely transform amounts of data larger than a block.
  7. This package simplifies the usage of AES-256-CBC.
  8. */
  9. package cbcutil
  10. /*
  11. Some code is provided by the GitHub user locked (github.com/locked):
  12. https://gist.github.com/locked/b066aa1ddeb2b28e855e
  13. Thanks!
  14. */
  15. import (
  16. "bytes"
  17. "crypto/aes"
  18. "crypto/cipher"
  19. "crypto/hmac"
  20. "crypto/rand"
  21. "crypto/sha256"
  22. "errors"
  23. "fmt"
  24. "io"
  25. "os"
  26. )
  27. /*
  28. Decrypt is a function that decrypts a given cipher text with a provided key and initialization vector(iv).
  29. */
  30. func Decrypt(key, iv, ciphertext []byte) ([]byte, error) {
  31. block, err := aes.NewCipher(key)
  32. if err != nil {
  33. return nil, err
  34. } else if len(ciphertext) < aes.BlockSize {
  35. return nil, fmt.Errorf("ciphertext is shorter then block size: %d / %d", len(ciphertext), aes.BlockSize)
  36. }
  37. cbc := cipher.NewCBCDecrypter(block, iv)
  38. cbc.CryptBlocks(ciphertext, ciphertext)
  39. return unpad(ciphertext)
  40. }
  41. type File interface {
  42. io.Reader
  43. io.WriterAt
  44. Truncate(size int64) error
  45. Stat() (os.FileInfo, error)
  46. }
  47. func DecryptFile(key, iv []byte, file File) error {
  48. block, err := aes.NewCipher(key)
  49. if err != nil {
  50. return err
  51. }
  52. cbc := cipher.NewCBCDecrypter(block, iv)
  53. stat, err := file.Stat()
  54. if err != nil {
  55. return fmt.Errorf("failed to stat file: %w", err)
  56. }
  57. fileSize := stat.Size()
  58. if fileSize%aes.BlockSize != 0 {
  59. return fmt.Errorf("file size is not a multiple of the block size: %d / %d", fileSize, aes.BlockSize)
  60. }
  61. var bufSize int64 = 32 * 1024
  62. if fileSize < bufSize {
  63. bufSize = fileSize
  64. }
  65. buf := make([]byte, bufSize)
  66. var writePtr int64
  67. var lastByte byte
  68. for writePtr < fileSize {
  69. if writePtr+bufSize > fileSize {
  70. buf = buf[:fileSize-writePtr]
  71. }
  72. var n int
  73. n, err = io.ReadFull(file, buf)
  74. if err != nil {
  75. return fmt.Errorf("failed to read file: %w", err)
  76. } else if n != len(buf) {
  77. return fmt.Errorf("failed to read full buffer: %d / %d", n, len(buf))
  78. }
  79. cbc.CryptBlocks(buf, buf)
  80. n, err = file.WriteAt(buf, writePtr)
  81. if err != nil {
  82. return fmt.Errorf("failed to write file: %w", err)
  83. } else if n != len(buf) {
  84. return fmt.Errorf("failed to write full buffer: %d / %d", n, len(buf))
  85. }
  86. writePtr += int64(len(buf))
  87. lastByte = buf[len(buf)-1]
  88. }
  89. if int64(lastByte) > fileSize {
  90. return fmt.Errorf("padding is greater then the length: %d / %d", lastByte, fileSize)
  91. }
  92. err = file.Truncate(fileSize - int64(lastByte))
  93. if err != nil {
  94. return fmt.Errorf("failed to truncate file to remove padding: %w", err)
  95. }
  96. return nil
  97. }
  98. /*
  99. Encrypt is a function that encrypts plaintext with a given key and an optional initialization vector(iv).
  100. */
  101. func Encrypt(key, iv, plaintext []byte) ([]byte, error) {
  102. sizeOfLastBlock := len(plaintext) % aes.BlockSize
  103. paddingLen := aes.BlockSize - sizeOfLastBlock
  104. plaintextStart := plaintext[:len(plaintext)-sizeOfLastBlock]
  105. lastBlock := append(plaintext[len(plaintext)-sizeOfLastBlock:], bytes.Repeat([]byte{byte(paddingLen)}, paddingLen)...)
  106. if len(plaintextStart)%aes.BlockSize != 0 {
  107. panic(fmt.Errorf("plaintext is not the correct size: %d %% %d != 0", len(plaintextStart), aes.BlockSize))
  108. }
  109. if len(lastBlock) != aes.BlockSize {
  110. panic(fmt.Errorf("last block is not the correct size: %d != %d", len(lastBlock), aes.BlockSize))
  111. }
  112. block, err := aes.NewCipher(key)
  113. if err != nil {
  114. return nil, err
  115. }
  116. var ciphertext []byte
  117. if iv == nil {
  118. ciphertext = make([]byte, aes.BlockSize+len(plaintext)+paddingLen)
  119. iv := ciphertext[:aes.BlockSize]
  120. if _, err := io.ReadFull(rand.Reader, iv); err != nil {
  121. return nil, err
  122. }
  123. cbc := cipher.NewCBCEncrypter(block, iv)
  124. cbc.CryptBlocks(ciphertext[aes.BlockSize:], plaintextStart)
  125. cbc.CryptBlocks(ciphertext[aes.BlockSize+len(plaintextStart):], lastBlock)
  126. } else {
  127. ciphertext = make([]byte, len(plaintext)+paddingLen, len(plaintext)+paddingLen+10)
  128. cbc := cipher.NewCBCEncrypter(block, iv)
  129. cbc.CryptBlocks(ciphertext, plaintextStart)
  130. cbc.CryptBlocks(ciphertext[len(plaintextStart):], lastBlock)
  131. }
  132. return ciphertext, nil
  133. }
  134. func unpad(src []byte) ([]byte, error) {
  135. length := len(src)
  136. padLen := int(src[length-1])
  137. if padLen > length {
  138. return nil, fmt.Errorf("padding is greater then the length: %d / %d", padLen, length)
  139. }
  140. return src[:(length - padLen)], nil
  141. }
  142. func EncryptStream(key, iv, macKey []byte, plaintext io.Reader, ciphertext io.Writer) ([]byte, []byte, uint64, uint64, error) {
  143. block, err := aes.NewCipher(key)
  144. if err != nil {
  145. return nil, nil, 0, 0, fmt.Errorf("failed to create cipher: %w", err)
  146. }
  147. cbc := cipher.NewCBCEncrypter(block, iv)
  148. plainHasher := sha256.New()
  149. cipherHasher := sha256.New()
  150. cipherMAC := hmac.New(sha256.New, macKey)
  151. cipherMAC.Write(iv)
  152. writerAt, hasWriterAt := ciphertext.(io.WriterAt)
  153. buf := make([]byte, 32*1024)
  154. var size, extraSize int
  155. var writePtr int64
  156. hasMore := true
  157. for hasMore {
  158. var n int
  159. n, err = io.ReadFull(plaintext, buf)
  160. plainHasher.Write(buf[:n])
  161. size += n
  162. if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
  163. padding := aes.BlockSize - size%aes.BlockSize
  164. buf = append(buf[:n], bytes.Repeat([]byte{byte(padding)}, padding)...)
  165. extraSize = padding
  166. hasMore = false
  167. } else if err != nil {
  168. return nil, nil, 0, 0, fmt.Errorf("failed to read file: %w", err)
  169. }
  170. cbc.CryptBlocks(buf, buf)
  171. cipherMAC.Write(buf)
  172. cipherHasher.Write(buf)
  173. if hasWriterAt {
  174. _, err = writerAt.WriteAt(buf, writePtr)
  175. writePtr += int64(len(buf))
  176. } else {
  177. _, err = ciphertext.Write(buf)
  178. }
  179. if err != nil {
  180. return nil, nil, 0, 0, fmt.Errorf("failed to write file: %w", err)
  181. }
  182. }
  183. mac := cipherMAC.Sum(nil)[:10]
  184. extraSize += 10
  185. cipherHasher.Write(mac)
  186. if hasWriterAt {
  187. _, err = writerAt.WriteAt(mac, writePtr)
  188. } else {
  189. _, err = ciphertext.Write(mac)
  190. }
  191. if err != nil {
  192. return nil, nil, 0, 0, fmt.Errorf("failed to write checksum to file: %w", err)
  193. }
  194. return plainHasher.Sum(nil), cipherHasher.Sum(nil), uint64(size), uint64(size + extraSize), nil
  195. }