waifu_gallery/openrouter/api.go

127 lines
6.3 KiB
Go

package openrouter
// Should work with regular ChatGPT API endpoints, besides the routing options and some optinal features
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
)
const (
RoleAssistant = "assistant"
RoleSystem = "system"
RoleUser = "user"
)
type OpenRouterClient struct {
endpoint string
authorization string
httpClient *http.Client
}
type Message struct {
Role string `json:"role"` // required, use the Role* consts
Content string `json:"content"` // required
}
// Used to specifiy openrouter routing options
type ProviderPreferences struct {
Order []string `json:"order,omitempty"` // set the providers that OpenRouter will use for your request
RequireParameters bool `json:"require_parameters,omitempty"` // Require provider to support all request parameters (beta)
DataCollection bool `json:"data_collection,omitempty"` // deny or allow providers to log, default allow
AllowFallbacks bool `json:"allow_fallbacks,omitempty"` // To guarantee that your request is only served by the top (lowest-cost) provider, you can disable fallbacks
}
// Used to send requests
type ChatCompletions struct {
// OpenRouter specific
Model string `json:"model"` // required, name of the model to use
Messages []Message `json:"messages"` // required, the chat history
Preferences ProviderPreferences `json:"preferences,omitempty"` // OpenRouter preferences
MaxTokens int `json:"max_tokens,omitempty"` // maximum nbr of token to generate
Temperature float64 `json:"temperature,omitempty"` // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.
TopP float64 `json:"top_p,omitempty"` // An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.
TopK int `json:"top_k,omitempty"` // Top-k sampling is another sampling method where the k most probable next tokens are filtered and the probability mass is redistributed among only those k next tokens. The value of k controls the number of candidates for the next token at each step during text generation.
Stop []string `json:"stop,omitempty"` // Up to 4 sequences where the API will stop generating further tokens. The returned text will contain the stop sequence.
// Optional features
FrequencyPenalty float64 `json:"frequency_penatly,omitempty"` // Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. Reasonable value is around 0.1 to 1 if the aim is to just reduce repetitive samples somewhat. If the aim is to strongly suppress repetition, then one can increase the coefficients up to 2, but this can noticeably degrade the quality of samples. Negative values can be used to increase the likelihood of repetition. See also presence_penalty for penalizing tokens that have at least one appearance at a fixed rate.
PresencePenalty float64 `json:"presence_penalty,omitempty"` // Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. Reasonable value is around 0.1 to 1 if the aim is to just reduce repetitive samples somewhat. If the aim is to strongly suppress repetition, then one can increase the coefficients up to 2, but this can noticeably degrade the quality of samples. Negative values can be used to increase the likelihood of repetition. See also frequence_penalty for penalizing tokens at an increasing rate depending on how often they appear.
ResponseFormat struct{} `json:"response_format"` // Allows specification of a response format and associated schema that will constrain the LLM output to that structure. For example, using the `json_object` type allows you to provide a desired json schema for the output to follow.
// Stream bool `json:"stream,omitempty"` // Indicates whether the response should be streamed
}
type Usage struct {
CompletionTokens int `json:"completion_tokens"` // Equivalent to "native_tokens_completion" in the /generation API
PromptTokens int `json:"prompt_tokens"` // Equivalent to "native_tokens_prompt"
TotalTokens int `json:"total_tokens"` // Sum of the above two fields
TotalCost float64 `json:"total_cost"` // Number of credits used by this generation
}
type NonStreamingChoice struct {
FinishReason string `json:"finish_reason"`
Message struct {
Message string `json:"content"`
Role string `json:"role"`
} `json:"message"`
}
// Non streaming output result
type NonStreamingChatOutput struct {
Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"`
Created int64 `json:"created"`
Usage Usage `json:"usage"`
// Output []NonStreamingChoice `json:"choices"`
Output []NonStreamingChoice `json:"choices"`
}
func (c *OpenRouterClient) postReq(data []byte) (resp *http.Response, err error) {
req, err := http.NewRequest("POST", c.endpoint, bytes.NewBuffer(data))
if err != nil {
return
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", c.authorization)
client := &http.Client{}
resp, err = client.Do(req)
if !(resp.StatusCode >= 200 && resp.StatusCode < 300) {
return resp, fmt.Errorf("LLM api resp HTTP code '%d'", resp.StatusCode)
}
return
}
func New(endpoint, token string) *OpenRouterClient {
return &OpenRouterClient{
endpoint: endpoint,
authorization: token,
httpClient: &http.Client{},
}
}
func (c *OpenRouterClient) ChatComplete(input ChatCompletions) (output NonStreamingChatOutput, err error) {
data, err := json.Marshal(input)
if err != nil {
return
}
resp, err := c.postReq(data)
if err != nil {
return
}
err = json.NewDecoder(resp.Body).Decode(&output)
if err != nil {
return
}
return
}