Last active
November 13, 2023 12:49
-
-
Save sebsto/f4f589470ee97e5b5943bc3ecaae1e40 to your computer and use it in GitHub Desktop.
Example of code to invoke Cohere Embed model on Amazon Bedrock in the Swift programming language
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
import ClientRuntime | |
// reduce the verbosity of the AWS SDK | |
SDKLoggingSystem.initialize(logLevel: .warning) | |
import AWSBedrock | |
import AWSBedrockRuntime | |
// create a Bedrock client and list available models for a provider | |
let provider = "cohere" | |
print("====== Models available for \(provider)") | |
let client = try BedrockClient(region: "us-east-1") | |
let input = ListFoundationModelsInput(byProvider : provider) | |
let output = try await client.listFoundationModels(input: input) | |
print(output.modelSummaries!.map { "\($0.modelName!) : \($0.modelId!)" }.joined(separator: "\n")) | |
print("======") | |
// create a bedrock runtime client and invoke a model | |
let modelId = "cohere.embed-english-v3" | |
let runtime = try BedrockRuntimeClient(region: "us-east-1") | |
let document = | |
""" | |
This is a document that provides context about a business domain | |
""" | |
let payload = CohereEmbedDocument(texts: [document], inputType: .searchDocument) | |
let request = InvokeModelInput(body: try payload.encode(), | |
contentType: "application/json", | |
modelId: modelId) | |
do { | |
let invokeModelOutput = try await runtime.invokeModel(input: request) | |
// print("== raw response ==") | |
// print(String(data: invokeModelOutput.body!, encoding: .utf8)!) | |
let cohereResponse = try CohereEmbedResponse(from: invokeModelOutput.body!) | |
print(cohereResponse) | |
print("======") | |
} catch { | |
print(error) | |
} | |
//https://docs.cohere.com/reference/embed | |
enum CohereEmbedInputType: String, Encodable { | |
case searchDocument = "search_document" | |
case searchQuery = "search_query" | |
case clasification = "classification" | |
case clustering = "clustering" | |
} | |
enum CohereEmbedTruncating: String, Encodable { | |
case none = "NONE" | |
case start = "START" | |
case end = "END" | |
} | |
struct CohereEmbedDocument: Encodable { | |
let texts: [String] | |
let inputType: CohereEmbedInputType | |
let truncate: CohereEmbedTruncating = .none | |
func encode() throws -> Data { | |
let encoder = JSONEncoder() | |
encoder.keyEncodingStrategy = .convertToSnakeCase | |
return try encoder.encode(self) | |
} | |
} | |
struct CohereEmbedResponse: Decodable, CustomStringConvertible { | |
let embeddings: [[Double]] | |
let id: String | |
let texts: [String] | |
init(from data: Data) throws { | |
let decoder = JSONDecoder() | |
self = try decoder.decode(CohereEmbedResponse.self, from: data) | |
} | |
var description: String { | |
guard embeddings.count > 0 else { | |
return "no embeddings" | |
} | |
let embedding = embeddings[0] | |
let elementsToShow = min(5, embedding.count) | |
return "[" + | |
embedding[0..<elementsToShow].map { String(format: "%.3f", $0) }.joined(separator: ",") | |
+ ",...] (\(embedding.count) elements)" | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment