init commit
This commit is contained in:
180
internal/network/tcp_server.go
Normal file
180
internal/network/tcp_server.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"customServer/internal/handlers"
|
||||
"customServer/internal/protocol"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func StartTCPServer(addr string) error {
|
||||
tlsConfig, err := generateSelfSignedCert()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate self-signed cert: %v", err)
|
||||
}
|
||||
|
||||
listener, err := tls.Listen("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("TCP TLS listener failed: %v", err)
|
||||
}
|
||||
fmt.Printf("[TCP] Scalable Server (TLS): localhost%s\n", addr)
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
fmt.Printf("Accept error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go handleTCPConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func handleTCPConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
fmt.Printf("[TCP] New TLS connection from %s\n", conn.RemoteAddr())
|
||||
|
||||
for {
|
||||
// Read length (4 bytes, Big Endian)
|
||||
length, err := protocol.ReadPacketLength(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
fmt.Printf("[TCP] Read length error: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Read packet
|
||||
data := make([]byte, length)
|
||||
_, err = io.ReadFull(conn, data)
|
||||
if err != nil {
|
||||
fmt.Printf("[TCP] Read data error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("[TCP] Received packet of %d bytes\n", length)
|
||||
|
||||
packetID := int64(0)
|
||||
requestNumber := int32(0)
|
||||
var payloadBytes []byte
|
||||
|
||||
reader := bytes.NewReader(data)
|
||||
for {
|
||||
tag, err := protocol.ReadVarint(reader)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fieldNum := tag >> 3
|
||||
wireType := tag & 0x7
|
||||
|
||||
if fieldNum == 1 && wireType == 0 { // Packet.id
|
||||
packetID, _ = protocol.ReadVarintInt64(reader)
|
||||
} else if fieldNum == 3 && wireType == 2 { // Packet.payload (Message)
|
||||
payloadLen, _ := protocol.ReadVarint(reader)
|
||||
payloadBytes = make([]byte, payloadLen)
|
||||
reader.Read(payloadBytes)
|
||||
|
||||
payloadReader := bytes.NewReader(payloadBytes)
|
||||
for {
|
||||
pTag, err := protocol.ReadVarint(payloadReader)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
pFieldNum := pTag >> 3
|
||||
pWireType := pTag & 0x7
|
||||
if pFieldNum == 1 && pWireType == 0 { // Message.request_number
|
||||
reqNum64, _ := protocol.ReadVarintInt64(payloadReader)
|
||||
requestNumber = int32(reqNum64)
|
||||
} else {
|
||||
protocol.SkipField(payloadReader, pWireType)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
protocol.SkipField(reader, wireType)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("[TCP] Got Request ID: %d, Number: %d\n", packetID, requestNumber)
|
||||
|
||||
responsePayload, responseFieldNum := handlers.Dispatch(conn, packetID, requestNumber, payloadBytes)
|
||||
if responsePayload != nil {
|
||||
sendTCPResponse(conn, packetID, responseFieldNum, responsePayload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sendTCPResponse(conn net.Conn, packetID int64, fieldNum int, payload []byte) {
|
||||
// Construct Message wrapper
|
||||
message := make([]byte, 0)
|
||||
// Field 1: request_number
|
||||
message = append(message, 0x08)
|
||||
message = append(message, protocol.EncodeVarint(uint64(fieldNum))...)
|
||||
|
||||
// Field [fieldNum]: payload (WireType 2)
|
||||
message = append(message, protocol.EncodeVarint(uint64(fieldNum<<3|2))...)
|
||||
message = append(message, protocol.EncodeVarint(uint64(len(payload)))...)
|
||||
message = append(message, payload...)
|
||||
|
||||
// Construct Packet wrapper
|
||||
packet := make([]byte, 0)
|
||||
// Field 1: Id
|
||||
packet = append(packet, 0x08)
|
||||
packet = append(packet, protocol.EncodeVarint(uint64(packetID))...)
|
||||
// Field 3: Payload
|
||||
packet = append(packet, 0x1a)
|
||||
packet = append(packet, protocol.EncodeVarint(uint64(len(message)))...)
|
||||
packet = append(packet, message...)
|
||||
|
||||
// Send length + packet
|
||||
lengthBuf := make([]byte, 4)
|
||||
length := uint32(len(packet))
|
||||
lengthBuf[0] = byte(length >> 24)
|
||||
lengthBuf[1] = byte(length >> 16)
|
||||
lengthBuf[2] = byte(length >> 8)
|
||||
lengthBuf[3] = byte(length)
|
||||
|
||||
conn.Write(lengthBuf)
|
||||
conn.Write(packet)
|
||||
}
|
||||
|
||||
func generateSelfSignedCert() (*tls.Config, error) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Custom Server Mod"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 365),
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert := tls.Certificate{
|
||||
Certificate: [][]byte{derBytes},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user