108 lines
3.1 KiB
Go
108 lines
3.1 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"waifu_gallery/s3"
|
|
)
|
|
|
|
type Config struct {
|
|
// Fooocus serverless api endpoint
|
|
RunpodApi struct {
|
|
Endpoint string `json:"endpoint"`
|
|
Authorization string `json:"Authorization"`
|
|
} `json:"runpodApi"`
|
|
|
|
S3 s3.Config `json:"s3"`
|
|
|
|
// Where our web server listens
|
|
ListenAddress string `json:"listenAddress"`
|
|
MaxConcurrentJobs int `json:"maxConcurrentJobs"` // set this to limit the number of generation jobs that can be set in parallel, 0 to disable
|
|
|
|
// Content generation properties
|
|
Loras []Lora `json:"loras"`
|
|
AspectRatio string `json:"aspectRatio"` // necessary! used to compute display properties. e.g: 512*512
|
|
ImageSeed int `json:"imageSeed"` // set it to make image generation deterministic
|
|
RandomPrompt []string `json:"randomPrompt"` // prepended to base prompt, used to spice things up a bit (e.g multiple artist styles)
|
|
BasePrompt string `json:"basePrompt"` // user prompt will be appended to this
|
|
NegativePrompt string `json:"negativePrompt"` // self explanatory
|
|
UserPromptBias float64 `json:"userPromptBias"` // how much do we favor the user's input
|
|
GuidanceScale float64 `json:"GuidanceScale"` // The guidance scale, default 7.0
|
|
Sharpness float64 `json:"sharpness"` // Sharpness, defaults to 2.0
|
|
Steps int `json:"steps"` // Number of generation steps, default 25
|
|
Sampler string `json:"sampler"` // Which sampler to use, default euler_a
|
|
BaseModel string `json:"baseModel"` // Base model, leave it empty to use fooocus' default
|
|
BannedWords []string `json:"bannedWords"` // user input will be blocked if it contains one of these
|
|
|
|
// Internal app state
|
|
state struct {
|
|
width int
|
|
height int
|
|
|
|
bannedWordsRegex *regexp.Regexp
|
|
}
|
|
}
|
|
|
|
func parseConfig(name string, cfg *Config) (err error) {
|
|
// read file
|
|
file, err := os.ReadFile(name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// unmarshal
|
|
err = json.Unmarshal(file, &cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// set defaults
|
|
if len(cfg.ListenAddress) == 0 {
|
|
cfg.ListenAddress = "127.0.0.1:8080"
|
|
}
|
|
if len(cfg.AspectRatio) == 0 {
|
|
cfg.AspectRatio = "1024*1024"
|
|
}
|
|
if len(cfg.Sampler) == 0 {
|
|
cfg.Sampler = "euler_ancestral"
|
|
}
|
|
if cfg.UserPromptBias == 0 {
|
|
cfg.UserPromptBias = 0.5
|
|
}
|
|
if cfg.GuidanceScale == 0 {
|
|
cfg.GuidanceScale = 7.0
|
|
}
|
|
if cfg.Steps == 0 {
|
|
cfg.Steps = 25
|
|
}
|
|
if cfg.Sharpness == 0 {
|
|
cfg.Sharpness = 2.0
|
|
}
|
|
|
|
// initialize aspect ratio
|
|
parts := strings.Split(cfg.AspectRatio, "*")
|
|
if len(parts) != 2 {
|
|
return errors.New("invalid aspect ratio format (must be something like 1024*1024)")
|
|
}
|
|
cfg.state.width, err = strconv.Atoi(parts[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg.state.height, err = strconv.Atoi(parts[1])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if cfg.state.width <= 0 || cfg.state.height <= 0 {
|
|
return errors.New("invalid aspect ratio (must be something like 1024*1024)")
|
|
}
|
|
// initialize badword regex
|
|
config.state.bannedWordsRegex = regexp.MustCompile(strings.Join(cfg.BannedWords, "|"))
|
|
|
|
return nil
|
|
}
|