diff --git a/crypt.go b/crypt.go index 71a9b80..d1e59e4 100644 --- a/crypt.go +++ b/crypt.go @@ -79,7 +79,7 @@ func encrypt(args []string) { } var infd io.Reader = os.Stdin - var outfd io.Writer = os.Stdout + var outfd io.WriteCloser = os.Stdout var inf *os.File if len(args) > 1 { diff --git a/sign/encrypt.go b/sign/encrypt.go index 2cb162f..0731acb 100644 --- a/sign/encrypt.go +++ b/sign/encrypt.go @@ -83,7 +83,8 @@ type Encryptor struct { sender *PrivateKey started bool - buf []byte + buf []byte + stream bool } // Create a new Encryption context and use the optional private key 'sk' for @@ -150,7 +151,11 @@ func (e *Encryptor) AddRecipient(pk *PublicKey) error { } // Encrypt the input stream 'rd' and write encrypted stream to 'wr' -func (e *Encryptor) Encrypt(rd io.Reader, wr io.Writer) error { +func (e *Encryptor) Encrypt(rd io.Reader, wr io.WriteCloser) error { + if e.stream { + return fmt.Errorf("encrypt: can't use Encrypt() after using streaming I/O") + } + if !e.started { err := e.start(wr) if err != nil { @@ -182,7 +187,8 @@ func (e *Encryptor) Encrypt(rd io.Reader, wr io.Writer) error { i++ } } - return nil + + return wr.Close() } // Begin the encryption process by writing the header @@ -281,8 +287,9 @@ type Decryptor struct { buf []byte // Decrypted key - key []byte - eof bool + key []byte + eof bool + stream bool } // Create a new decryption context and if 'pk' is given, check that it matches @@ -420,8 +427,12 @@ func (d *Decryptor) Decrypt(wr io.Writer) error { return fmt.Errorf("decrypt: wrapped-key not decrypted (missing SetPrivateKey()?") } + if d.stream { + return fmt.Errorf("decrypt: can't use Decrypt() after using streaming I/O") + } + if d.eof { - return fmt.Errorf("decrypt: input stream has reached EOF") + return io.EOF } var i uint32 @@ -441,7 +452,6 @@ func (d *Decryptor) Decrypt(wr io.Writer) error { d.eof = true return nil } - } return nil } diff --git a/sign/encrypt_test.go b/sign/encrypt_test.go index ce538af..a89e8f5 100644 --- a/sign/encrypt_test.go +++ b/sign/encrypt_test.go @@ -22,6 +22,14 @@ import ( "testing" ) +type Buffer struct { + bytes.Buffer +} + +func (b *Buffer) Close() error { + return nil +} + // one sender, one receiver no verification of sender func TestEncryptSimple(t *testing.T) { assert := newAsserter(t) @@ -45,7 +53,7 @@ func TestEncryptSimple(t *testing.T) { assert(err == nil, "can't add recipient: %s", err) rd := bytes.NewBuffer(buf) - wr := bytes.Buffer{} + wr := Buffer{} err = ee.Encrypt(rd, &wr) assert(err == nil, "encrypt fail: %s", err) @@ -58,7 +66,7 @@ func TestEncryptSimple(t *testing.T) { err = dd.SetPrivateKey(&receiver.Sec, nil) assert(err == nil, "decryptor can't add SK: %s", err) - wr = bytes.Buffer{} + wr = Buffer{} err = dd.Decrypt(&wr) assert(err == nil, "decrypt fail: %s", err) @@ -91,7 +99,7 @@ func TestEncryptCorrupted(t *testing.T) { assert(err == nil, "can't add recipient: %s", err) rd := bytes.NewReader(buf) - wr := bytes.Buffer{} + wr := Buffer{} err = ee.Encrypt(rd, &wr) assert(err == nil, "encrypt fail: %s", err) @@ -136,7 +144,7 @@ func TestEncryptSenderVerified(t *testing.T) { assert(err == nil, "can't add recipient: %s", err) rd := bytes.NewBuffer(buf) - wr := bytes.Buffer{} + wr := Buffer{} err = ee.Encrypt(rd, &wr) assert(err == nil, "encrypt fail: %s", err) @@ -149,7 +157,7 @@ func TestEncryptSenderVerified(t *testing.T) { err = dd.SetPrivateKey(&receiver.Sec, &sender.Pub) assert(err == nil, "decryptor can't add SK: %s", err) - wr = bytes.Buffer{} + wr = Buffer{} err = dd.Decrypt(&wr) assert(err == nil, "decrypt fail: %s", err) @@ -190,7 +198,7 @@ func TestEncryptMultiReceiver(t *testing.T) { } rd := bytes.NewBuffer(buf) - wr := bytes.Buffer{} + wr := Buffer{} err = ee.Encrypt(rd, &wr) assert(err == nil, "encrypt fail: %s", err) @@ -205,7 +213,7 @@ func TestEncryptMultiReceiver(t *testing.T) { err = dd.SetPrivateKey(&rx[i].Sec, &sender.Pub) assert(err == nil, "decryptor can't add SK %d: %s", i, err) - wr = bytes.Buffer{} + wr = Buffer{} err = dd.Decrypt(&wr) assert(err == nil, "decrypt %d fail: %s", i, err) @@ -216,6 +224,88 @@ func TestEncryptMultiReceiver(t *testing.T) { } } +// Test stream write and read +func TestStreamIO(t *testing.T) { + assert := newAsserter(t) + + receiver, err := NewKeypair() + assert(err == nil, "receiver keypair gen failed: %s", err) + + var blkSize int = 1024 + var size int = (blkSize * 10) + + // cleartext + buf := make([]byte, size) + for i := 0; i < len(buf); i++ { + buf[i] = byte(i & 0xff) + } + + ee, err := NewEncryptor(nil, uint64(blkSize)) + assert(err == nil, "encryptor create fail: %s", err) + + err = ee.AddRecipient(&receiver.Pub) + assert(err == nil, "can't add recipient: %s", err) + + wr := Buffer{} + wio, err := ee.NewStreamWriter(&wr) + assert(err == nil, "can't start stream writer: %s", err) + + // chunksize for writing to stream + csize := 19 + rbuf := buf + for len(rbuf) > 0 { + m := csize + if len(rbuf) < m { + m = len(rbuf) + } + + n, err := wio.Write(rbuf[:m]) + assert(err == nil, "stream write failed: %s", err) + assert(n == m, "stream write mismatch: exp %d, saw %d", m, n) + + rbuf = rbuf[m:] + } + err = wio.Close() + assert(err == nil, "stream close failed: %s", err) + + _, err = wio.Write(buf[:csize]) + assert(err != nil, "stream write accepted I/O after close: %s", err) + + rd := bytes.NewBuffer(wr.Bytes()) + + dd, err := NewDecryptor(rd) + assert(err == nil, "decryptor create fail: %s", err) + + err = dd.SetPrivateKey(&receiver.Sec, nil) + assert(err == nil, "decryptor can't add SK: %s", err) + + rio, err := dd.NewStreamReader() + assert(err == nil, "stream reader failed: %s", err) + + rbuf = make([]byte, csize) + wr = Buffer{} + n := 0 + for { + m, err := rio.Read(rbuf) + assert(err == nil || err == io.EOF, "streamread fail: %s", err) + + if m > 0 { + wr.Write(rbuf[:m]) + n += m + } + if err == io.EOF || m == 0 { + break + } + } + + b := wr.Bytes() + assert(n == len(b), "streamread: bad buflen; exp %d, saw %d", n, len(b)) + assert(n == len(buf), "streamread: decrypt len mismatch; exp %d, saw %d", len(buf), n) + + assert(byteEq(b, buf), "decrypt content mismatch") + +} + func randint() int { var b [4]byte diff --git a/sign/stream.go b/sign/stream.go new file mode 100644 index 0000000..92e13eb --- /dev/null +++ b/sign/stream.go @@ -0,0 +1,164 @@ +// stream.go - Streaming io.Reader, io.Writer interface to encryption/decryption +// +// (c) 2016 Sudhi Herle +// +// Licensing Terms: GPLv2 +// +// If you need a commercial license for this work, please contact +// the author. +// +// This software does not come with any express or implied +// warranty; it is provided "as is". No claim is made to its +// suitability for any purpose. +// + +package sign + +import ( + "errors" + "fmt" + "io" +) + +// encWriter buffers partial writes until a full chunk is accumulated. +// It's methods implement the io.WriteCloser interface. +type encWriter struct { + buf []byte + n int // # of bytes written + wr io.WriteCloser + e *Encryptor + blk uint32 + err error +} + +// NewStreamWriter begins stream encryption to an underlying destination writer 'wr'. +// It returns an io.WriteCloser. +func (e *Encryptor) NewStreamWriter(wr io.WriteCloser) (io.WriteCloser, error) { + if !e.started { + err := e.start(wr) + if err != nil { + return nil, err + } + } + + w := &encWriter{ + buf: make([]byte, e.ChunkSize), + wr: wr, + e: e, + } + + e.stream = true + return w, nil +} + +// Write implements the io.Writer interface +func (w *encWriter) Write(b []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + n := len(b) + if n == 0 { + return 0, nil + } + + max := int(w.e.ChunkSize) + for len(b) > 0 { + buf := w.buf[w.n:] + z := copy(buf, b) + b = b[z:] + w.n += z + + // We only flush if we have more data remaining in the input buffer. + // This way, we don't flush a potentially last block here; that happens + // when the caller eventually closes the stream. + if w.n == max && len(b) > 0 { + w.err = w.e.encrypt(w.buf, w.wr, w.blk, false) + if w.err != nil { + return 0, w.err + } + + w.n = 0 + w.blk += 1 + } + } + return n, nil +} + +// Close implements the io.Close interface +func (w *encWriter) Close() error { + if w.err != nil { + return w.err + } + + err := w.e.encrypt(w.buf[:w.n], w.wr, w.blk, true) + if err != nil { + w.err = err + return err + } + + w.n = 0 + w.err = errClosed + return w.wr.Close() +} + +// encReader buffers partial reads and it's methods implement the io.Reader interface. +type encReader struct { + buf []byte + unread []byte + d *Decryptor + blk uint32 +} + +// NewStreamReader returns an io.Reader to read from the decrypted stream +func (d *Decryptor) NewStreamReader() (io.Reader, error) { + if d.key == nil { + return nil, fmt.Errorf("streamReader: wrapped-key not decrypted (missing SetPrivateKey()?") + } + + if d.eof { + return nil, io.EOF + } + + d.stream = true + return &encReader{ + buf: make([]byte, d.ChunkSize), + d: d, + }, nil +} + +// Read implements io.Reader interface +func (r *encReader) Read(b []byte) (int, error) { + if r.d.eof && len(r.unread) == 0 { + return 0, io.EOF + } + + if len(r.unread) > 0 { + n := copy(b, r.unread) + r.unread = r.unread[n:] + return n, nil + } + + buf, eof, err := r.d.decrypt(r.blk) + if err != nil { + return 0, err + } + + r.blk += 1 + + n := copy(b, buf) + buf = buf[n:] + + copy(r.buf, buf) + r.unread = r.buf[:len(buf)] + + if eof { + r.d.eof = true + } + + return n, nil +} + +var ( + errClosed = errors.New("encrypt: stream already closed") +) diff --git a/version b/version index 6f4eebd..ac39a10 100644 --- a/version +++ b/version @@ -1 +1 @@ -0.8.1 +0.9.0