181 lines
4.6 KiB
Go
181 lines
4.6 KiB
Go
|
|
package network
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"crypto/rand"
|
||
|
|
"crypto/rsa"
|
||
|
|
"crypto/tls"
|
||
|
|
"crypto/x509"
|
||
|
|
"crypto/x509/pkix"
|
||
|
|
"customServer/internal/handlers"
|
||
|
|
"customServer/internal/protocol"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"math/big"
|
||
|
|
"net"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func StartTCPServer(addr string) error {
|
||
|
|
tlsConfig, err := generateSelfSignedCert()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to generate self-signed cert: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
listener, err := tls.Listen("tcp", addr, tlsConfig)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("TCP TLS listener failed: %v", err)
|
||
|
|
}
|
||
|
|
fmt.Printf("[TCP] Scalable Server (TLS): localhost%s\n", addr)
|
||
|
|
|
||
|
|
for {
|
||
|
|
conn, err := listener.Accept()
|
||
|
|
if err != nil {
|
||
|
|
fmt.Printf("Accept error: %v\n", err)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
go handleTCPConnection(conn)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func handleTCPConnection(conn net.Conn) {
|
||
|
|
defer conn.Close()
|
||
|
|
fmt.Printf("[TCP] New TLS connection from %s\n", conn.RemoteAddr())
|
||
|
|
|
||
|
|
for {
|
||
|
|
// Read length (4 bytes, Big Endian)
|
||
|
|
length, err := protocol.ReadPacketLength(conn)
|
||
|
|
if err != nil {
|
||
|
|
if err != io.EOF {
|
||
|
|
fmt.Printf("[TCP] Read length error: %v\n", err)
|
||
|
|
}
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Read packet
|
||
|
|
data := make([]byte, length)
|
||
|
|
_, err = io.ReadFull(conn, data)
|
||
|
|
if err != nil {
|
||
|
|
fmt.Printf("[TCP] Read data error: %v\n", err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
fmt.Printf("[TCP] Received packet of %d bytes\n", length)
|
||
|
|
|
||
|
|
packetID := int64(0)
|
||
|
|
requestNumber := int32(0)
|
||
|
|
var payloadBytes []byte
|
||
|
|
|
||
|
|
reader := bytes.NewReader(data)
|
||
|
|
for {
|
||
|
|
tag, err := protocol.ReadVarint(reader)
|
||
|
|
if err != nil {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
fieldNum := tag >> 3
|
||
|
|
wireType := tag & 0x7
|
||
|
|
|
||
|
|
if fieldNum == 1 && wireType == 0 { // Packet.id
|
||
|
|
packetID, _ = protocol.ReadVarintInt64(reader)
|
||
|
|
} else if fieldNum == 3 && wireType == 2 { // Packet.payload (Message)
|
||
|
|
payloadLen, _ := protocol.ReadVarint(reader)
|
||
|
|
payloadBytes = make([]byte, payloadLen)
|
||
|
|
reader.Read(payloadBytes)
|
||
|
|
|
||
|
|
payloadReader := bytes.NewReader(payloadBytes)
|
||
|
|
for {
|
||
|
|
pTag, err := protocol.ReadVarint(payloadReader)
|
||
|
|
if err != nil {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
pFieldNum := pTag >> 3
|
||
|
|
pWireType := pTag & 0x7
|
||
|
|
if pFieldNum == 1 && pWireType == 0 { // Message.request_number
|
||
|
|
reqNum64, _ := protocol.ReadVarintInt64(payloadReader)
|
||
|
|
requestNumber = int32(reqNum64)
|
||
|
|
} else {
|
||
|
|
protocol.SkipField(payloadReader, pWireType)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
protocol.SkipField(reader, wireType)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fmt.Printf("[TCP] Got Request ID: %d, Number: %d\n", packetID, requestNumber)
|
||
|
|
|
||
|
|
responsePayload, responseFieldNum := handlers.Dispatch(conn, packetID, requestNumber, payloadBytes)
|
||
|
|
if responsePayload != nil {
|
||
|
|
sendTCPResponse(conn, packetID, responseFieldNum, responsePayload)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func sendTCPResponse(conn net.Conn, packetID int64, fieldNum int, payload []byte) {
|
||
|
|
// Construct Message wrapper
|
||
|
|
message := make([]byte, 0)
|
||
|
|
// Field 1: request_number
|
||
|
|
message = append(message, 0x08)
|
||
|
|
message = append(message, protocol.EncodeVarint(uint64(fieldNum))...)
|
||
|
|
|
||
|
|
// Field [fieldNum]: payload (WireType 2)
|
||
|
|
message = append(message, protocol.EncodeVarint(uint64(fieldNum<<3|2))...)
|
||
|
|
message = append(message, protocol.EncodeVarint(uint64(len(payload)))...)
|
||
|
|
message = append(message, payload...)
|
||
|
|
|
||
|
|
// Construct Packet wrapper
|
||
|
|
packet := make([]byte, 0)
|
||
|
|
// Field 1: Id
|
||
|
|
packet = append(packet, 0x08)
|
||
|
|
packet = append(packet, protocol.EncodeVarint(uint64(packetID))...)
|
||
|
|
// Field 3: Payload
|
||
|
|
packet = append(packet, 0x1a)
|
||
|
|
packet = append(packet, protocol.EncodeVarint(uint64(len(message)))...)
|
||
|
|
packet = append(packet, message...)
|
||
|
|
|
||
|
|
// Send length + packet
|
||
|
|
lengthBuf := make([]byte, 4)
|
||
|
|
length := uint32(len(packet))
|
||
|
|
lengthBuf[0] = byte(length >> 24)
|
||
|
|
lengthBuf[1] = byte(length >> 16)
|
||
|
|
lengthBuf[2] = byte(length >> 8)
|
||
|
|
lengthBuf[3] = byte(length)
|
||
|
|
|
||
|
|
conn.Write(lengthBuf)
|
||
|
|
conn.Write(packet)
|
||
|
|
}
|
||
|
|
|
||
|
|
func generateSelfSignedCert() (*tls.Config, error) {
|
||
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
template := x509.Certificate{
|
||
|
|
SerialNumber: big.NewInt(1),
|
||
|
|
Subject: pkix.Name{
|
||
|
|
Organization: []string{"Custom Server Mod"},
|
||
|
|
},
|
||
|
|
NotBefore: time.Now(),
|
||
|
|
NotAfter: time.Now().Add(time.Hour * 24 * 365),
|
||
|
|
|
||
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||
|
|
BasicConstraintsValid: true,
|
||
|
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||
|
|
DNSNames: []string{"localhost"},
|
||
|
|
}
|
||
|
|
|
||
|
|
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
cert := tls.Certificate{
|
||
|
|
Certificate: [][]byte{derBytes},
|
||
|
|
PrivateKey: priv,
|
||
|
|
}
|
||
|
|
|
||
|
|
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil
|
||
|
|
}
|