137 lines
3.4 KiB
Go
137 lines
3.4 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"strings"
|
|
"time"
|
|
|
|
"waifu_gallery/runpod"
|
|
"waifu_gallery/s3"
|
|
)
|
|
|
|
type AdvancedParam struct {
|
|
OverwriteStep int `json:"overwrite_step,omitempty"`
|
|
SamplerName string `json:"sampler_name,omitempty"`
|
|
}
|
|
|
|
type Lora struct {
|
|
ModelName string `json:"model_name"`
|
|
Weight float64 `json:"weight,omitempty"` // defaults to 0.5
|
|
Enabled bool `json:"enabled"` // set to true
|
|
}
|
|
|
|
type FooocusTxt2Img struct {
|
|
Api_name string `json:"api_name"`
|
|
RequireBase64 bool `json:"require_base64,omitempty"` // set this to true!!
|
|
Prompt string `json:"prompt"`
|
|
NegativePrompt string `json:"negative_prompt,omitempty"`
|
|
BaseModelName string `json:"base_model_name,omitempty"`
|
|
Loras []Lora `json:"loras,omitempty"`
|
|
Styles []string `json:"style_selections"`
|
|
AspectRatio string `json:"aspect_ratios_selection,omitempty"`
|
|
ImageSeed int `json:"image_seed,omitempty"`
|
|
Sharpness float64 `json:"sharpness,omitempty"`
|
|
GuidanceScale float64 `json:"guidance_scale,omitempty"`
|
|
AdvancedParam AdvancedParam `json:"advanced_params,omitempty"`
|
|
}
|
|
|
|
type FooocusInput struct {
|
|
Input FooocusTxt2Img `json:"input"`
|
|
}
|
|
|
|
func buildTxt2ImgQuery(prompt string) (data []byte, err error) {
|
|
var query FooocusTxt2Img
|
|
|
|
query.Api_name = "txt2img"
|
|
query.RequireBase64 = true // necessary
|
|
query.Styles = []string{""}
|
|
query.Loras = config.Loras
|
|
query.AspectRatio = config.AspectRatio
|
|
query.GuidanceScale = config.GuidanceScale
|
|
query.Sharpness = config.Sharpness
|
|
query.AdvancedParam.OverwriteStep = config.Steps
|
|
query.AdvancedParam.SamplerName = config.Sampler
|
|
if len(config.BaseModel) != 0 {
|
|
query.BaseModelName = config.BaseModel
|
|
}
|
|
if config.ImageSeed != 0 {
|
|
query.ImageSeed = config.ImageSeed
|
|
}
|
|
// Now we build the prompt
|
|
var randPrompt string
|
|
if len(config.RandomPrompt) > 0 {
|
|
randPrompt = config.RandomPrompt[rand.Intn(len(config.RandomPrompt))]
|
|
}
|
|
if len(config.BasePrompt) > 0 {
|
|
query.Prompt = fmt.Sprintf("%s, %s, (%s:%f)", randPrompt, config.BasePrompt, prompt, config.UserPromptBias)
|
|
} else {
|
|
query.Prompt = prompt
|
|
}
|
|
query.NegativePrompt = config.NegativePrompt
|
|
|
|
data, err = json.Marshal(FooocusInput{Input: query})
|
|
if err != nil {
|
|
ErrorLogger.Println(err)
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func createPicture(input string) (err error) {
|
|
query, err := buildTxt2ImgQuery(input)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
resp, err := runpod.Run(query)
|
|
if err != nil {
|
|
return
|
|
}
|
|
status, err := runpod.Poll(resp.Id, 250*time.Millisecond)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if len(status.Output) == 0 {
|
|
return errors.New("no output in response")
|
|
}
|
|
output := status.Output[0]
|
|
|
|
// Decode base64
|
|
decodedBytes, err := base64.StdEncoding.DecodeString(output["base64"])
|
|
if err != nil {
|
|
return
|
|
}
|
|
// Upload file
|
|
err = s3.Upload(
|
|
"image/png",
|
|
input,
|
|
strings.NewReader(string(decodedBytes)),
|
|
int64(len(decodedBytes)),
|
|
map[string]string{
|
|
"Technique": "Digital Art", // TODO generate this via vision/LLM?
|
|
"Width": fmt.Sprintf("%d", config.state.width),
|
|
"Height": fmt.Sprintf("%d", config.state.height),
|
|
"Seed": output["seed"],
|
|
})
|
|
if err != nil {
|
|
ErrorLogger.Println(err)
|
|
}
|
|
// Upload Query
|
|
err = s3.Upload(
|
|
"application/json",
|
|
input+".json",
|
|
bytes.NewReader(query),
|
|
int64(len(query)),
|
|
map[string]string{
|
|
"Seed": output["seed"],
|
|
},
|
|
)
|
|
return
|
|
}
|