bismuthd/signingserver/signingserver.go
2024-10-24 11:43:47 -04:00

198 lines
5.6 KiB
Go

package signingserver
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"net"
"strings"
core "git.greysoh.dev/imterah/bismuthd/commons"
"git.greysoh.dev/imterah/bismuthd/server"
)
func (signServer *BismuthSigningServer) InitializeServer() error {
if signServer.AddVerifyHandler == nil {
fmt.Println("WARN: You are using the default AddVerifyHandler in SignServer! This is a bad idea. Please write your own implementation!")
signServer.AddVerifyHandler = func(serverAddr string, serverKeyFingerprint string, serverDomainList []string, additionalClientProvidedInfo string) (bool, error) {
domainListHash := sha256.Sum256([]byte(strings.Join(serverDomainList, ":")))
signServer.builtInVerifyMapStore[serverAddr+".fingerprint"] = serverKeyFingerprint
signServer.builtInVerifyMapStore[serverAddr+".domainListHash"] = hex.EncodeToString(domainListHash[:])
return true, nil
}
}
if signServer.VerifyServerHandler == nil {
fmt.Println("WARN: You are using the default VerifyServerHandler in SignServer! This is a bad idea. Please write your own implementation!")
signServer.VerifyServerHandler = func(serverAddr string, serverKeyFingerprint string, serverDomainList []string) (bool, error) {
domainListHash := sha256.Sum256([]byte(strings.Join(serverDomainList, ":")))
domainListHashHex := hex.EncodeToString(domainListHash[:])
if storedKeyFingerprint, ok := signServer.builtInVerifyMapStore[serverAddr+".fingerprint"]; ok {
if storedKeyFingerprint != serverKeyFingerprint {
return false, nil
}
} else {
return false, nil
}
if storedDomainListHashHex, ok := signServer.builtInVerifyMapStore[serverAddr+".domainListHash"]; ok {
if storedDomainListHashHex != domainListHashHex {
return false, nil
}
} else {
return false, nil
}
return true, nil
}
}
if signServer.builtInVerifyMapStore == nil {
signServer.builtInVerifyMapStore = map[string]string{}
}
signServer.BismuthServer.HandleConnection = signServer.connHandler
return nil
}
func (signServer *BismuthSigningServer) connHandler(conn net.Conn, metadata *server.ClientMetadata) error {
defer conn.Close()
requestType := make([]byte, 1)
hostAndIP := conn.RemoteAddr().String()
hostAndIPColonIndex := strings.Index(hostAndIP, ":")
if hostAndIPColonIndex == -1 {
return fmt.Errorf("failed to get colon in remote address")
}
host := hostAndIP[:hostAndIPColonIndex]
clientKeyFingerprint := metadata.ClientPublicKey.GetFingerprint()
for {
if _, err := conn.Read(requestType); err != nil {
return err
}
if requestType[0] == core.AreDomainsValidForKey {
// This is probably a bit too big, but I'd like to air on the side of caution here...
keyFingerprintLength := make([]byte, 2)
if _, err := conn.Read(keyFingerprintLength); err != nil {
return err
}
keyFingerprintBytes := make([]byte, binary.BigEndian.Uint16(keyFingerprintLength))
if _, err := conn.Read(keyFingerprintBytes); err != nil {
return err
}
keyFingerprint := hex.EncodeToString(keyFingerprintBytes)
serverDomainListLength := make([]byte, 2)
if _, err := conn.Read(serverDomainListLength); err != nil {
return err
}
serverDomainListBytes := make([]byte, binary.BigEndian.Uint16(serverDomainListLength))
if _, err := conn.Read(serverDomainListBytes); err != nil {
return err
}
serverDomainList := strings.Split(string(serverDomainListBytes), "\n")
// We can't trust anything if they aren't advertising any domains/IPs
if len(serverDomainList) == 0 {
requestResponse := make([]byte, 1)
requestResponse[0] = core.Failure
conn.Write(requestResponse)
continue
}
isVerified, err := signServer.VerifyServerHandler(host, keyFingerprint, serverDomainList)
if err != nil {
requestResponse := make([]byte, 1)
requestResponse[0] = core.InternalError
conn.Write(requestResponse)
return err
}
if isVerified {
requestResponse := make([]byte, 1)
requestResponse[0] = core.Success
conn.Write(requestResponse)
} else {
requestResponse := make([]byte, 1)
requestResponse[0] = core.Failure
conn.Write(requestResponse)
}
} else if requestType[0] == core.ValidateKey {
// This is probably a bit too big, but I'd like to air on the side of caution here...
serverDomainListLength := make([]byte, 2)
if _, err := conn.Read(serverDomainListLength); err != nil {
return err
}
serverDomainListBytes := make([]byte, binary.BigEndian.Uint16(serverDomainListLength))
if _, err := conn.Read(serverDomainListBytes); err != nil {
return err
}
serverDomainList := strings.Split(string(serverDomainListBytes), "\n")
additionalArgumentsLength := make([]byte, 2)
var additionalArgumentsSize uint16
if _, err := conn.Read(additionalArgumentsLength); err != nil {
return err
}
additionalArgumentsSize = binary.BigEndian.Uint16(additionalArgumentsLength)
additionalArguments := ""
if additionalArgumentsSize != 0 {
additionalArgumentsBytes := make([]byte, additionalArgumentsSize)
if _, err := conn.Read(additionalArgumentsBytes); err != nil {
return err
}
additionalArguments = string(additionalArgumentsBytes)
}
isAddedToTrust, err := signServer.AddVerifyHandler(host, clientKeyFingerprint, serverDomainList, additionalArguments)
if err != nil {
return err
}
if isAddedToTrust {
requestResponse := make([]byte, 1)
requestResponse[0] = core.Success
conn.Write(requestResponse)
} else {
requestResponse := make([]byte, 1)
requestResponse[0] = core.Failure
conn.Write(requestResponse)
}
}
}
}