307 lines
8.7 KiB
Go
307 lines
8.7 KiB
Go
// Shared connection handler after both the client and server handshake successfully
|
|
|
|
package commons
|
|
|
|
import (
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Maximum size for a TCP packet
|
|
var ConnStandardMaxBufSize = 65535
|
|
var cryptHeaderSize = 43
|
|
|
|
// Connection used after both the Bismuth client and server negotiate and start
|
|
// transmitting data.
|
|
//
|
|
// Note that using this has the same API as the net.Conn, but it isn't conformant to
|
|
// the interface due to using pointers rather than copies to access the struct.
|
|
//
|
|
// If you need the same interface, wrap this in WrappedBismuthConn.
|
|
// Wrapping BismuthConn is done automatically by the client and server.
|
|
type BismuthConn struct {
|
|
Aead cipher.AEAD
|
|
PassedConn net.Conn
|
|
|
|
lock *sync.Mutex
|
|
|
|
initDone bool
|
|
|
|
contentBuf []byte
|
|
contentBufPos int
|
|
contentBufSize int
|
|
|
|
// Maximum buffer size to be used to internally buffer packets
|
|
MaxBufSize int
|
|
// If true, it enables using the content buffer maximum size
|
|
// instead of the TCP packet maximum size
|
|
AllowNonstandardPacketSizeLimit bool
|
|
}
|
|
|
|
func (bmConn *BismuthConn) DoInitSteps() {
|
|
bmConn.lock = &sync.Mutex{}
|
|
bmConn.contentBuf = make([]byte, bmConn.MaxBufSize)
|
|
|
|
bmConn.initDone = true
|
|
}
|
|
|
|
func (bmConn *BismuthConn) encryptMessage(msg []byte) ([]byte, error) {
|
|
nonce := make([]byte, bmConn.Aead.NonceSize(), bmConn.Aead.NonceSize()+len(msg)+bmConn.Aead.Overhead())
|
|
|
|
if _, err := rand.Read(nonce); err != nil {
|
|
return []byte{}, err
|
|
}
|
|
|
|
encryptedMsg := bmConn.Aead.Seal(nonce, nonce, msg, nil)
|
|
return encryptedMsg, nil
|
|
}
|
|
|
|
func (bmConn *BismuthConn) decryptMessage(encMsg []byte) ([]byte, error) {
|
|
if len(encMsg) < bmConn.Aead.NonceSize() {
|
|
return []byte{}, fmt.Errorf("ciphertext too short")
|
|
}
|
|
|
|
// Split nonce and ciphertext.
|
|
nonce, ciphertext := encMsg[:bmConn.Aead.NonceSize()], encMsg[bmConn.Aead.NonceSize():]
|
|
|
|
// Decrypt the message and check it wasn't tampered with.
|
|
decryptedData, err := bmConn.Aead.Open(nil, nonce, ciphertext, nil)
|
|
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
|
|
return decryptedData, nil
|
|
}
|
|
|
|
// After you update the property `bmConn.MaxBufSize`, call this function to resize the content buffer
|
|
func (bmConn *BismuthConn) ResizeContentBuf() error {
|
|
if !bmConn.initDone {
|
|
return fmt.Errorf("bmConn not initialized")
|
|
}
|
|
|
|
bmConn.lock.Lock()
|
|
|
|
if bmConn.contentBufSize != 0 {
|
|
// TODO: switch this to do append() instead, when I finally decide to consider this "optimization" in the main buffer logic
|
|
// This code below basically, instead of growing it, gets the actually unused cache data, then grows it and copies it over.
|
|
//
|
|
// This is probably unneccesary, but it saves some hassle I guess.
|
|
currentContentBufData := bmConn.contentBuf[bmConn.contentBufPos:bmConn.contentBufSize]
|
|
bmConn.contentBufSize = bmConn.contentBufSize - bmConn.contentBufPos
|
|
bmConn.contentBufPos = 0
|
|
|
|
bmConn.contentBuf = make([]byte, bmConn.MaxBufSize)
|
|
copy(bmConn.contentBuf[len(currentContentBufData):], currentContentBufData)
|
|
} else {
|
|
bmConn.contentBuf = make([]byte, bmConn.MaxBufSize)
|
|
}
|
|
|
|
bmConn.lock.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Reads specifically from the buffer only. If nothing is in the buffer, nothing is returned.
|
|
func (bmConn *BismuthConn) ReadFromBuffer(b []byte) (n int, err error) {
|
|
bmConn.lock.Lock()
|
|
defer bmConn.lock.Unlock()
|
|
|
|
calcContentBufSize := bmConn.contentBufSize - bmConn.contentBufPos
|
|
providedBufferSize := len(b)
|
|
|
|
if bmConn.contentBufSize == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
if calcContentBufSize <= providedBufferSize {
|
|
copy(b, bmConn.contentBuf[bmConn.contentBufPos:bmConn.contentBufSize])
|
|
|
|
bmConn.contentBufPos = 0
|
|
bmConn.contentBufSize = 0
|
|
bmConn.contentBuf = make([]byte, bmConn.MaxBufSize)
|
|
|
|
return calcContentBufSize, nil
|
|
} else if calcContentBufSize > providedBufferSize {
|
|
newContentBufSize := bmConn.contentBufPos + providedBufferSize
|
|
copy(b, bmConn.contentBuf[bmConn.contentBufPos:newContentBufSize])
|
|
|
|
bmConn.contentBufPos = newContentBufSize
|
|
|
|
return providedBufferSize, nil
|
|
}
|
|
|
|
return 0, nil
|
|
}
|
|
|
|
// Reads specifically from the network. Be careful as using only this may overflow the buffer.
|
|
func (bmConn *BismuthConn) ReadFromNetwork(b []byte) (n int, err error) {
|
|
bmConn.lock.Lock()
|
|
defer bmConn.lock.Unlock()
|
|
|
|
bufferSize := len(b)
|
|
|
|
encryptedContentLengthBytes := make([]byte, 3)
|
|
|
|
if _, err = bmConn.PassedConn.Read(encryptedContentLengthBytes); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
encryptedContentLength := Int24ToInt32(encryptedContentLengthBytes)
|
|
encryptedContent := make([]byte, encryptedContentLength)
|
|
|
|
// Check to see if we can fit the packet inside either:
|
|
// - the max TCP packet size (64k) if 'AllowNonstandardPacketSizeLimit' isn't set; or
|
|
// - the max buffer size if 'AllowNonstandardPacketSizeLimit' is set
|
|
// We check AFTER we read to make sure that we don't corrupt any future packets, because if we don't read the packet,
|
|
// it will think that the actual packet will be the start of the packet, and that would cause loads of problems.
|
|
if !bmConn.AllowNonstandardPacketSizeLimit && encryptedContentLength > uint32(65535+cryptHeaderSize) {
|
|
return 0, fmt.Errorf("packet too large")
|
|
} else if bmConn.AllowNonstandardPacketSizeLimit && encryptedContentLength > uint32(bmConn.MaxBufSize) {
|
|
return 0, fmt.Errorf("packet too large")
|
|
}
|
|
|
|
totalPosition := 0
|
|
|
|
for totalPosition != int(encryptedContentLength) {
|
|
currentPosition, err := bmConn.PassedConn.Read(encryptedContent[totalPosition:encryptedContentLength])
|
|
totalPosition += currentPosition
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
decryptedContent, err := bmConn.decryptMessage(encryptedContent)
|
|
decryptedContentSize := len(decryptedContent)
|
|
|
|
calcSize := min(decryptedContentSize, bufferSize)
|
|
copy(b[:calcSize], decryptedContent)
|
|
|
|
if bufferSize < int(decryptedContentSize) {
|
|
newSlice := decryptedContent[calcSize:]
|
|
|
|
if bmConn.contentBufSize+len(newSlice) > bmConn.MaxBufSize {
|
|
return 0, fmt.Errorf("ran out of room in the buffer to store data! (can't overflow the buffer...)")
|
|
}
|
|
|
|
copy(bmConn.contentBuf[bmConn.contentBufSize:bmConn.contentBufSize+len(newSlice)], newSlice)
|
|
bmConn.contentBufSize += len(newSlice)
|
|
}
|
|
|
|
if err != nil {
|
|
return calcSize, err
|
|
}
|
|
|
|
return calcSize, nil
|
|
}
|
|
|
|
// Reads from the Bismuth connection, using both the buffered and network methods
|
|
func (bmConn *BismuthConn) Read(b []byte) (n int, err error) {
|
|
if !bmConn.initDone {
|
|
return 0, fmt.Errorf("bmConn not initialized")
|
|
}
|
|
|
|
bufferReadSize, err := bmConn.ReadFromBuffer(b)
|
|
|
|
if err != nil {
|
|
return bufferReadSize, err
|
|
}
|
|
|
|
if bufferReadSize == len(b) {
|
|
return bufferReadSize, nil
|
|
}
|
|
|
|
networkReadSize, err := bmConn.ReadFromNetwork(b[bufferReadSize:])
|
|
|
|
if err != nil {
|
|
return bufferReadSize + networkReadSize, err
|
|
}
|
|
|
|
return bufferReadSize + networkReadSize, nil
|
|
}
|
|
|
|
// Encrypts and sends off a message
|
|
func (bmConn *BismuthConn) Write(b []byte) (n int, err error) {
|
|
encryptedMessage, err := bmConn.encryptMessage(b)
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
encryptedMessageSize := make([]byte, 3)
|
|
Int32ToInt24(encryptedMessageSize, uint32(len(encryptedMessage)))
|
|
|
|
bmConn.PassedConn.Write(encryptedMessageSize)
|
|
bmConn.PassedConn.Write(encryptedMessage)
|
|
|
|
return len(b), nil
|
|
}
|
|
|
|
func (bmConn *BismuthConn) Close() error {
|
|
return bmConn.PassedConn.Close()
|
|
}
|
|
|
|
func (bmConn *BismuthConn) LocalAddr() net.Addr {
|
|
return bmConn.PassedConn.LocalAddr()
|
|
}
|
|
|
|
func (bmConn *BismuthConn) RemoteAddr() net.Addr {
|
|
return bmConn.PassedConn.RemoteAddr()
|
|
}
|
|
|
|
func (bmConn *BismuthConn) SetDeadline(time time.Time) error {
|
|
return bmConn.PassedConn.SetDeadline(time)
|
|
}
|
|
|
|
func (bmConn *BismuthConn) SetReadDeadline(time time.Time) error {
|
|
return bmConn.PassedConn.SetReadDeadline(time)
|
|
}
|
|
|
|
func (bmConn *BismuthConn) SetWriteDeadline(time time.Time) error {
|
|
return bmConn.PassedConn.SetWriteDeadline(time)
|
|
}
|
|
|
|
// Wrapped BismuthConn struct. This is conformant to net.Conn, unlike above.
|
|
// To get the raw Bismuth struct, just get the Bismuth property:
|
|
//
|
|
// `bmConn.Bismuth` -> `BismuthConn`
|
|
type BismuthConnWrapped struct {
|
|
Bismuth *BismuthConn
|
|
}
|
|
|
|
func (bmConn BismuthConnWrapped) Read(b []byte) (n int, err error) {
|
|
return bmConn.Bismuth.Read(b)
|
|
}
|
|
|
|
func (bmConn BismuthConnWrapped) Write(b []byte) (n int, err error) {
|
|
return bmConn.Bismuth.Write(b)
|
|
}
|
|
|
|
func (bmConn BismuthConnWrapped) Close() error {
|
|
return bmConn.Bismuth.Close()
|
|
}
|
|
|
|
func (bmConn BismuthConnWrapped) LocalAddr() net.Addr {
|
|
return bmConn.Bismuth.LocalAddr()
|
|
}
|
|
|
|
func (bmConn BismuthConnWrapped) RemoteAddr() net.Addr {
|
|
return bmConn.Bismuth.RemoteAddr()
|
|
}
|
|
|
|
func (bmConn BismuthConnWrapped) SetDeadline(time time.Time) error {
|
|
return bmConn.Bismuth.SetDeadline(time)
|
|
}
|
|
|
|
func (bmConn BismuthConnWrapped) SetReadDeadline(time time.Time) error {
|
|
return bmConn.Bismuth.SetReadDeadline(time)
|
|
}
|
|
|
|
func (bmConn BismuthConnWrapped) SetWriteDeadline(time time.Time) error {
|
|
return bmConn.Bismuth.SetWriteDeadline(time)
|
|
}
|