Created
March 20, 2023 08:43
-
-
Save gertjana/86485fc496233ddb9a561622e2420f19 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/json" | |
"fmt" | |
"io/ioutil" | |
"net/http" | |
"os" | |
"strings" | |
"time" | |
) | |
// Provide your API key as a LEAP_API_KEY environment variable | |
var API_KEY = os.Getenv("LEAP_API_KEY") | |
var HEADERS = map[string]string{ | |
"Content-Type": "application/json", | |
"Accept": "application/json", | |
"Authorization": "Bearer " + API_KEY, | |
} | |
const BASE_API = "https://api.tryleap.ai/api/v1" | |
// If a modelId is passed in as the first argument, it will use that model. Otherwise, it will create a new model and train it. | |
func main() { | |
const modelName = "GertjanTest" | |
const prompt = "Photo's of @me with phycadelic colors, high contrast, and a lot of detail." | |
var sample_images = []string{ | |
"https://scontent-ams4-1.xx.fbcdn.net/v/t1.6435-9/38507713_10155827789237239_1223092985231572992_n.jpg?_nc_cat=102&ccb=1-7&_nc_sid=8bfeb9&_nc_ohc=SpI49TU-LZgAX9Tih_6&_nc_ht=scontent-ams4-1.xx&oh=00_AfCVa59wzV2QP-M1iatZt8adOKaQExqMRxFnnjA_R8QNrg&oe=643D0668", | |
"https://scontent-ams4-1.xx.fbcdn.net/v/t1.6435-9/36088073_10155734387182239_1766884150402351104_n.jpg?_nc_cat=109&ccb=1-7&_nc_sid=8bfeb9&_nc_ohc=jOgd9s36aT0AX9_l6hE&_nc_ht=scontent-ams4-1.xx&oh=00_AfDLUiY6iH33W4JiKlBWOP6yHewPJd2TXM_c4_Q6mL5dcw&oe=643D22B6", | |
"https://scontent-ams4-1.xx.fbcdn.net/v/t1.18169-9/1918748_144973647238_127197_n.jpg?_nc_cat=110&ccb=1-7&_nc_sid=cdbe9c&_nc_ohc=tz6x6BLDyU4AX_MMnSC&_nc_ht=scontent-ams4-1.xx&oh=00_AfDMre6PRPLBSKpGWxXPMW7x1sKuhpTV1Ayu6iEn3Gar2Q&oe=643D07F9", | |
"https://scontent-ams2-1.xx.fbcdn.net/v/t1.6435-9/78835646_10156940909177239_7040054755149742080_n.jpg?_nc_cat=105&ccb=1-7&_nc_sid=8bfeb9&_nc_ohc=JINnJArEVt4AX_znDbv&_nc_ht=scontent-ams2-1.xx&oh=00_AfBzLP8aM5BLymF9sdmCVMf6GRsBZ206iSeqQo1w70XdmQ&oe=643D0E79", | |
"https://scontent-ams4-1.xx.fbcdn.net/v/t1.6435-9/45342220_10156018880582239_634734194265686016_n.jpg?_nc_cat=107&ccb=1-7&_nc_sid=8bfeb9&_nc_ohc=6Ot2H9ah-QYAX-MpA4e&_nc_ht=scontent-ams4-1.xx&oh=00_AfABy-IuEvCRzXFfqyGytHEJLKVjSKfbDDnMyhRMtnNWRg&oe=643D1E00", | |
} | |
var modelId string | |
var err error | |
if len(os.Args) > 1 { | |
modelId = os.Args[1] | |
fmt.Println("Using model: ", modelId) | |
} else { | |
fmt.Println("Creating model: ", modelName) | |
if modelId, err = CreateModel(modelName); err != nil { | |
fmt.Println("Error creating model: ", err) | |
} | |
} | |
if err = TrainModel(modelId, sample_images); err != nil { | |
fmt.Println("Error creating model: ", err) | |
} | |
var generated_images []string | |
if generated_images, err = RunModel(modelId, prompt); err != nil { | |
fmt.Println("Error running model: ", err) | |
} else { | |
for _, image := range generated_images { | |
fmt.Println(image) | |
} | |
} | |
} | |
func CreateModel(title string) (string, error) { | |
url := fmt.Sprintf("%s/images/models", BASE_API) | |
payload := fmt.Sprintf(`{"title": "%s", "subjectKeyword": "@me"}`, title) | |
var err error | |
var response []byte | |
if response, err = post(url, payload); err != nil { | |
return "", err | |
} | |
data := make(map[string]string) | |
if err = json.Unmarshal(response, &data); err != nil { | |
return "", err | |
} | |
return data["id"], nil | |
} | |
func TrainModel(modelId string, sample_images []string) error { | |
fmt.Println("Uploading sample images") | |
var err error | |
if err = uploadImageSamples(modelId, sample_images); err != nil { | |
return err | |
} | |
fmt.Println("Queueing training job...") | |
if versionId, model_status, err := queueTrainingJob(modelId); err != nil { | |
return err | |
} else { | |
for model_status != "finished" { | |
if versionId, model_status, err = getModelVersion(modelId, versionId); err != nil { | |
return err | |
} | |
time.Sleep(10 * time.Second) | |
} | |
} | |
return nil | |
} | |
func RunModel(modelId string, prompt string) ([]string, error) { | |
fmt.Println("Generating image...") | |
var inferenceId, inference_status string | |
var err error | |
if inferenceId, inference_status, err = generateImage(modelId, prompt); err != nil { | |
return nil, err | |
} | |
fmt.Printf("inferenceId: %s, inference_status: %s", inferenceId, inference_status) | |
var images []string | |
for inference_status != "finished" { | |
if inferenceId, inference_status, images, err = getInferenceJob(modelId, inferenceId); err != nil { | |
return nil, err | |
} | |
time.Sleep(10 * time.Second) | |
} | |
return images, nil | |
} | |
func get(url string) ([]byte, error) { | |
return req(url, "", "GET") | |
} | |
func post(url string, payload string) ([]byte, error) { | |
return req(url, payload, "POST") | |
} | |
func req(url string, payload string, method string) ([]byte, error) { | |
var body []byte | |
var req *http.Request | |
var res *http.Response | |
var err error | |
if req, err = http.NewRequest(method, url, strings.NewReader(payload)); err != nil { | |
return nil, err | |
} | |
for key, value := range HEADERS { | |
req.Header.Add(key, value) | |
} | |
if res, err = http.DefaultClient.Do(req); err != nil { | |
return nil, err | |
} | |
defer res.Body.Close() | |
if body, err = ioutil.ReadAll(res.Body); err != nil { | |
return nil, err | |
} | |
return body, nil | |
} | |
func uploadImageSamples(model_id string, sample_images []string) error { | |
url := fmt.Sprintf("%s/images/models/%s/samples/url", BASE_API, model_id) | |
payload, _ := json.Marshal(map[string][]string{"images": sample_images}) | |
var err error | |
if _, err = post(url, string(payload)); err != nil { | |
return err | |
} | |
return nil | |
} | |
func queueTrainingJob(model_id string) (string, string, error) { | |
url := fmt.Sprintf("%s/images/models/%s/queue", BASE_API, model_id) | |
var err error | |
var response []byte | |
if response, err = post(url, ""); err != nil { | |
return "", "", err | |
} | |
data := make(map[string]string) | |
if err = json.Unmarshal(response, &data); err != nil { | |
return "", "", err | |
} | |
version_id := data["id"] | |
status := data["status"] | |
fmt.Printf("Version ID: %s, Status: %s\n", version_id, status) | |
return version_id, status, nil | |
} | |
func getModelVersion(modelId, versionId string) (string, string, error) { | |
url := fmt.Sprintf("%s/images/models/%s/versions/%s", BASE_API, modelId, versionId) | |
var err error | |
var response []byte | |
if response, err = get(url); err != nil { | |
return "", "", err | |
} | |
data := make(map[string]interface{}) | |
if err = json.Unmarshal(response, &data); err != nil { | |
return "", "", err | |
} | |
status := data["status"].(string) | |
fmt.Printf("Version ID: %s. Status: %s\n", versionId, status) | |
return versionId, status, nil | |
} | |
func generateImage(modelId, prompt string) (string, string, error) { | |
url := fmt.Sprintf("%s/images/models/%s/inferences", BASE_API, modelId) | |
payload, _ := json.Marshal(map[string]interface{}{ | |
"prompt": prompt, | |
"steps": 50, | |
"width": 512, | |
"height": 512, | |
"numberOfImages": 4, | |
"seed": 4523184, | |
"enhancePrompt": true, | |
"restoreFaces": true, | |
}) | |
var err error | |
var response []byte | |
if response, err = post(url, string(payload)); err != nil { | |
return "", "", err | |
} | |
data := make(map[string]interface{}) | |
if err = json.Unmarshal(response, &data); err != nil { | |
return "", "", err | |
} | |
inferenceId := data["id"].(string) | |
status := data["status"].(string) | |
fmt.Printf("InferenceId: %s, Status: %s\n", inferenceId, status) | |
return inferenceId, status, nil | |
} | |
func getInferenceJob(modelId string, inferenceId string) (string, string, []string, error) { | |
url := fmt.Sprintf("%s/images/models/%s/inferences/%s", BASE_API, modelId, inferenceId) | |
var err error | |
var response []byte | |
if response, err = get(url); err != nil { | |
return "", "", nil, err | |
} | |
data := make(map[string]interface{}) | |
if err = json.Unmarshal(response, &data); err != nil { | |
return "", "", nil, err | |
} | |
var state string = "" | |
if data["state"] != nil { | |
state = data["state"].(string) | |
} | |
var images []string = make([]string, 5) | |
if data["images"] != nil { | |
for _, image := range data["images"].([]interface{}) { | |
images = append(images, image.(map[string]interface{})["uri"].(string)) | |
} | |
} | |
fmt.Printf("Inference ID: %s. State: %s\n", inferenceId, state) | |
return inferenceId, state, images, nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment