diff --git a/cryptpng.go b/cryptpng.go index e8cfbe1..6fa181c 100644 --- a/cryptpng.go +++ b/cryptpng.go @@ -25,6 +25,7 @@ func check(err error) { } } +const saltChunkName = "saLt" const chunkName = "crPt" const chunkSize = 0x100000 @@ -69,8 +70,10 @@ func EncryptDataPng(f *os.File, fin *os.File, fout *os.File) { check(err) inputData, err := ioutil.ReadAll(fin) check(err) - inputData, err = encryptData(inputData) + inputData, salt := encryptData(inputData) check(err) + saltChunk := CreateChunk(salt, saltChunkName) + png.AddMetaChunk(saltChunk) chunkCount := int(math.Ceil(float64(len(inputData)) / chunkSize)) for i := 0; i < chunkCount; i++ { dataStart := i * chunkSize @@ -87,12 +90,17 @@ func DecryptDataPng(f *os.File, fout *os.File) { png := PngData{} err := png.Read(f) check(err) + salt := make([]byte, 0) + saltChunk := png.GetChunk(saltChunkName) + if saltChunk != nil { + salt = append(salt, saltChunk.data...) + } var data []byte for _, cryptChunk := range png.GetChunksByName(chunkName) { data = append(data, cryptChunk.data...) } if len(data) > 0 { - data, err = decryptData(data) + data, err = decryptData(data, salt) if err != nil { log.Println("\nThe provided password is probably incorrect.") } @@ -105,26 +113,36 @@ func DecryptDataPng(f *os.File, fout *os.File) { } // creates an encrypted png chunk -func encryptData(data []byte) ([]byte, error) { - key := readPassword() - return encrypt(key, data) +func encryptData(data []byte) ([]byte, []byte) { + key, salt := readPassword(nil) + encData, err := encrypt(key, data) + check(err) + return encData, salt } // decrypts the data of a png chunk -func decryptData(data []byte) ([]byte, error) { - key := readPassword() +func decryptData(data []byte, salt []byte) ([]byte, error) { + key, _ := readPassword(&salt) return decrypt(key, data) } // reads a password from the terminal // turns off the input for the typing of the password -func readPassword() []byte { +func readPassword(passwordSalt *[]byte) ([]byte, []byte) { fmt.Print("Password: ") bytePw, err := terminal.ReadPassword(int(syscall.Stdin)) check(err) hash := sha512.New512_256() - hash.Write(bytePw) - return hash.Sum(nil) + if passwordSalt != nil { + hash.Write(append(*passwordSalt, bytePw...)) + return hash.Sum(nil), *passwordSalt + } else { + salt := make([]byte, 32) + _, err = io.ReadFull(rand.Reader, salt) + check(err) + hash.Write(append(salt, bytePw...)) + return hash.Sum(nil), salt + } } // encrypt and decrypt functions taken from