Skip to content

Instantly share code, notes, and snippets.

@glaforge
Last active November 29, 2024 15:56
Show Gist options
  • Save glaforge/4e45fa4222dd803d6d8bbf2b9335e90d to your computer and use it in GitHub Desktop.
Save glaforge/4e45fa4222dd803d6d8bbf2b9335e90d to your computer and use it in GitHub Desktop.
Version #1 of ProgrammingIdioms with embeddings of the title, the description, and the code
package experiments;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Type;
import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.vertexai.VertexAiEmbeddingModel;
import dev.langchain4j.model.vertexai.VertexAiGeminiChatModel;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicReference;
public class ProgrammingIdioms {
private static final List<String> KNOWN_PROGRAMMING_LANGUAGES =
List.of("UNKNOWN",
"Go", "Rust", "Python", "Perl", "Ruby", "Java", "JS",
"C#", "Dart", "Pascal", "PHP", "C++", "Haskell", "D",
"Lua", "Clojure", "Fortran", "Elixir", "Kotlin",
"Erlang", "C", "Lisp", "VB", "Groovy", "Ada", "Scala",
"Scheme", "Smalltalk", "Obj-C", "Cobol", "Prolog", "Caml"
);
private static final ChatLanguageModel GEMINI_MODEL =
VertexAiGeminiChatModel.builder()
.project(System.getenv("GCP_PROJECT_ID"))
.location(System.getenv("GCP_LOCATION"))
.modelName("gemini-1.5-flash-002")
.responseSchema(Schema.newBuilder()
.setType(Type.STRING)
.addAllEnum(KNOWN_PROGRAMMING_LANGUAGES)
.build())
.build();
private static final VertexAiEmbeddingModel EMBEDDING_MODEL = VertexAiEmbeddingModel.builder()
.project(System.getenv("GCP_PROJECT_ID"))
.location(System.getenv("GCP_LOCATION"))
.modelName("text-embedding-005")
.publisher("google")
.taskType(VertexAiEmbeddingModel.TaskType.RETRIEVAL_DOCUMENT)
.titleMetadataKey("titleAndDescription")
.maxSegmentsPerBatch(150)
.build();
private static final VertexAiEmbeddingModel EMBEDDING_MODEL_FOR_RETRIEVAL = VertexAiEmbeddingModel.builder()
.project(System.getenv("GCP_PROJECT_ID"))
.location(System.getenv("GCP_LOCATION"))
.modelName("text-embedding-005")
.publisher("google")
.taskType(VertexAiEmbeddingModel.TaskType.CODE_RETRIEVAL_QUERY)
.titleMetadataKey("titleAndDescription")
.build();
record Idiom(
@SerializedName("Id")
long id,
@SerializedName("Title")
String title,
@SerializedName("LeadParagraph")
String description,
@SerializedName("ExtraKeywords")
String keywords,
@SerializedName("Implementations")
Implementation[] implementations
) {
record Implementation(
@SerializedName("Id")
long id,
@SerializedName("LanguageName")
String language,
@SerializedName("CodeBlock")
String code,
@SerializedName("AuthorComment")
String comment
) {
}
}
public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {
Path filePath = Path.of("src/main/resources/embedding-store.json");
boolean serializedStoreExists = Files.exists(filePath);
InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
if (!serializedStoreExists) {
System.out.println("Preparing text segments to embed");
List<TextSegment> allCodeSegments = new ArrayList<>();
Idiom[] idioms = loadIdioms();
for (Idiom idiom : idioms) {
System.out.println("-> " + idiom.title);
for (Idiom.Implementation implementation : idiom.implementations) {
if (implementation.code != null && !implementation.code.isBlank()) {
allCodeSegments.add(new TextSegment(
implementation.code,
new Metadata()
.put("idiomId", idiom.id)
.put("title", idiom.title)
.put("description", idiom.description)
.put("titleAndDescription", idiom.title + ": " + idiom.description)
.put("keywords", idiom.keywords)
.put("implementationId", implementation.id)
.put("language", implementation.language)
));
}
}
}
System.out.println("Embedding all code segments...");
List<Embedding> allEmbeddings = EMBEDDING_MODEL.embedAll(allCodeSegments).content();
System.out.println("Loading and serializing embedding store...");
embeddingStore.addAll(allEmbeddings, allCodeSegments);
embeddingStore.serializeToFile(filePath);
} else {
System.out.println("Loading serialized embedding store...");
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
}
System.out.println("Ready for querying!");
String question =
// "How to count the characters in a string?";
"How can I make an HTTP POST request?";
// "How can I make an HTTP POST request in Java?";
// "How to use the LibXML parser in Perl?";
// "How to use the LibXML parser?";
System.out.format("Query: %s%n", question);
List<Future<Object>> futures;
try (ExecutorService executorService = Executors.newFixedThreadPool(2)) {
futures = executorService.invokeAll(List.of(
() -> recognizeProgrammingLanguage(question),
() -> embedQuery(question)
));
}
String programmingLanguageRecognised = (String) futures.get(0).get();
Embedding queryEmbedding = (Embedding) futures.get(1).get();
System.out.println("Searching...\n");
var searchRequestBuilder = EmbeddingSearchRequest.builder()
.maxResults(5)
.minScore(0.8)
.queryEmbedding(queryEmbedding);
var searchRequest = "UNKNOWN".equals(programmingLanguageRecognised) ?
searchRequestBuilder.build() :
searchRequestBuilder.filter(new IsEqualTo("language", programmingLanguageRecognised)).build();
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore.search(searchRequest);
searchResult.matches().forEach(match -> {
TextSegment matchedSegment = match.embedded();
System.out.format("""
——— %s ——— (score: %4.5f) —————————
Title: %s
Description: %s
Code:
%s
""",
matchedSegment.metadata().getString("language"),
match.score(),
matchedSegment.metadata().getString("title"),
processString(matchedSegment.metadata().getString("description")),
matchedSegment.text()
);
});
}
private static String processString(String input) {
return input.replaceAll("\\b_([\\w]+)", "$1");
}
private static Embedding embedQuery(String question) {
System.out.println("Embedding query...");
Embedding queryEmbedding = EMBEDDING_MODEL_FOR_RETRIEVAL.embed(question).content();
return queryEmbedding;
}
private static String recognizeProgrammingLanguage(String question) {
System.out.println("Calling gemini to find programming language...");
String programmingLanguageRecognised =
GEMINI_MODEL.generate(
SystemMessage.from("""
Your role is to classify the user message to decide
if it is a question about a particular programming language or not.
If you don't know, or if the programming language is not specified, reply with `UNKNOWN`,
otherwise reply with just the name of the programming language recognized among the following list:
""" + KNOWN_PROGRAMMING_LANGUAGES),
UserMessage.from(question)
).content().text();
System.out.println("Programming language specified: " + programmingLanguageRecognised);
return programmingLanguageRecognised;
}
private static Idiom[] loadIdioms() throws IOException {
System.out.println("Loading idioms...");
List<String> idiomsLines = Files.readAllLines(Path.of("src/main/resources/programming-idioms-all.json"));
String idiomsJson = String.join("\n", idiomsLines);
Idiom[] idioms = new Gson().fromJson(idiomsJson, Idiom[].class);
return idioms;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment