client := NewClient(falKey)
call, err := client.Call(falAPIPath, RequestOptions{
Method: "POST",
Payload: request,
})
response, err := client.PollUntilCompletion(context, call)
if err != nil {
if err == context.DeadlineExceeded {
return nil, fmt.Errorf("Fal request timed out")
}
return nil, fmt.Errorf("Fal request failed: %w", err)
}
response.Images[0].URL
// fal.media/.../...png
Last active
October 14, 2024 20:53
-
-
Save azer/d49ab3ca55df9fc4160d2be8235e2643 to your computer and use it in GitHub Desktop.
Go Client for Fal.ai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package fal | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"time" | |
"github.com/azer/logger" | |
) | |
// Client represents a Fal.ai API client. | |
type Client struct { | |
apiKey string | |
} | |
// Call represents an ongoing API call to Fal.ai. | |
type Call struct { | |
client *Client | |
requestID string | |
statusURL string | |
responseURL string | |
cancelURL string | |
} | |
// RequestOptions contains options for making an API request. | |
type RequestOptions struct { | |
Method string | |
Payload interface{} | |
} | |
// Response represents the response from the Fal.ai API. | |
type Response struct { | |
Status string `json:"status,omitempty"` | |
RequestID string `json:"request_id,omitempty"` | |
ResponseURL string `json:"response_url,omitempty"` | |
StatusURL string `json:"status_url,omitempty"` | |
CancelURL string `json:"cancel_url,omitempty"` | |
Images []Image `json:"images,omitempty"` | |
Timings Timings `json:"timings,omitempty"` | |
Seed int `json:"seed,omitempty"` | |
HasNSFWConcepts []bool `json:"has_nsfw_concepts,omitempty"` | |
Prompt string `json:"prompt,omitempty"` | |
QueuePosition int `json:"queue_position,omitempty"` | |
} | |
// Image represents an image in the API response. | |
type Image struct { | |
URL string `json:"url"` | |
Width int `json:"width"` | |
Height int `json:"height"` | |
ContentType string `json:"content_type"` | |
} | |
// Timings contains timing information for the API call. | |
type Timings struct { | |
Inference float64 `json:"inference"` | |
} | |
// NewClient creates a new Fal.ai API client. | |
func NewClient(apiKey string) *Client { | |
return &Client{ | |
apiKey: apiKey, | |
} | |
} | |
// Call initiates a new API call to Fal.ai. | |
func (c *Client) Call(path string, options RequestOptions) (*Call, error) { | |
timer := log.Timer() | |
defer timer.End("Initial request completed") | |
url := fmt.Sprintf("https://queue.fal.run/%s", path) | |
log.Info("Making initial request", logger.Attrs{"url": url, "method": options.Method}) | |
payload, err := json.Marshal(options.Payload) | |
if err != nil { | |
return nil, fmt.Errorf("Failed to marshal payload: %w", err) | |
} | |
req, err := http.NewRequest(options.Method, url, bytes.NewBuffer(payload)) | |
if err != nil { | |
return nil, fmt.Errorf("Failed to create request: %w", err) | |
} | |
req.Header.Set("Authorization", fmt.Sprintf("Key %s", c.apiKey)) | |
req.Header.Set("Content-Type", "application/json") | |
client := &http.Client{} | |
resp, err := client.Do(req) | |
if err != nil { | |
return nil, fmt.Errorf("Request failed: %w", err) | |
} | |
defer resp.Body.Close() | |
body, err := io.ReadAll(resp.Body) | |
if err != nil { | |
return nil, fmt.Errorf("Failed to read response body: %w", err) | |
} | |
var initialResponse Response | |
if err := json.Unmarshal(body, &initialResponse); err != nil { | |
return nil, fmt.Errorf("Failed to unmarshal response: %w", err) | |
} | |
log.Info("Initial response received", logger.Attrs{"requestID": initialResponse.RequestID}) | |
return &Call{ | |
client: c, | |
requestID: initialResponse.RequestID, | |
statusURL: initialResponse.StatusURL, | |
responseURL: initialResponse.ResponseURL, | |
cancelURL: initialResponse.CancelURL, | |
}, nil | |
} | |
// CheckStatus checks the current status of the API call. | |
func (call *Call) CheckStatus() (*Response, error) { | |
req, err := http.NewRequest("GET", call.statusURL, nil) | |
if err != nil { | |
return nil, fmt.Errorf("Failed to create status request: %w", err) | |
} | |
req.Header.Set("Authorization", fmt.Sprintf("Key %s", call.client.apiKey)) | |
client := &http.Client{} | |
resp, err := client.Do(req) | |
if err != nil { | |
return nil, fmt.Errorf("Status request failed: %w", err) | |
} | |
defer resp.Body.Close() | |
body, err := io.ReadAll(resp.Body) | |
if err != nil { | |
return nil, fmt.Errorf("Failed to read status response body: %w", err) | |
} | |
var statusResponse Response | |
if err := json.Unmarshal(body, &statusResponse); err != nil { | |
return nil, fmt.Errorf("Failed to unmarshal status response: %w", err) | |
} | |
return &statusResponse, nil | |
} | |
// FetchResponse retrieves the final response of the API call. | |
func (call *Call) FetchResponse() (*Response, error) { | |
req, err := http.NewRequest("GET", call.responseURL, nil) | |
if err != nil { | |
return nil, fmt.Errorf("Failed to create final response request: %w", err) | |
} | |
req.Header.Set("Authorization", fmt.Sprintf("Key %s", call.client.apiKey)) | |
client := &http.Client{} | |
resp, err := client.Do(req) | |
if err != nil { | |
return nil, fmt.Errorf("Final response request failed: %w", err) | |
} | |
defer resp.Body.Close() | |
body, err := io.ReadAll(resp.Body) | |
if err != nil { | |
return nil, fmt.Errorf("Failed to read final response body: %w", err) | |
} | |
var finalResponse Response | |
if err := json.Unmarshal(body, &finalResponse); err != nil { | |
return nil, fmt.Errorf("Failed to unmarshal final response: %w", err) | |
} | |
log.Info("Final response received", logger.Attrs{"requestID": finalResponse.RequestID}) | |
return &finalResponse, nil | |
} | |
// Cancel cancels the ongoing API call. | |
func (call *Call) Cancel() error { | |
req, err := http.NewRequest("POST", call.cancelURL, nil) | |
if err != nil { | |
return fmt.Errorf("Failed to create cancel request: %w", err) | |
} | |
req.Header.Set("Authorization", fmt.Sprintf("Key %s", call.client.apiKey)) | |
client := &http.Client{} | |
resp, err := client.Do(req) | |
if err != nil { | |
return fmt.Errorf("Cancel request failed: %w", err) | |
} | |
defer resp.Body.Close() | |
if resp.StatusCode != http.StatusOK { | |
return fmt.Errorf("Failed to cancel request, status code: %d", resp.StatusCode) | |
} | |
log.Info("Request cancelled", logger.Attrs{"requestID": call.requestID}) | |
return nil | |
} | |
// PollUntilCompletion polls the status of the API call until it's completed or failed. | |
func (c *Client) PollUntilCompletion(ctx context.Context, call *Call) (*Response, error) { | |
ticker := time.NewTicker(500 * time.Millisecond) | |
defer ticker.Stop() | |
for { | |
select { | |
case <-ctx.Done(): | |
return nil, ctx.Err() | |
case <-ticker.C: | |
statusResponse, err := call.CheckStatus() | |
if err != nil { | |
return nil, err | |
} | |
switch statusResponse.Status { | |
case "COMPLETED": | |
log.Info("Request completed", logger.Attrs{"request_id": request_id}) | |
return call.FetchResponse() | |
case "FAILED": | |
log.Error("Request failed", logger.Attrs{"request_id": request_id}) | |
return nil, fmt.Errorf("Request failed") | |
default: | |
//log.Info("Request in progress", logger.Attrs{ "rid": id, "status": statusResponse.Status, "queuePosition": statusResponse.QueuePosition }) | |
} | |
} | |
} | |
} | |
// PollWithProgress polls the status of the API call and provides progress updates through a channel. | |
func (c *Client) PollWithProgress(ctx context.Context, call *Call, progressCh chan<- *Response) (*Response, error) { | |
ticker := time.NewTicker(500 * time.Millisecond) | |
defer ticker.Stop() | |
for { | |
select { | |
case <-ctx.Done(): | |
return nil, ctx.Err() | |
case <-ticker.C: | |
statusResponse, err := call.CheckStatus() | |
if err != nil { | |
return nil, err | |
} | |
// Send the status update to the progress channel | |
select { | |
case progressCh <- statusResponse: | |
default: | |
// If the channel is full, we skip this update | |
} | |
switch statusResponse.Status { | |
case "COMPLETED": | |
return call.FetchResponse() | |
case "FAILED": | |
return nil, fmt.Errorf("Request failed") | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment