Last active
November 29, 2024 15:56
-
-
Save glaforge/4e45fa4222dd803d6d8bbf2b9335e90d to your computer and use it in GitHub Desktop.
Version #1 of ProgrammingIdioms with embeddings of the title, the description, and the code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package experiments; | |
import com.google.cloud.vertexai.api.Schema; | |
import com.google.cloud.vertexai.api.Type; | |
import com.google.gson.Gson; | |
import com.google.gson.annotations.SerializedName; | |
import dev.langchain4j.data.document.Metadata; | |
import dev.langchain4j.data.embedding.Embedding; | |
import dev.langchain4j.data.message.SystemMessage; | |
import dev.langchain4j.data.message.UserMessage; | |
import dev.langchain4j.data.segment.TextSegment; | |
import dev.langchain4j.model.chat.ChatLanguageModel; | |
import dev.langchain4j.model.vertexai.VertexAiEmbeddingModel; | |
import dev.langchain4j.model.vertexai.VertexAiGeminiChatModel; | |
import dev.langchain4j.store.embedding.EmbeddingSearchRequest; | |
import dev.langchain4j.store.embedding.EmbeddingSearchResult; | |
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; | |
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; | |
import java.io.IOException; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.time.Instant; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.concurrent.*; | |
import java.util.concurrent.atomic.AtomicReference; | |
public class ProgrammingIdioms { | |
private static final List<String> KNOWN_PROGRAMMING_LANGUAGES = | |
List.of("UNKNOWN", | |
"Go", "Rust", "Python", "Perl", "Ruby", "Java", "JS", | |
"C#", "Dart", "Pascal", "PHP", "C++", "Haskell", "D", | |
"Lua", "Clojure", "Fortran", "Elixir", "Kotlin", | |
"Erlang", "C", "Lisp", "VB", "Groovy", "Ada", "Scala", | |
"Scheme", "Smalltalk", "Obj-C", "Cobol", "Prolog", "Caml" | |
); | |
private static final ChatLanguageModel GEMINI_MODEL = | |
VertexAiGeminiChatModel.builder() | |
.project(System.getenv("GCP_PROJECT_ID")) | |
.location(System.getenv("GCP_LOCATION")) | |
.modelName("gemini-1.5-flash-002") | |
.responseSchema(Schema.newBuilder() | |
.setType(Type.STRING) | |
.addAllEnum(KNOWN_PROGRAMMING_LANGUAGES) | |
.build()) | |
.build(); | |
private static final VertexAiEmbeddingModel EMBEDDING_MODEL = VertexAiEmbeddingModel.builder() | |
.project(System.getenv("GCP_PROJECT_ID")) | |
.location(System.getenv("GCP_LOCATION")) | |
.modelName("text-embedding-005") | |
.publisher("google") | |
.taskType(VertexAiEmbeddingModel.TaskType.RETRIEVAL_DOCUMENT) | |
.titleMetadataKey("titleAndDescription") | |
.maxSegmentsPerBatch(150) | |
.build(); | |
private static final VertexAiEmbeddingModel EMBEDDING_MODEL_FOR_RETRIEVAL = VertexAiEmbeddingModel.builder() | |
.project(System.getenv("GCP_PROJECT_ID")) | |
.location(System.getenv("GCP_LOCATION")) | |
.modelName("text-embedding-005") | |
.publisher("google") | |
.taskType(VertexAiEmbeddingModel.TaskType.CODE_RETRIEVAL_QUERY) | |
.titleMetadataKey("titleAndDescription") | |
.build(); | |
record Idiom( | |
@SerializedName("Id") | |
long id, | |
@SerializedName("Title") | |
String title, | |
@SerializedName("LeadParagraph") | |
String description, | |
@SerializedName("ExtraKeywords") | |
String keywords, | |
@SerializedName("Implementations") | |
Implementation[] implementations | |
) { | |
record Implementation( | |
@SerializedName("Id") | |
long id, | |
@SerializedName("LanguageName") | |
String language, | |
@SerializedName("CodeBlock") | |
String code, | |
@SerializedName("AuthorComment") | |
String comment | |
) { | |
} | |
} | |
public static void main(String[] args) throws IOException, InterruptedException, ExecutionException { | |
Path filePath = Path.of("src/main/resources/embedding-store.json"); | |
boolean serializedStoreExists = Files.exists(filePath); | |
InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>(); | |
if (!serializedStoreExists) { | |
System.out.println("Preparing text segments to embed"); | |
List<TextSegment> allCodeSegments = new ArrayList<>(); | |
Idiom[] idioms = loadIdioms(); | |
for (Idiom idiom : idioms) { | |
System.out.println("-> " + idiom.title); | |
for (Idiom.Implementation implementation : idiom.implementations) { | |
if (implementation.code != null && !implementation.code.isBlank()) { | |
allCodeSegments.add(new TextSegment( | |
implementation.code, | |
new Metadata() | |
.put("idiomId", idiom.id) | |
.put("title", idiom.title) | |
.put("description", idiom.description) | |
.put("titleAndDescription", idiom.title + ": " + idiom.description) | |
.put("keywords", idiom.keywords) | |
.put("implementationId", implementation.id) | |
.put("language", implementation.language) | |
)); | |
} | |
} | |
} | |
System.out.println("Embedding all code segments..."); | |
List<Embedding> allEmbeddings = EMBEDDING_MODEL.embedAll(allCodeSegments).content(); | |
System.out.println("Loading and serializing embedding store..."); | |
embeddingStore.addAll(allEmbeddings, allCodeSegments); | |
embeddingStore.serializeToFile(filePath); | |
} else { | |
System.out.println("Loading serialized embedding store..."); | |
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath); | |
} | |
System.out.println("Ready for querying!"); | |
String question = | |
// "How to count the characters in a string?"; | |
"How can I make an HTTP POST request?"; | |
// "How can I make an HTTP POST request in Java?"; | |
// "How to use the LibXML parser in Perl?"; | |
// "How to use the LibXML parser?"; | |
System.out.format("Query: %s%n", question); | |
List<Future<Object>> futures; | |
try (ExecutorService executorService = Executors.newFixedThreadPool(2)) { | |
futures = executorService.invokeAll(List.of( | |
() -> recognizeProgrammingLanguage(question), | |
() -> embedQuery(question) | |
)); | |
} | |
String programmingLanguageRecognised = (String) futures.get(0).get(); | |
Embedding queryEmbedding = (Embedding) futures.get(1).get(); | |
System.out.println("Searching...\n"); | |
var searchRequestBuilder = EmbeddingSearchRequest.builder() | |
.maxResults(5) | |
.minScore(0.8) | |
.queryEmbedding(queryEmbedding); | |
var searchRequest = "UNKNOWN".equals(programmingLanguageRecognised) ? | |
searchRequestBuilder.build() : | |
searchRequestBuilder.filter(new IsEqualTo("language", programmingLanguageRecognised)).build(); | |
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore.search(searchRequest); | |
searchResult.matches().forEach(match -> { | |
TextSegment matchedSegment = match.embedded(); | |
System.out.format(""" | |
——— %s ——— (score: %4.5f) ————————— | |
Title: %s | |
Description: %s | |
Code: | |
%s | |
""", | |
matchedSegment.metadata().getString("language"), | |
match.score(), | |
matchedSegment.metadata().getString("title"), | |
processString(matchedSegment.metadata().getString("description")), | |
matchedSegment.text() | |
); | |
}); | |
} | |
private static String processString(String input) { | |
return input.replaceAll("\\b_([\\w]+)", "$1"); | |
} | |
private static Embedding embedQuery(String question) { | |
System.out.println("Embedding query..."); | |
Embedding queryEmbedding = EMBEDDING_MODEL_FOR_RETRIEVAL.embed(question).content(); | |
return queryEmbedding; | |
} | |
private static String recognizeProgrammingLanguage(String question) { | |
System.out.println("Calling gemini to find programming language..."); | |
String programmingLanguageRecognised = | |
GEMINI_MODEL.generate( | |
SystemMessage.from(""" | |
Your role is to classify the user message to decide | |
if it is a question about a particular programming language or not. | |
If you don't know, or if the programming language is not specified, reply with `UNKNOWN`, | |
otherwise reply with just the name of the programming language recognized among the following list: | |
""" + KNOWN_PROGRAMMING_LANGUAGES), | |
UserMessage.from(question) | |
).content().text(); | |
System.out.println("Programming language specified: " + programmingLanguageRecognised); | |
return programmingLanguageRecognised; | |
} | |
private static Idiom[] loadIdioms() throws IOException { | |
System.out.println("Loading idioms..."); | |
List<String> idiomsLines = Files.readAllLines(Path.of("src/main/resources/programming-idioms-all.json")); | |
String idiomsJson = String.join("\n", idiomsLines); | |
Idiom[] idioms = new Gson().fromJson(idiomsJson, Idiom[].class); | |
return idioms; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment