FinalCommander/FinalCommander.go
2022-01-18 19:54:39 +01:00

313 lines
8 KiB
Go

package main
import (
"bytes"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"encoding/hex"
"flag"
"fmt"
"git.gammaspectra.live/S.O.N.G/MakyuuIchaival"
"git.gammaspectra.live/S.O.N.G/MakyuuIchaival/contentmessage"
"git.gammaspectra.live/S.O.N.G/MakyuuIchaival/httputils"
"git.gammaspectra.live/S.O.N.G/MakyuuIchaival/tlsutils"
"github.com/cloudflare/circl/sign/ed25519"
"github.com/mroth/weightedrand"
"github.com/multiformats/go-multihash"
"log"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
)
var privateKey ed25519.PrivateKey
var debugOutput = false
var contentServers []*ContentServer
type ContentServer struct {
Index int
Address string
Weight uint
LastCheckResult bool
LastCheckMutex sync.RWMutex
}
func (s *ContentServer) getURL(args ...string) string {
return fmt.Sprintf("https://%s/%s", s.Address, strings.Join(args, "/"))
}
func (s *ContentServer) getCheckResult() bool {
s.LastCheckMutex.RLock()
defer s.LastCheckMutex.RUnlock()
return s.LastCheckResult
}
func (s *ContentServer) setCheckResult(result bool) {
s.LastCheckMutex.Lock()
defer s.LastCheckMutex.Unlock()
s.LastCheckResult = result
}
func (s *ContentServer) check() {
customTransport := http.DefaultTransport.(*http.Transport).Clone()
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
client := &http.Client{
Transport: customTransport,
Timeout: 5 * time.Second,
}
response, err := client.Head(s.getURL())
if err != nil {
s.setCheckResult(false)
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusBadRequest {
s.setCheckResult(false)
return
}
s.setCheckResult(true)
}
func selectNextContentServer(skip []int) *ContentServer {
inSkip := func(i int) bool {
for _, c := range skip {
if i == c {
return true
}
}
return false
}
var choosingList []weightedrand.Choice
for _, c := range contentServers {
if !inSkip(c.Index) && c.getCheckResult() {
choosingList = append(choosingList, weightedrand.NewChoice(c, c.Weight))
}
}
chooser, err := weightedrand.NewChooser(choosingList...)
if err != nil {
return nil
}
return chooser.Pick().(*ContentServer)
}
func setOtherHeaders(ctx *httputils.RequestContext) {
ctx.SetResponseHeader("Server", "FinalCommander")
ctx.SetResponseHeader("Vary", "Content-Encoding")
ctx.SetResponseHeader("X-Content-Type-Options", "nosniff")
ctx.SetResponseHeader("X-Robots-Tags", "noindex, nofollow, notranslate")
ctx.SetResponseHeader("Referrer-Policy", "origin")
}
func setCORSHeaders(ctx *httputils.RequestContext) {
ctx.SetResponseHeader("Access-Control-Allow-Credentials", "true")
ctx.SetResponseHeader("Access-Control-Max-Age", "7200") //Firefox caps this to 86400, Chrome to 7200. Default is 5 seconds (!!!)
ctx.SetResponseHeader("Access-Control-Allow-Methods", "GET,HEAD,OPTIONS")
ctx.SetResponseHeader("Access-Control-Allow-Headers", "DNT,ETag,Origin,Accept,Accept-Language,X-Requested-With,Range")
ctx.SetResponseHeader("Access-Control-Allow-Origin", "*")
ctx.SetResponseHeader("Access-Control-Expose-Headers", "*")
//CORP, COEP, COOP
ctx.SetResponseHeader("Cross-Origin-Embedder-Policy", "require-corp")
ctx.SetResponseHeader("Cross-Origin-Resource-Policy", "cross-origin")
ctx.SetResponseHeader("Cross-Origin-Opener-Policy", "unsafe-none")
}
func encodeIntegerList(data []int) []byte {
message := &bytes.Buffer{}
buf := make([]byte, binary.MaxVarintLen64)
for _, i := range data {
n := binary.PutVarint(buf, int64(i))
_, _ = message.Write(buf[:n])
}
return message.Bytes()
}
func decodeIntegerList(data []byte) []int {
buf := bytes.NewBuffer(data)
var result []int
for {
if buf.Len() <= 0 {
break
}
i, err := binary.ReadVarint(buf)
if err != nil {
//TODO: maybe should error
break
}
result = append(result, int(i))
}
return result
}
func handle(ctx *httputils.RequestContext) {
if len(ctx.GetRequestHeader("Host")) > 0 && ctx.GetRequestHeader("Host") == ctx.GetTLSServerName() { //Prevents rebinding / DNS stuff
ctx.SetResponseCode(http.StatusNotFound)
return
}
if ctx.IsGet() || ctx.IsHead() {
if debugOutput {
log.Printf("Serve %s", ctx.GetPath())
}
setOtherHeaders(ctx)
setCORSHeaders(ctx)
pathElements := strings.Split(ctx.GetPath(), "/")
if len(pathElements) < 3 {
ctx.SetResponseCode(http.StatusBadRequest)
return
}
hashType := strings.ToLower(pathElements[1])
hash, err := hex.DecodeString(pathElements[2])
if err != nil {
ctx.SetResponseCode(http.StatusBadRequest)
return
}
var skip []int
if len(pathElements) > 3 {
data, err := MakyuuIchaival.Bech32Encoding.DecodeString(pathElements[3])
if err != nil {
ctx.SetResponseCode(http.StatusBadRequest)
return
}
skip = decodeIntegerList(data)
}
var mh multihash.Multihash
if hashType == "sha256" && len(hash) == 32 {
mh, _ = multihash.Encode(hash, multihash.SHA2_256)
} else if hashType == "md5" && len(hash) == 16 {
mh, _ = multihash.Encode(hash, multihash.MD5)
} else {
ctx.SetResponseCode(http.StatusNotFound)
return
}
contentServer := selectNextContentServer(skip)
if contentServer == nil {
ctx.SetResponseCode(http.StatusNotFound)
return
}
//TODO: signing cache
message := contentmessage.NewContentMessageV1(mh, privateKey)
skip = append(skip, contentServer.Index)
ctx.DoRedirect(contentServer.getURL(MakyuuIchaival.Bech32Encoding.EncodeToString(message.Encode()), MakyuuIchaival.Bech32Encoding.EncodeToString(encodeIntegerList(skip))), http.StatusFound)
} else if ctx.IsOptions() {
setOtherHeaders(ctx)
setCORSHeaders(ctx)
ctx.SetResponseCode(http.StatusNoContent)
} else {
ctx.SetResponseCode(http.StatusNotImplemented)
}
}
func checkContentServers() {
for _, c := range contentServers {
c.check()
}
}
func main() {
//TODO: OCSP
certificatePath := flag.String("certificate", "", "Path to SSL certificate file.")
keypairPath := flag.String("keypair", "", "Path to SSL key file.")
listenAddress := flag.String("listen", ":7777", "Address/port to lisent on.")
weightedServerList := flag.String("servers", "", "Weighted list of servers to use. All use HTTPs. Format address:PORT/WEIGHT,[...]")
sniAddressOption := flag.String("sni", "", "Define SNI address if desired. Empty will serve any requests regardless.")
flag.Parse()
var err error
privateKeyEnv := os.Getenv("PRIVATE_KEY")
if privateKeyEnv == "" {
log.Print("No PRIVATE_KEY environment variable specified, generating new identity")
publicKey, privateKey, _ := ed25519.GenerateKey(rand.Reader)
log.Printf("Public Ed25519 key (share this): %s", MakyuuIchaival.Bech32Encoding.EncodeToString(publicKey))
log.Printf("Private Ed25519 key (keep safe!): %s", MakyuuIchaival.Bech32Encoding.EncodeToString(privateKey))
return
}
privateKey, err = MakyuuIchaival.Bech32Encoding.DecodeString(privateKeyEnv)
if err != nil {
log.Fatal(err)
}
publicKey := make([]byte, ed25519.PublicKeySize)
copy(publicKey, privateKey[32:])
log.Printf("Loaded Private Ed25519 key, Public %s", MakyuuIchaival.Bech32Encoding.EncodeToString(publicKey))
if len(privateKey) != ed25519.PrivateKeySize {
log.Fatal("Wrong Private key length")
}
for i, s := range strings.Split(*weightedServerList, ",") {
p := strings.Split(s, "/")
if len(p) != 2 {
log.Fatalf("Invalid weighted server %s", s)
}
weight, err := strconv.ParseUint(p[1], 10, 32)
if err != nil {
log.Fatal(err)
}
contentServers = append(contentServers, &ContentServer{
Index: i,
Address: p[0],
Weight: uint(weight),
LastCheckResult: false,
})
}
checkContentServers()
go func() {
ticker := time.NewTicker(1 * time.Minute)
for _ = range ticker.C {
checkContentServers()
}
}()
tlsConfiguration, err := tlsutils.NewTLSConfiguration(*certificatePath, *keypairPath, *sniAddressOption)
if err != nil {
log.Fatal(err)
}
server := &httputils.Server{
ListenAddress: *listenAddress,
TLSConfig: tlsConfiguration,
EnableHTTP2: true,
EnableHTTP3: true,
Handler: handle,
Debug: debugOutput,
}
server.Serve()
}