waifu_gallery/fooocus.go

137 lines
3.4 KiB
Go

package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/rand"
"strings"
"time"
"waifu_gallery/runpod"
"waifu_gallery/s3"
)
type AdvancedParam struct {
OverwriteStep int `json:"overwrite_step,omitempty"`
SamplerName string `json:"sampler_name,omitempty"`
}
type Lora struct {
ModelName string `json:"model_name"`
Weight float64 `json:"weight,omitempty"` // defaults to 0.5
Enabled bool `json:"enabled"` // set to true
}
type FooocusTxt2Img struct {
Api_name string `json:"api_name"`
RequireBase64 bool `json:"require_base64,omitempty"` // set this to true!!
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
BaseModelName string `json:"base_model_name,omitempty"`
Loras []Lora `json:"loras,omitempty"`
Styles []string `json:"style_selections"`
AspectRatio string `json:"aspect_ratios_selection,omitempty"`
ImageSeed int `json:"image_seed,omitempty"`
Sharpness float64 `json:"sharpness,omitempty"`
GuidanceScale float64 `json:"guidance_scale,omitempty"`
AdvancedParam AdvancedParam `json:"advanced_params,omitempty"`
}
type FooocusInput struct {
Input FooocusTxt2Img `json:"input"`
}
func buildTxt2ImgQuery(prompt string) (data []byte, err error) {
var query FooocusTxt2Img
query.Api_name = "txt2img"
query.RequireBase64 = true // necessary
query.Styles = []string{""}
query.Loras = config.Loras
query.AspectRatio = config.AspectRatio
query.GuidanceScale = config.GuidanceScale
query.Sharpness = config.Sharpness
query.AdvancedParam.OverwriteStep = config.Steps
query.AdvancedParam.SamplerName = config.Sampler
if len(config.BaseModel) != 0 {
query.BaseModelName = config.BaseModel
}
if config.ImageSeed != 0 {
query.ImageSeed = config.ImageSeed
}
// Now we build the prompt
var randPrompt string
if len(config.RandomPrompt) > 0 {
randPrompt = config.RandomPrompt[rand.Intn(len(config.RandomPrompt))]
}
if len(config.BasePrompt) > 0 {
query.Prompt = fmt.Sprintf("%s, %s, (%s:%f)", randPrompt, config.BasePrompt, prompt, config.UserPromptBias)
} else {
query.Prompt = prompt
}
query.NegativePrompt = config.NegativePrompt
data, err = json.Marshal(FooocusInput{Input: query})
if err != nil {
ErrorLogger.Println(err)
return
}
return
}
func createPicture(input string) (err error) {
query, err := buildTxt2ImgQuery(input)
if err != nil {
return
}
resp, err := runpod.Run(query)
if err != nil {
return
}
status, err := runpod.Poll(resp.Id, 250*time.Millisecond)
if err != nil {
return
}
if len(status.Output) == 0 {
return errors.New("no output in response")
}
output := status.Output[0]
// Decode base64
decodedBytes, err := base64.StdEncoding.DecodeString(output["base64"])
if err != nil {
return
}
// Upload file
err = s3.Upload(
"image/png",
input,
strings.NewReader(string(decodedBytes)),
int64(len(decodedBytes)),
map[string]string{
"Technique": "Digital Art", // TODO generate this via vision/LLM?
"Width": fmt.Sprintf("%d", config.state.width),
"Height": fmt.Sprintf("%d", config.state.height),
"Seed": output["seed"],
})
if err != nil {
ErrorLogger.Println(err)
}
// Upload Query
err = s3.Upload(
"application/json",
input+".json",
bytes.NewReader(query),
int64(len(query)),
map[string]string{
"Seed": output["seed"],
},
)
return
}