Working image generation, small refactor

This commit is contained in:
Mememan 2024-04-18 00:03:45 +02:00
parent 1043e406b2
commit 98abd4a581
10 changed files with 349 additions and 135 deletions

5
.gitignore vendored
View file

@ -25,4 +25,7 @@ __debug_bin*
# temporary crap
tmp*
.swp*
.swp*
# binary
waifu_gallery

27
README.md Normal file
View file

@ -0,0 +1,27 @@
# Waifu Gallery
Uses [Fooocus API running on Runpod serverless](https://github.com/davefojtik/RunPod-Fooocus-API/tree/Standalone) to generate '''Beautiful''' artwork with the base constraints of your choosing and display it in a virtual gallery in an ironic twist.
In practice, going to `painting_a_futurist_atwork.yourdomain.com` will generate an image with `{{ random Prompt }},{{ base Prompt }}, (painting a futurist artwork:UserBias)` if it hasn't already been created and save it to s3 to speed up further viewings.
## Installation
```shell
$ go get
$ go build
```
## Usage
Add a `config.json` configuration file in the your current directory and then execute the program.
Example:
```json
{
}
```
## TODO
- [x] make imggen api request
- [x] upload result to S3
- [x] Poll request
- [ ] Guard from redundant API calls
- [ ] Make user wait for result and reload
- [ ] Better front

View file

@ -7,7 +7,7 @@ import (
"strconv"
"strings"
"imggen.reisen/s3"
"waifu_gallery/s3"
)
type Config struct {
@ -31,8 +31,10 @@ type Config struct {
NegativePrompt string `json:"negativePrompt"` // self explanatory
UserPromptBias float64 `json:"userPromptBias"` // how much do we favor the user's input
GuidanceScale float64 `json:"GuidanceScale"` // The guidance scale, default 7.0
Sharpness float64 `json:"sharpness"` // Sharpness, defaults to 2.0
Steps int `json:"steps"` // Number of generation steps, default 25
Sampler string `json:"sampler"` // Which sampler to use, default euler_a
BaseModel string `json:"baseModel"` // Base model, leave it empty to use fooocus' default
BannedWords []string `json:"bannedWords"` // user input will be blocked if it contains one of these
// Internal state from config
@ -74,6 +76,9 @@ func parseConfig(name string, cfg *Config) (err error) {
if cfg.Steps == 0 {
cfg.Steps = 25
}
if cfg.Sharpness == 0 {
cfg.Sharpness = 2.0
}
// parse config input
parts := strings.Split(cfg.AspectRatio, "*")

View file

@ -1,9 +1,17 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/rand"
"strings"
"time"
"waifu_gallery/runpod"
"waifu_gallery/s3"
)
type AdvancedParam struct {
@ -22,9 +30,11 @@ type FooocusTxt2Img struct {
RequireBase64 bool `json:"require_base64,omitempty"` // set this to true!!
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
BaseModelName string `json:"base_model_name,omitempty"`
Loras []Lora `json:"loras,omitempty"`
AspectRatio string `json:"aspect_ratios_selection,omitempty"`
ImageSeed int `json:"image_seed,omitempty"`
Sharpness float64 `json:"sharpness,omitempty"`
GuidanceScale float64 `json:"guidance_scale,omitempty"`
AdvancedParam AdvancedParam `json:"advanced_params,omitempty"`
}
@ -33,7 +43,7 @@ type FooocusInput struct {
Input FooocusTxt2Img `json:"input"`
}
func buildTxt2ImgQuery(prompt string, bias float64) (data []byte, err error) {
func buildTxt2ImgQuery(prompt string) (data []byte, err error) {
var query FooocusTxt2Img
query.Api_name = "txt2img"
@ -41,8 +51,12 @@ func buildTxt2ImgQuery(prompt string, bias float64) (data []byte, err error) {
query.Loras = config.Loras
query.AspectRatio = config.AspectRatio
query.GuidanceScale = config.GuidanceScale
query.Sharpness = config.Sharpness
query.AdvancedParam.OverwriteStep = config.Steps
query.AdvancedParam.SamplerName = config.Sampler
if len(config.BaseModel) != 0 {
query.BaseModelName = config.BaseModel
}
if config.ImageSeed != 0 {
query.ImageSeed = config.ImageSeed
}
@ -52,7 +66,7 @@ func buildTxt2ImgQuery(prompt string, bias float64) (data []byte, err error) {
randPrompt = config.RandomPrompt[rand.Intn(len(config.RandomPrompt))]
}
if len(config.BasePrompt) > 0 {
query.Prompt = fmt.Sprintf("%s, %s, (%s:%f)", randPrompt, config.BasePrompt, prompt, bias)
query.Prompt = fmt.Sprintf("%s, %s, (%s:%f)", randPrompt, config.BasePrompt, prompt, config.UserPromptBias)
} else {
query.Prompt = prompt
}
@ -60,14 +74,56 @@ func buildTxt2ImgQuery(prompt string, bias float64) (data []byte, err error) {
data, err = json.Marshal(FooocusInput{Input: query})
if err != nil {
print(err)
ErrorLogger.Println(err)
return
}
return
}
func sendQuery(data []byte) (err error) {
// err = runpod.Runsync(config.RunpodApi.Endpoint, config.RunpodApi.Authorization)
func createPicture(input string) (err error) {
// TODO mutex and stuff
query, err := buildTxt2ImgQuery(input)
if err != nil {
return
}
resp, err := runpod.Run(query)
if err != nil {
return
}
status, err := runpod.Poll(resp.Id, 500*time.Millisecond)
if err != nil {
return
}
if len(status.Output) == 0 {
return errors.New("no output in response")
}
output := status.Output[0]
// Upload file
err = s3.Upload(
"image/png",
input,
base64.NewDecoder(base64.StdEncoding, strings.NewReader(output["base64"])),
map[string]string{
"Technique": "Digital Art", // TODO generate this via vision/LLM?
"Width": fmt.Sprintf("%d", config.state.width),
"Height": fmt.Sprintf("%d", config.state.height),
"Seed": output["seed"],
})
if err != nil {
ErrorLogger.Println(err)
}
// Upload Query
err = s3.Upload(
"application/json",
input+".json",
bytes.NewReader(query),
map[string]string{
"Seed": output["seed"],
},
)
return
}

8
go.mod
View file

@ -1,8 +1,11 @@
module imggen.reisen
module waifu_gallery
go 1.22.2
require golang.org/x/text v0.14.0
require (
github.com/minio/minio-go/v7 v7.0.69
golang.org/x/text v0.14.0
)
require (
github.com/dustin/go-humanize v1.0.1 // indirect
@ -11,7 +14,6 @@ require (
github.com/klauspost/compress v1.17.6 // indirect
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/minio-go/v7 v7.0.69 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect

127
http.go Normal file
View file

@ -0,0 +1,127 @@
package main
import (
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"waifu_gallery/s3"
)
func handleSubdomain(w http.ResponseWriter, r *http.Request) (input string, err error) {
// get Subdomain
domainInd := strings.LastIndex(r.Host, ".")
if domainInd <= 0 {
return "", errors.New("Error: Host '" + r.Host + "' doesn't have a subdomain")
}
subdomain := r.Host[:domainInd]
// Check if prompt has banned words
for _, bannedWord := range config.BannedWords {
if strings.Contains(subdomain, bannedWord) {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("This incident has been reported. Sick fuck."))
return "", fmt.Errorf("banned words in req '%s'", subdomain)
}
}
// Subdomain sanitization
input = subdomain[strings.LastIndex(subdomain, ".")+1:] // only keep one level of subdomain
// add spaces
input = strings.ReplaceAll(input, "_", " ")
input = strings.ReplaceAll(input, "-", " ")
// remove non-ascii characters
input = regexp.MustCompile(`[^a-zA-Z0-9 ]+`).ReplaceAllString(input, " ")
input = strings.TrimSpace(input)
input = strings.ToLower(input)
if len(input) <= 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Invalid input, subdomain cannot be empty (Accepted characters are '_-[A-Za-z]')"))
return "", errors.New("invalid input")
}
return
}
func serveGallery(w http.ResponseWriter, r *http.Request) {
// Check that the root got called
if r.RequestURI != "/" {
return
}
// Validate user input
input, err := handleSubdomain(w, r)
if err != nil {
InfoLogger.Println(err)
return
}
metadata, err := s3.GetMetaData(input)
if err != nil {
switch s3.ToErrorResponse(err).Code {
case "NoSuchKey":
// Picture doesn't exist, let's create it
InfoLogger.Printf("generating new picture with input '%s'\n", input)
err := createPicture(input)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
ErrorLogger.Println(err)
return
}
// get metadata again
metadata, err = s3.GetMetaData(input)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
ErrorLogger.Println(err)
return
}
default:
w.WriteHeader(http.StatusInternalServerError)
ErrorLogger.Println(err)
return
}
}
// Picture exists, let's show it
err = writeDisplayPage(w, FileData{
URL: "img.png",
Name: metadata.Key,
Size: metadata.Size,
Date: metadata.LastModified,
Metadata: metadata.UserMetadata,
})
if err != nil {
ErrorLogger.Println(err)
}
}
func servePicture(w http.ResponseWriter, r *http.Request) {
// Validate user input
input, err := handleSubdomain(w, r)
if err != nil {
InfoLogger.Println(err)
return
}
// Fetch image from input in s3
obj, err := s3.Download(input)
if err != nil {
switch s3.ToErrorResponse(err).Code {
case "NoSuchKey":
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("This picture does not exist (yet)."))
return
default:
ErrorLogger.Println(err)
return
}
}
// reader will bitch if the object doesn't exist but.. whatever
http.ServeContent(w, r, input+".png", time.Now(), obj)
}

118
main.go
View file

@ -1,128 +1,18 @@
package main
import (
"encoding/binary"
"errors"
"fmt"
"log"
"net/http"
"os"
"regexp"
"strings"
"time"
"imggen.reisen/s3"
"waifu_gallery/runpod"
"waifu_gallery/s3"
)
var config Config
var InfoLogger = log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile)
var ErrorLogger = log.New(os.Stderr, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
func handleSubdomain(w http.ResponseWriter, r *http.Request) (input string, err error) {
// get Subdomain
domainInd := strings.LastIndex(r.Host, ".")
if domainInd <= 0 {
return "", errors.New("Error: Host '" + r.Host + "' doesn't have a subdomain")
}
subdomain := r.Host[:domainInd]
// Check if prompt has banned words
for _, bannedWord := range config.BannedWords {
if strings.Contains(subdomain, bannedWord) {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("This incident has been reported. Sick fuck."))
return "", fmt.Errorf("banned words in req '%s'", subdomain)
}
}
// Subdomain sanitization
input = subdomain[strings.LastIndex(subdomain, ".")+1:] // only keep one level of subdomain
// add spaces
input = strings.ReplaceAll(input, "_", " ")
input = strings.ReplaceAll(input, "-", " ")
// remove non-ascii characters
input = regexp.MustCompile(`[^a-zA-Z0-9 ]+`).ReplaceAllString(input, " ")
input = strings.TrimSpace(input)
input = strings.ToLower(input)
if len(input) <= 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Invalid input, subdomain cannot be empty (Accepted characters are '_-[A-Za-z]')"))
return "", errors.New("invalid input")
}
return
}
func serveGallery(w http.ResponseWriter, r *http.Request) {
// Check that the root got called
if r.RequestURI != "/" {
return
}
// Validate user input
input, err := handleSubdomain(w, r)
if err != nil {
InfoLogger.Println(err)
return
}
metadata, err := s3.GetMetaData(input)
if err != nil {
switch s3.ToErrorResponse(err).Code {
case "NoSuchKey":
// Picture doesn't exist, let's create it
query, err := buildTxt2ImgQuery(input, config.UserPromptBias)
if err != nil {
ErrorLogger.Println(err)
return
}
// debug
binary.Write(os.Stderr, binary.LittleEndian, query)
default:
ErrorLogger.Println(err)
return
}
} else {
// Picture exists, let's show it
err = writeDisplayPage(w, FileData{
URL: "img.png",
Name: metadata.Key,
Size: metadata.Size,
Date: metadata.LastModified,
Metadata: metadata.UserMetadata,
})
if err != nil {
ErrorLogger.Println(err)
}
}
}
func servePicture(w http.ResponseWriter, r *http.Request) {
// Validate user input
input, err := handleSubdomain(w, r)
if err != nil {
InfoLogger.Println(err)
return
}
// Fetch image from input in s3
obj, err := s3.Download(input)
if err != nil {
switch s3.ToErrorResponse(err).Code {
case "NoSuchKey":
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("This picture does not exist (yet)"))
return
default:
ErrorLogger.Println(err)
return
}
}
// reader will bitch if the object doesn't exist but.. whatever
http.ServeContent(w, r, input+".png", time.Now(), obj)
}
func main() {
// Read config file
err := parseConfig("config.json", &config)
@ -130,10 +20,12 @@ func main() {
log.Fatal(err)
}
// Check for required values
// Initialize Runpod
if len(config.RunpodApi.Endpoint) <= 0 || len(config.RunpodApi.Authorization) <= 0 {
log.Fatal("Runpod credentials not set")
}
runpod.Endpoint = config.RunpodApi.Endpoint
runpod.Authorization = config.RunpodApi.Authorization
// Initialize S3
err = s3.InitMinio(s3.Config(config.S3))

View file

@ -1,10 +0,0 @@
package runpod
func Run(endpoint string, authorization string, data []byte) (err error) {
// client := &http.Client{}
// req, _ := http.NewRequest("GET", url, nil)
// res, _ := client.Do(req)
// resp, err := http.Post(endpoint + "runsync")
return
}

105
runpod/runpod.go Normal file
View file

@ -0,0 +1,105 @@
package runpod
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"time"
)
var Endpoint string
var Authorization string
type Request struct {
Id string
Status string
}
type StatusResp struct {
DelayTime int `json:"delayTime"`
ExecutionTime int `json:"executionTime"`
Id string `json:"id"`
Status string `json:"status"`
Output []map[string]string `json:"output"`
}
func postReq(data []byte, path string) (resp *http.Response, err error) {
req, err := http.NewRequest("POST", Endpoint+path, bytes.NewBuffer(data))
if err != nil {
return
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", Authorization)
client := &http.Client{}
resp, err = client.Do(req)
return
}
func getReq(path string) (resp *http.Response, err error) {
req, err := http.NewRequest("GET", Endpoint+path, nil)
if err != nil {
return
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", Authorization)
client := &http.Client{}
resp, err = client.Do(req)
return
}
// Return the Status of a job
func Status(id string) (status StatusResp, err error) {
resp, err := getReq("status/" + id)
if err != nil {
return
}
err = json.NewDecoder(resp.Body).Decode(&status)
if err != nil {
return
}
defer resp.Body.Close()
return
}
// Run a job
func Run(data []byte) (req Request, err error) {
resp, err := postReq(data, "run")
if err != nil {
return
}
err = json.NewDecoder(resp.Body).Decode(&req)
if err != nil {
return
}
defer resp.Body.Close()
return
}
// Poll job until it succeeded or failed and return it
func Poll(id string, delay time.Duration) (status StatusResp, err error) {
for {
status, err = Status(id)
if err != nil {
return
}
switch status.Status {
case "COMPLETED":
return
case "IN_QUEUE", "IN_PROGRESS":
time.Sleep(delay)
continue
default:
return status, fmt.Errorf("job failed with status '%s'", status.Status)
}
}
}

View file

@ -2,6 +2,7 @@ package s3
import (
"context"
"io"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
@ -46,5 +47,11 @@ func Download(name string) (data *minio.Object, err error) {
return minioClient.GetObject(ctx, bucketName, name, minio.GetObjectOptions{})
}
// upload a generated
func Upload() {}
func Upload(contentType string, name string, data io.Reader, metadata map[string]string) (err error) {
// TODO pass size and do single PUT
_, err = minioClient.PutObject(ctx, bucketName, name, data, -1, minio.PutObjectOptions{
ContentType: contentType,
UserMetadata: metadata,
})
return
}