Skip to content

Instantly share code, notes, and snippets.

@azer
Last active October 14, 2024 20:53
Show Gist options
  • Save azer/d49ab3ca55df9fc4160d2be8235e2643 to your computer and use it in GitHub Desktop.
Save azer/d49ab3ca55df9fc4160d2be8235e2643 to your computer and use it in GitHub Desktop.
Go Client for Fal.ai
package fal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/azer/logger"
)
// Client represents a Fal.ai API client.
type Client struct {
apiKey string
}
// Call represents an ongoing API call to Fal.ai.
type Call struct {
client *Client
requestID string
statusURL string
responseURL string
cancelURL string
}
// RequestOptions contains options for making an API request.
type RequestOptions struct {
Method string
Payload interface{}
}
// Response represents the response from the Fal.ai API.
type Response struct {
Status string `json:"status,omitempty"`
RequestID string `json:"request_id,omitempty"`
ResponseURL string `json:"response_url,omitempty"`
StatusURL string `json:"status_url,omitempty"`
CancelURL string `json:"cancel_url,omitempty"`
Images []Image `json:"images,omitempty"`
Timings Timings `json:"timings,omitempty"`
Seed int `json:"seed,omitempty"`
HasNSFWConcepts []bool `json:"has_nsfw_concepts,omitempty"`
Prompt string `json:"prompt,omitempty"`
QueuePosition int `json:"queue_position,omitempty"`
}
// Image represents an image in the API response.
type Image struct {
URL string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
ContentType string `json:"content_type"`
}
// Timings contains timing information for the API call.
type Timings struct {
Inference float64 `json:"inference"`
}
// NewClient creates a new Fal.ai API client.
func NewClient(apiKey string) *Client {
return &Client{
apiKey: apiKey,
}
}
// Call initiates a new API call to Fal.ai.
func (c *Client) Call(path string, options RequestOptions) (*Call, error) {
timer := log.Timer()
defer timer.End("Initial request completed")
url := fmt.Sprintf("https://queue.fal.run/%s", path)
log.Info("Making initial request", logger.Attrs{"url": url, "method": options.Method})
payload, err := json.Marshal(options.Payload)
if err != nil {
return nil, fmt.Errorf("Failed to marshal payload: %w", err)
}
req, err := http.NewRequest(options.Method, url, bytes.NewBuffer(payload))
if err != nil {
return nil, fmt.Errorf("Failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Key %s", c.apiKey))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("Request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("Failed to read response body: %w", err)
}
var initialResponse Response
if err := json.Unmarshal(body, &initialResponse); err != nil {
return nil, fmt.Errorf("Failed to unmarshal response: %w", err)
}
log.Info("Initial response received", logger.Attrs{"requestID": initialResponse.RequestID})
return &Call{
client: c,
requestID: initialResponse.RequestID,
statusURL: initialResponse.StatusURL,
responseURL: initialResponse.ResponseURL,
cancelURL: initialResponse.CancelURL,
}, nil
}
// CheckStatus checks the current status of the API call.
func (call *Call) CheckStatus() (*Response, error) {
req, err := http.NewRequest("GET", call.statusURL, nil)
if err != nil {
return nil, fmt.Errorf("Failed to create status request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Key %s", call.client.apiKey))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("Status request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("Failed to read status response body: %w", err)
}
var statusResponse Response
if err := json.Unmarshal(body, &statusResponse); err != nil {
return nil, fmt.Errorf("Failed to unmarshal status response: %w", err)
}
return &statusResponse, nil
}
// FetchResponse retrieves the final response of the API call.
func (call *Call) FetchResponse() (*Response, error) {
req, err := http.NewRequest("GET", call.responseURL, nil)
if err != nil {
return nil, fmt.Errorf("Failed to create final response request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Key %s", call.client.apiKey))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("Final response request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("Failed to read final response body: %w", err)
}
var finalResponse Response
if err := json.Unmarshal(body, &finalResponse); err != nil {
return nil, fmt.Errorf("Failed to unmarshal final response: %w", err)
}
log.Info("Final response received", logger.Attrs{"requestID": finalResponse.RequestID})
return &finalResponse, nil
}
// Cancel cancels the ongoing API call.
func (call *Call) Cancel() error {
req, err := http.NewRequest("POST", call.cancelURL, nil)
if err != nil {
return fmt.Errorf("Failed to create cancel request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Key %s", call.client.apiKey))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("Cancel request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Failed to cancel request, status code: %d", resp.StatusCode)
}
log.Info("Request cancelled", logger.Attrs{"requestID": call.requestID})
return nil
}
// PollUntilCompletion polls the status of the API call until it's completed or failed.
func (c *Client) PollUntilCompletion(ctx context.Context, call *Call) (*Response, error) {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
statusResponse, err := call.CheckStatus()
if err != nil {
return nil, err
}
switch statusResponse.Status {
case "COMPLETED":
log.Info("Request completed", logger.Attrs{"request_id": request_id})
return call.FetchResponse()
case "FAILED":
log.Error("Request failed", logger.Attrs{"request_id": request_id})
return nil, fmt.Errorf("Request failed")
default:
//log.Info("Request in progress", logger.Attrs{ "rid": id, "status": statusResponse.Status, "queuePosition": statusResponse.QueuePosition })
}
}
}
}
// PollWithProgress polls the status of the API call and provides progress updates through a channel.
func (c *Client) PollWithProgress(ctx context.Context, call *Call, progressCh chan<- *Response) (*Response, error) {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
statusResponse, err := call.CheckStatus()
if err != nil {
return nil, err
}
// Send the status update to the progress channel
select {
case progressCh <- statusResponse:
default:
// If the channel is full, we skip this update
}
switch statusResponse.Status {
case "COMPLETED":
return call.FetchResponse()
case "FAILED":
return nil, fmt.Errorf("Request failed")
}
}
}
}
client := NewClient(falKey)

call, err := client.Call(falAPIPath, RequestOptions{
		Method:  "POST",
		Payload: request,
})

response, err := client.PollUntilCompletion(context, call)
if err != nil {
  if err == context.DeadlineExceeded {
	  return nil, fmt.Errorf("Fal request timed out")
	}

	return nil, fmt.Errorf("Fal request failed: %w", err)
}

response.Images[0].URL
// fal.media/.../...png
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment