package main import ( "bytes" "database/sql" "encoding/base64" "encoding/hex" "flag" "fmt" "git.gammaspectra.live/S.O.N.G/MakyuuIchaival" "git.gammaspectra.live/S.O.N.G/MakyuuIchaival/contentmessage" "git.gammaspectra.live/S.O.N.G/MakyuuIchaival/httputils" "git.gammaspectra.live/S.O.N.G/MakyuuIchaival/tlsutils" "github.com/cloudflare/circl/sign/ed25519" "github.com/ipfs/go-cid" _ "github.com/lib/pq" "github.com/multiformats/go-multihash" "log" "net/http" "net/url" "os" "path" "strings" "sync" "time" ) type ContentCacheEntry struct { Entry ContentEntry AccessTime time.Time } var dbHandle *sql.DB var sha256Statement *sql.Stmt var md5Statement *sql.Stmt var fdlimit int var objectCacheMutex sync.RWMutex var objectCache = make(map[string]*ContentCacheEntry) var trustedPublicKeys []ed25519.PublicKey var debugOutput = false func getFirstValidContentEntry(entries *[]ContentEntry) *ContentEntry { for _, entry := range *entries { stat, err := os.Stat(entry.Path) if err == nil && (entry.Size == 0 || uint64(stat.Size()) == entry.Size) { //TODO: update Size if not found copiedEntry := entry copiedEntry.Size = uint64(stat.Size()) return &copiedEntry } } return nil } func GetMimeTypeFromExtension(ext string) string { if len(ext) > 0 { switch strings.ToLower(ext[1:]) { //Audio types case "flac": return "audio/flac" case "mp3": return "audio/mpeg;codecs=mp3" case "m4a": return "audio/mp4" case "mka": return "audio/x-matroska" case "ogg": return "audio/ogg" case "opus": return "audio/opus" case "tta": return "audio/tta" case "aac": return "audio/aac" case "alac": return "audio/alac" case "wav": return "audio/wav" case "ape": return "audio/ape" //Image types case "png": return "image/png" case "jfif": fallthrough case "jpeg": fallthrough case "jpg": return "image/jpeg" case "gif": return "image/gif" case "svg": return "image/svg+xml" case "tiff": fallthrough case "tif": return "image/tiff" case "webp": return "image/webp" case "bmp": return "image/bmp" //Text types case "txt": return "text/plain" case "log": return "text/x-log" case "accurip": return "text/x-accurip" case "cue": return "text/x-cue" case "toc": return "text/x-toc" //Text subtitles case "lrc": return "text/x-subtitle-lrc" case "ssa": return "text/x-subtitle-ssa" case "ass": return "text/x-subtitle-ass" case "srt": return "text/x-subtitle-subrip" //Web types case "js": return "text/javascript" case "wasm": return "application/wasm" case "html": return "text/html" case "css": return "text/css" case "ttf": return "font/ttf" case "otf": return "font/otf" case "woff": return "font/woff" case "woff2": return "font/woff2" } } return "application/octet-stream" } func handleQueryRequest(ctx *httputils.RequestContext, identifier cid.Cid, extraArguments []string) { cTime := time.Now() var cacheEntry = tryGetCacheEntryForIdentifier(identifier) if cacheEntry == nil { result := getEntriesForCID(identifier) entry := getFirstValidContentEntry(&result) if entry != nil { cacheEntry = getCacheEntryForContentEntry(entry, identifier) } ctx.AddTimingInformational("ec", "Content Cache MISS") } else { ctx.AddTimingInformational("ec", "Content Cache HIT") } pTime := cTime cTime = time.Now() ctx.AddTiming("e", "Content Entry", cTime.Sub(pTime)) if cacheEntry == nil { var origin string if len(ctx.GetRequestHeader("Referer")) > 0 { origin = ctx.GetRequestHeader("Referer") } else if len(ctx.GetRequestHeader("Origin")) > 0 { origin = ctx.GetRequestHeader("Origin") } else if len(extraArguments) > 1 { origin = "https://" + extraArguments[1] } //Try to redirect back to origin if len(extraArguments) > 0 && origin != "" { mh, _ := multihash.Decode(identifier.Hash()) var kind string if mh.Code == multihash.SHA2_256 { kind = "sha256" } else if mh.Code == multihash.MD5 { kind = "md5" } if kind != "" { ctx.DoRedirect(fmt.Sprintf("%s/%s/%s/%s", origin, kind, hex.EncodeToString(mh.Digest), strings.Join(extraArguments, "/")), http.StatusFound) return } } ctx.SetResponseCode(http.StatusNotFound) return } mh, _ := multihash.Decode(cacheEntry.Entry.Identifier.Hash()) ctx.SetResponseHeader("Accept-Ranges", "bytes") ctx.SetResponseHeader("X-Request-CID", identifier.String()) ctx.SetResponseHeader("ETag", fmt.Sprintf("\"%s\"", cacheEntry.Entry.Identifier.String())) if mh.Code == multihash.SHA2_256 { ctx.SetResponseHeader("Digest", fmt.Sprintf("sha-256=%s", base64.StdEncoding.EncodeToString(mh.Digest))) } ctx.SetResponseHeader("Cache-Control", "public, max-age=2592000, immutable") filename := path.Base(cacheEntry.Entry.Path) //TODO: setting to hide filename ctx.SetResponseHeader("Content-Disposition", fmt.Sprintf("inline; filename*=utf-8''%s", url.PathEscape(filename))) pTime = cTime cTime = time.Now() ctx.AddTiming("s", "Content Serve", cTime.Sub(pTime)) mime := GetMimeTypeFromExtension(path.Ext(cacheEntry.Entry.Path)) if len(mime) > 0 { ctx.SetResponseHeader("Content-Type", mime) } ctx.ServeFile(cacheEntry.Entry.Path) } func setOtherHeaders(ctx *httputils.RequestContext) { ctx.SetResponseHeader("Server", "OrbitalBeat") ctx.SetResponseHeader("Vary", "Content-Encoding") ctx.SetResponseHeader("X-Content-Type-Options", "nosniff") ctx.SetResponseHeader("X-Robots-Tags", "noindex, nofollow, notranslate") for k, v := range ctx.GetExtraHeaders() { ctx.SetResponseHeader(k, v) } } func setCORSHeaders(ctx *httputils.RequestContext) { ctx.SetResponseHeader("Access-Control-Allow-Credentials", "true") ctx.SetResponseHeader("Access-Control-Max-Age", "7200") //Firefox caps this to 86400, Chrome to 7200. Default is 5 seconds (!!!) ctx.SetResponseHeader("Access-Control-Allow-Methods", "GET,HEAD,OPTIONS") ctx.SetResponseHeader("Access-Control-Allow-Headers", "DNT,ETag,Origin,Accept,Accept-Language,X-Requested-With,Range") ctx.SetResponseHeader("Access-Control-Allow-Origin", "*") ctx.SetResponseHeader("Access-Control-Expose-Headers", "*") //CORP, COEP, COOP ctx.SetResponseHeader("Cross-Origin-Embedder-Policy", "require-corp") ctx.SetResponseHeader("Cross-Origin-Resource-Policy", "cross-origin") ctx.SetResponseHeader("Cross-Origin-Opener-Policy", "unsafe-none") } func getCacheEntryForContentEntry(entry *ContentEntry, originalIdentifier cid.Cid) *ContentCacheEntry { cacheEntry := tryGetCacheEntryForIdentifier(entry.Identifier) if cacheEntry != nil { return cacheEntry } objectCacheMutex.Lock() defer objectCacheMutex.Unlock() if len(objectCache) >= fdlimit { //Find oldest value, remove it var item *ContentCacheEntry for _, e := range objectCache { if item == nil || e.AccessTime.Before(item.AccessTime) { item = e } } if item != nil { delete(objectCache, item.Entry.Identifier.String()) } } c := &ContentCacheEntry{ Entry: *entry, AccessTime: time.Now(), } objectCache[entry.Identifier.String()] = c if originalIdentifier.String() != entry.Identifier.String() { objectCache[originalIdentifier.String()] = c } return c } func tryGetCacheEntryForIdentifier(identifier cid.Cid) *ContentCacheEntry { objectCacheMutex.RLock() defer objectCacheMutex.RUnlock() cacheEntry, ok := objectCache[identifier.String()] if ok { cacheEntry.AccessTime = time.Now().UTC() return cacheEntry } return nil } func IsTrustedPublicKey(key ed25519.PublicKey) bool { for _, k := range trustedPublicKeys { if bytes.Compare(k, key) == 0 { return true } } return false } func handle(ctx *httputils.RequestContext) { if len(ctx.GetRequestHeader("Host")) > 0 && ctx.GetRequestHeader("Host") == ctx.GetTLSServerName() { //Prevents rebinding / DNS stuff ctx.SetResponseCode(http.StatusNotFound) return } cTime := time.Now() if ctx.IsFastHttp() { ctx.AddTiming("c", "Connection", ctx.GetRequestTime().Sub(ctx.GetConnectionTime())) } ctx.AddTiming("r", "Request Handler", cTime.Sub(ctx.GetRequestTime())) if ctx.IsGet() || ctx.IsHead() { if debugOutput { log.Printf("Serve %s", ctx.GetPath()) } setOtherHeaders(ctx) setCORSHeaders(ctx) pathElements := strings.Split(ctx.GetPath(), "/") if len(pathElements) < 2 { ctx.SetResponseCode(http.StatusBadRequest) return } messageBytes, err := MakyuuIchaival.Bech32Encoding.DecodeString(pathElements[1]) if err != nil { ctx.SetResponseCode(http.StatusBadRequest) return } message := contentmessage.DecodeContentMessage(messageBytes) if message == nil { ctx.SetResponseCode(http.StatusBadRequest) return } if !IsTrustedPublicKey(message.PublicKey) { ctx.SetResponseCode(http.StatusForbidden) return } pTime := cTime cTime := time.Now() ctx.AddTiming("d", "Decode", cTime.Sub(pTime)) result, cacheHit := message.Verify() pTime = cTime cTime = time.Now() if cacheHit { ctx.AddTimingInformational("vc", "Ed25519 Cache HIT") } else { ctx.AddTimingInformational("vc", "Ed25519 Cache MISS") } ctx.AddTiming("v", "Ed25519 Verify", cTime.Sub(pTime)) if !result { ctx.SetResponseCode(http.StatusForbidden) return } if debugOutput { log.Printf("Serving CID %s", message.Identifier.String()) } handleQueryRequest(ctx, message.Identifier, pathElements[2:]) } else if ctx.IsOptions() { setOtherHeaders(ctx) setCORSHeaders(ctx) ctx.SetResponseCode(http.StatusNoContent) } else { ctx.SetResponseCode(http.StatusNotImplemented) } } type ContentEntry struct { Identifier cid.Cid Path string Size uint64 } func handleQuery(rows *sql.Rows, err error) []ContentEntry { if err != nil { log.Print(err) return []ContentEntry{} } defer rows.Close() var result []ContentEntry for rows.Next() { var entry ContentEntry var sha256 multihash.Multihash var size sql.NullInt64 err := rows.Scan(&entry.Path, &size, &sha256) if err != nil { log.Print(err) break } if size.Valid { entry.Size = uint64(size.Int64) } mh, _ := multihash.Encode(sha256, multihash.SHA2_256) entry.Identifier = cid.NewCidV1(cid.Raw, mh) result = append(result, entry) } return result } func getEntriesForCID(identifier cid.Cid) []ContentEntry { mh, _ := multihash.Decode(identifier.Hash()) if mh.Code == multihash.SHA2_256 { return handleQuery(sha256Statement.Query(mh.Digest)) } else if mh.Code == multihash.MD5 { return handleQuery(md5Statement.Query(mh.Digest)) } return []ContentEntry{} } func main() { //TODO: OCSP certificatePath := flag.String("certificate", "", "Path to SSL certificate file.") keypairPath := flag.String("keypair", "", "Path to SSL key file.") pgConnStr := flag.String("connstr", "", "Postgres connection string for postgres database") listenAddress := flag.String("listen", ":7777", "Address/port to lisent on.") trustedKeys := flag.String("trusted_keys", "", "Trusted list of public keys, comma separated.") sniAddressOption := flag.String("sni", "", "Define SNI address if desired. Empty will serve any requests regardless.") fdLimitOption := flag.Int("fdlimit", 512, "Maximum number of lingering cached open files.") debugOption := flag.Bool("debug", false, "Output debug information.") http2Option := flag.Bool("http2", false, "Enable HTTP/2") http3Option := flag.Bool("http3", false, "Enable HTTP/3") signatureCacheLimitOption := flag.Int("siglimit", 4096, "Maximum number of lingering valid signature cache results.") flag.Parse() fdlimit = *fdLimitOption contentmessage.SetMessageCacheLimit(*signatureCacheLimitOption) debugOutput = *debugOption var err error for _, k := range strings.Split(*trustedKeys, ",") { var publicKey ed25519.PublicKey publicKey, err = MakyuuIchaival.Bech32Encoding.DecodeString(strings.Trim(k, " ")) if err != nil { log.Fatal(err) } if len(publicKey) != ed25519.PublicKeySize { continue } trustedPublicKeys = append(trustedPublicKeys, publicKey) log.Printf("Added public key %s", strings.Trim(k, " ")) } dbHandle, err = sql.Open("postgres", *pgConnStr) if err != nil { log.Fatal(err) } defer dbHandle.Close() sha256Statement, err = dbHandle.Prepare("SELECT path, size, sha256 FROM entries WHERE sha256 = $1;") if err != nil { log.Fatal(err) } defer sha256Statement.Close() md5Statement, err = dbHandle.Prepare("SELECT path, size, sha256 FROM entries WHERE md5 = $1;") if err != nil { log.Fatal(err) } defer md5Statement.Close() tlsConfiguration, err := tlsutils.NewTLSConfiguration(*certificatePath, *keypairPath, strings.ToLower(*sniAddressOption)) if err != nil { log.Fatal(err) } server := &httputils.Server{ ListenAddress: *listenAddress, TLSConfig: tlsConfiguration, EnableHTTP2: *http2Option, EnableHTTP3: *http3Option, Handler: handle, Debug: debugOutput, } server.Serve() }