diff --git a/FinalCommander.go b/FinalCommander.go index 22d896a..6b14fdb 100644 --- a/FinalCommander.go +++ b/FinalCommander.go @@ -21,7 +21,6 @@ import ( "time" ) -var privateKey ed25519.PrivateKey var publicKey ed25519.PublicKey var debugOutput = false @@ -179,7 +178,7 @@ func handleHexHash(pathElements []string, ctx *httputils.RequestContext, host st return } - ctx.DoRedirect(contentServer.GetContentURL(entry, privateKey, skip)+host, http.StatusFound) + ctx.DoRedirect(contentServer.GetContentURL(entry, skip)+host, http.StatusFound) } else { contentServer := selectNextContentServer(skip) if contentServer == nil { @@ -197,7 +196,7 @@ func handleHexHash(pathElements []string, ctx *httputils.RequestContext, host st continue } - result, err := c.CheckEntryKey(key, privateKey) + result, err := c.CheckEntryKey(key) if result != nil { if e == nil { e = &content.Entry{ @@ -232,7 +231,7 @@ func handleHexHash(pathElements []string, ctx *httputils.RequestContext, host st } }() - ctx.DoRedirect(contentServer.GetHashURL(mh, privateKey, skip)+host, http.StatusFound) + ctx.DoRedirect(contentServer.GetHashURL(mh, skip)+host, http.StatusFound) } } @@ -307,7 +306,7 @@ func getContentEntry(key *content.HashIdentifier) *content.Entry { continue } - h, err := c.CheckEntryKey(&e.Key, privateKey) + h, err := c.CheckEntryKey(&e.Key) if h == nil && err == nil { newInvalidList = append(newInvalidList, c.Index) } @@ -328,14 +327,14 @@ func checkContentServers() { } func main() { - //TODO: OCSP + debugOption := flag.Bool("debug", false, "Enable debug output.") certificatePath := flag.String("certificate", "", "Path to SSL certificate file.") keypairPath := flag.String("keypair", "", "Path to SSL key file.") databasePath := flag.String("dbpath", "database", "Path to key/value database.") - listenAddress := flag.String("listen", ":7777", "Address/port to lisent on.") + listenAddress := flag.String("listen", ":7777", "address/port to listen on.") - weightedServerList := flag.String("servers", "", "Weighted list of servers to use. All use HTTPs. Format address:PORT/WEIGHT,[...]") + weightedServerList := flag.String("servers", "", "Weighted list of servers to use. All will use HTTPs. Allowed protocols: orbt. Format [protocol=]address:PORT/WEIGHT,[...]") sniAddressOption := flag.String("sni", "", "Define SNI address if desired. Empty will serve any requests regardless.") @@ -343,6 +342,8 @@ func main() { var err error + debugOutput = *debugOption + privateKeyEnv := os.Getenv("PRIVATE_KEY") if privateKeyEnv != "" { @@ -373,7 +374,7 @@ func main() { log.Fatal("Wrong Private key length") } - privateKey = ed25519.NewKeyFromSeed(privateSeed) + privateKey := ed25519.NewKeyFromSeed(privateSeed) publicKey = make([]byte, ed25519.PublicKeySize) copy(publicKey, privateKey[ed25519.PublicKeySize:]) log.Printf("Loaded Private Ed25519 key, Public %s", MakyuuIchaival.Bech32Encoding.EncodeToString(publicKey)) @@ -385,7 +386,7 @@ func main() { defer db.Close() for i, s := range strings.Split(*weightedServerList, ",") { - cs, err := content.NewContentServerFromArgument(s, i) + cs, err := content.NewContentServerFromArgument(s, i, privateKey) if err != nil { log.Fatal(err) diff --git a/content/server.go b/content/server.go index c1ba745..893559f 100644 --- a/content/server.go +++ b/content/server.go @@ -16,17 +16,42 @@ import ( "time" ) +type ServerProtocol int + +const ( + ProtocolOrbitalBeatV1 ServerProtocol = iota +) + type Server struct { Index int Address string + Protocol ServerProtocol + key []byte Weight uint LastCheckResult bool lastCheckMutex sync.RWMutex } -func NewContentServerFromArgument(arg string, index int) (*Server, error) { - //Format address:PORT/WEIGHT[/publicKey], - p := strings.Split(arg, "/") +func NewContentServerFromArgument(arg string, index int, defaultKey []byte) (*Server, error) { + //Format Address:PORT/WEIGHT[/publicKey], + + protos := strings.Split(arg, "=") + + serverProtocol := ProtocolOrbitalBeatV1 + + serverKey := defaultKey + + if len(protos) > 1 { + switch protos[0] { + case "orbt": + serverProtocol = ProtocolOrbitalBeatV1 + + default: + return nil, fmt.Errorf("invalid server Protocol %s", arg) + } + } + + p := strings.Split(protos[len(protos)-1], "/") if len(p) < 2 { return nil, fmt.Errorf("invalid weighted server %s", arg) } @@ -39,6 +64,8 @@ func NewContentServerFromArgument(arg string, index int) (*Server, error) { cs := &Server{ Index: index, Address: p[0], + Protocol: serverProtocol, + key: serverKey, Weight: uint(weight), LastCheckResult: false, } @@ -46,19 +73,29 @@ func NewContentServerFromArgument(arg string, index int) (*Server, error) { return cs, nil } -func (s *Server) GetContentURL(content *Entry, key ed25519.PrivateKey, skip []int) string { - message := contentmessage.NewContentMessageV1(content.Multihash(), key) - skip = append(skip, s.Index) - return s.getURL(MakyuuIchaival.Bech32Encoding.EncodeToString(message.Encode()), MakyuuIchaival.Bech32Encoding.EncodeToString(utilities.EncodeIntegerList(skip))) +func (s *Server) GetContentURL(content *Entry, skip []int) string { + switch s.Protocol { + case ProtocolOrbitalBeatV1: + message := contentmessage.NewContentMessageV1(content.Multihash(), ed25519.PrivateKey(s.key)) + skip = append(skip, s.Index) + return s.getBaseURL(MakyuuIchaival.Bech32Encoding.EncodeToString(message.Encode()), MakyuuIchaival.Bech32Encoding.EncodeToString(utilities.EncodeIntegerList(skip))) + default: + return "" + } } -func (s *Server) GetHashURL(mh multihash.Multihash, key ed25519.PrivateKey, skip []int) string { - message := contentmessage.NewContentMessageV1(mh, key) - skip = append(skip, s.Index) - return s.getURL(MakyuuIchaival.Bech32Encoding.EncodeToString(message.Encode()), MakyuuIchaival.Bech32Encoding.EncodeToString(utilities.EncodeIntegerList(skip))) +func (s *Server) GetHashURL(mh multihash.Multihash, skip []int) string { + switch s.Protocol { + case ProtocolOrbitalBeatV1: + message := contentmessage.NewContentMessageV1(mh, ed25519.PrivateKey(s.key)) + skip = append(skip, s.Index) + return s.getBaseURL(MakyuuIchaival.Bech32Encoding.EncodeToString(message.Encode()), MakyuuIchaival.Bech32Encoding.EncodeToString(utilities.EncodeIntegerList(skip))) + default: + return "" + } } -func (s *Server) getURL(args ...string) string { +func (s *Server) getBaseURL(args ...string) string { return fmt.Sprintf("https://%s/%s", s.Address, strings.Join(args, "/")) } @@ -81,7 +118,7 @@ func (s *Server) Check() { Transport: customTransport, Timeout: 5 * time.Second, } - response, err := client.Head(s.getURL()) + response, err := client.Head(s.getBaseURL()) if err != nil { s.setCheckResult(false) @@ -97,14 +134,14 @@ func (s *Server) Check() { s.setCheckResult(true) } -func (s *Server) CheckEntryKey(key *HashIdentifier, privateKey ed25519.PrivateKey) (*HashIdentifier, error) { +func (s *Server) CheckEntryKey(key *HashIdentifier) (*HashIdentifier, error) { customTransport := http.DefaultTransport.(*http.Transport).Clone() customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} client := &http.Client{ Transport: customTransport, Timeout: 5 * time.Second, } - response, err := client.Head(s.GetHashURL(key.Hash(), privateKey, []int{})) + response, err := client.Head(s.GetHashURL(key.Hash(), []int{})) if err != nil { return nil, err