diff --git a/nip44.go b/nip44.go index d0b3312..c4298f6 100644 --- a/nip44.go +++ b/nip44.go @@ -65,7 +65,9 @@ func Encrypt(conversationKey []byte, plaintext string, options *EncryptOptions) if ciphertext, err = chacha20_(enc, nonce, []byte(padded)); err != nil { return "", err } - hmac_ = sha256Hmac(auth, ciphertext, salt) + if hmac_, err = sha256Hmac(auth, ciphertext, salt); err != nil { + return "", err + } concat = append(concat, []byte{byte(version)}...) concat = append(concat, salt...) concat = append(concat, ciphertext...) @@ -81,6 +83,7 @@ func Decrypt(conversationKey []byte, ciphertext string) (string, error) { dLen int salt []byte ciphertext_ []byte + hmac []byte hmac_ []byte enc []byte nonce []byte @@ -111,7 +114,10 @@ func Decrypt(conversationKey []byte, ciphertext string) (string, error) { if enc, nonce, auth, err = messageKeys(conversationKey, salt); err != nil { return "", err } - if !bytes.Equal(hmac_, sha256Hmac(auth, ciphertext_, salt)) { + if hmac, err = sha256Hmac(auth, ciphertext_, salt); err != nil { + return "", err + } + if !bytes.Equal(hmac_, hmac) { return "", errors.New("invalid hmac") } if padded, err = chacha20_(enc, nonce, ciphertext_); err != nil { @@ -163,11 +169,14 @@ func randomBytes(n int) ([]byte, error) { return buf, nil } -func sha256Hmac(key []byte, ciphertext []byte, nonce []byte) []byte { +func sha256Hmac(key []byte, ciphertext []byte, aad []byte) ([]byte, error) { + if len(aad) != 32 { + return nil, errors.New("aad data must be 32 bytes") + } h := hmac.New(sha256.New, key) - h.Write(nonce) + h.Write(aad) h.Write(ciphertext) - return h.Sum(nil) + return h.Sum(nil), nil } func messageKeys(conversationKey []byte, salt []byte) ([]byte, []byte, []byte, error) { @@ -178,6 +187,12 @@ func messageKeys(conversationKey []byte, salt []byte) ([]byte, []byte, []byte, e auth []byte = make([]byte, 32) err error ) + if len(conversationKey) != 32 { + return nil, nil, nil, errors.New("conversation key must be 32 bytes") + } + if len(salt) != 32 { + return nil, nil, nil, errors.New("salt must be 32 bytes") + } r = hkdf.Expand(sha256.New, conversationKey, salt) if _, err = io.ReadFull(r, enc); err != nil { return nil, nil, nil, err