bismuthd/commons/conn.go
2024-10-19 18:06:49 -04:00

307 lines
8.7 KiB
Go

// Shared connection handler after both the client and server handshake successfully
package commons
import (
"crypto/cipher"
"crypto/rand"
"fmt"
"net"
"sync"
"time"
)
// Maximum size for a TCP packet
var ConnStandardMaxBufSize = 65535
var cryptHeaderSize = 43
// Connection used after both the Bismuth client and server negotiate and start
// transmitting data.
//
// Note that using this has the same API as the net.Conn, but it isn't conformant to
// the interface due to using pointers rather than copies to access the struct.
//
// If you need the same interface, wrap this in WrappedBismuthConn.
// Wrapping BismuthConn is done automatically by the client and server.
type BismuthConn struct {
Aead cipher.AEAD
PassedConn net.Conn
lock *sync.Mutex
initDone bool
contentBuf []byte
contentBufPos int
contentBufSize int
// Maximum buffer size to be used to internally buffer packets
MaxBufSize int
// If true, it enables using the content buffer maximum size
// instead of the TCP packet maximum size
AllowNonstandardPacketSizeLimit bool
}
func (bmConn *BismuthConn) DoInitSteps() {
bmConn.lock = &sync.Mutex{}
bmConn.contentBuf = make([]byte, bmConn.MaxBufSize)
bmConn.initDone = true
}
func (bmConn *BismuthConn) encryptMessage(msg []byte) ([]byte, error) {
nonce := make([]byte, bmConn.Aead.NonceSize(), bmConn.Aead.NonceSize()+len(msg)+bmConn.Aead.Overhead())
if _, err := rand.Read(nonce); err != nil {
return []byte{}, err
}
encryptedMsg := bmConn.Aead.Seal(nonce, nonce, msg, nil)
return encryptedMsg, nil
}
func (bmConn *BismuthConn) decryptMessage(encMsg []byte) ([]byte, error) {
if len(encMsg) < bmConn.Aead.NonceSize() {
return []byte{}, fmt.Errorf("ciphertext too short")
}
// Split nonce and ciphertext.
nonce, ciphertext := encMsg[:bmConn.Aead.NonceSize()], encMsg[bmConn.Aead.NonceSize():]
// Decrypt the message and check it wasn't tampered with.
decryptedData, err := bmConn.Aead.Open(nil, nonce, ciphertext, nil)
if err != nil {
return []byte{}, err
}
return decryptedData, nil
}
// After you update the property `bmConn.MaxBufSize`, call this function to resize the content buffer
func (bmConn *BismuthConn) ResizeContentBuf() error {
if !bmConn.initDone {
return fmt.Errorf("bmConn not initialized")
}
bmConn.lock.Lock()
if bmConn.contentBufSize != 0 {
// TODO: switch this to do append() instead, when I finally decide to consider this "optimization" in the main buffer logic
// This code below basically, instead of growing it, gets the actually unused cache data, then grows it and copies it over.
//
// This is probably unneccesary, but it saves some hassle I guess.
currentContentBufData := bmConn.contentBuf[bmConn.contentBufPos:bmConn.contentBufSize]
bmConn.contentBufSize = bmConn.contentBufSize - bmConn.contentBufPos
bmConn.contentBufPos = 0
bmConn.contentBuf = make([]byte, bmConn.MaxBufSize)
copy(bmConn.contentBuf[len(currentContentBufData):], currentContentBufData)
} else {
bmConn.contentBuf = make([]byte, bmConn.MaxBufSize)
}
bmConn.lock.Unlock()
return nil
}
// Reads specifically from the buffer only. If nothing is in the buffer, nothing is returned.
func (bmConn *BismuthConn) ReadFromBuffer(b []byte) (n int, err error) {
bmConn.lock.Lock()
defer bmConn.lock.Unlock()
calcContentBufSize := bmConn.contentBufSize - bmConn.contentBufPos
providedBufferSize := len(b)
if bmConn.contentBufSize == 0 {
return 0, nil
}
if calcContentBufSize <= providedBufferSize {
copy(b, bmConn.contentBuf[bmConn.contentBufPos:bmConn.contentBufSize])
bmConn.contentBufPos = 0
bmConn.contentBufSize = 0
bmConn.contentBuf = make([]byte, bmConn.MaxBufSize)
return calcContentBufSize, nil
} else if calcContentBufSize > providedBufferSize {
newContentBufSize := bmConn.contentBufPos + providedBufferSize
copy(b, bmConn.contentBuf[bmConn.contentBufPos:newContentBufSize])
bmConn.contentBufPos = newContentBufSize
return providedBufferSize, nil
}
return 0, nil
}
// Reads specifically from the network. Be careful as using only this may overflow the buffer.
func (bmConn *BismuthConn) ReadFromNetwork(b []byte) (n int, err error) {
bmConn.lock.Lock()
defer bmConn.lock.Unlock()
bufferSize := len(b)
encryptedContentLengthBytes := make([]byte, 3)
if _, err = bmConn.PassedConn.Read(encryptedContentLengthBytes); err != nil {
return 0, err
}
encryptedContentLength := Int24ToInt32(encryptedContentLengthBytes)
encryptedContent := make([]byte, encryptedContentLength)
// Check to see if we can fit the packet inside either:
// - the max TCP packet size (64k) if 'AllowNonstandardPacketSizeLimit' isn't set; or
// - the max buffer size if 'AllowNonstandardPacketSizeLimit' is set
// We check AFTER we read to make sure that we don't corrupt any future packets, because if we don't read the packet,
// it will think that the actual packet will be the start of the packet, and that would cause loads of problems.
if !bmConn.AllowNonstandardPacketSizeLimit && encryptedContentLength > uint32(65535+cryptHeaderSize) {
return 0, fmt.Errorf("packet too large")
} else if bmConn.AllowNonstandardPacketSizeLimit && encryptedContentLength > uint32(bmConn.MaxBufSize) {
return 0, fmt.Errorf("packet too large")
}
totalPosition := 0
for totalPosition != int(encryptedContentLength) {
currentPosition, err := bmConn.PassedConn.Read(encryptedContent[totalPosition:encryptedContentLength])
totalPosition += currentPosition
if err != nil {
return 0, err
}
}
decryptedContent, err := bmConn.decryptMessage(encryptedContent)
decryptedContentSize := len(decryptedContent)
calcSize := min(decryptedContentSize, bufferSize)
copy(b[:calcSize], decryptedContent)
if bufferSize < int(decryptedContentSize) {
newSlice := decryptedContent[calcSize:]
if bmConn.contentBufSize+len(newSlice) > bmConn.MaxBufSize {
return 0, fmt.Errorf("ran out of room in the buffer to store data! (can't overflow the buffer...)")
}
copy(bmConn.contentBuf[bmConn.contentBufSize:bmConn.contentBufSize+len(newSlice)], newSlice)
bmConn.contentBufSize += len(newSlice)
}
if err != nil {
return calcSize, err
}
return calcSize, nil
}
// Reads from the Bismuth connection, using both the buffered and network methods
func (bmConn *BismuthConn) Read(b []byte) (n int, err error) {
if !bmConn.initDone {
return 0, fmt.Errorf("bmConn not initialized")
}
bufferReadSize, err := bmConn.ReadFromBuffer(b)
if err != nil {
return bufferReadSize, err
}
if bufferReadSize == len(b) {
return bufferReadSize, nil
}
networkReadSize, err := bmConn.ReadFromNetwork(b[bufferReadSize:])
if err != nil {
return bufferReadSize + networkReadSize, err
}
return bufferReadSize + networkReadSize, nil
}
// Encrypts and sends off a message
func (bmConn *BismuthConn) Write(b []byte) (n int, err error) {
encryptedMessage, err := bmConn.encryptMessage(b)
if err != nil {
return 0, err
}
encryptedMessageSize := make([]byte, 3)
Int32ToInt24(encryptedMessageSize, uint32(len(encryptedMessage)))
bmConn.PassedConn.Write(encryptedMessageSize)
bmConn.PassedConn.Write(encryptedMessage)
return len(b), nil
}
func (bmConn *BismuthConn) Close() error {
return bmConn.PassedConn.Close()
}
func (bmConn *BismuthConn) LocalAddr() net.Addr {
return bmConn.PassedConn.LocalAddr()
}
func (bmConn *BismuthConn) RemoteAddr() net.Addr {
return bmConn.PassedConn.RemoteAddr()
}
func (bmConn *BismuthConn) SetDeadline(time time.Time) error {
return bmConn.PassedConn.SetDeadline(time)
}
func (bmConn *BismuthConn) SetReadDeadline(time time.Time) error {
return bmConn.PassedConn.SetReadDeadline(time)
}
func (bmConn *BismuthConn) SetWriteDeadline(time time.Time) error {
return bmConn.PassedConn.SetWriteDeadline(time)
}
// Wrapped BismuthConn struct. This is conformant to net.Conn, unlike above.
// To get the raw Bismuth struct, just get the Bismuth property:
//
// `bmConn.Bismuth` -> `BismuthConn`
type BismuthConnWrapped struct {
Bismuth *BismuthConn
}
func (bmConn BismuthConnWrapped) Read(b []byte) (n int, err error) {
return bmConn.Bismuth.Read(b)
}
func (bmConn BismuthConnWrapped) Write(b []byte) (n int, err error) {
return bmConn.Bismuth.Write(b)
}
func (bmConn BismuthConnWrapped) Close() error {
return bmConn.Bismuth.Close()
}
func (bmConn BismuthConnWrapped) LocalAddr() net.Addr {
return bmConn.Bismuth.LocalAddr()
}
func (bmConn BismuthConnWrapped) RemoteAddr() net.Addr {
return bmConn.Bismuth.RemoteAddr()
}
func (bmConn BismuthConnWrapped) SetDeadline(time time.Time) error {
return bmConn.Bismuth.SetDeadline(time)
}
func (bmConn BismuthConnWrapped) SetReadDeadline(time time.Time) error {
return bmConn.Bismuth.SetReadDeadline(time)
}
func (bmConn BismuthConnWrapped) SetWriteDeadline(time time.Time) error {
return bmConn.Bismuth.SetWriteDeadline(time)
}