Created
February 26, 2025 14:03
-
-
Save glaforge/b3624ed112725e5b028c999255edfd76 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class SentenceWindowRetrieval { | |
public static void main(String[] args) throws IOException { | |
Document capitalDocument = Document.from("..."); | |
VertexAiEmbeddingModel embeddingModel = VertexAiEmbeddingModel.builder() | |
.project(System.getenv("GCP_PROJECT_ID")) | |
.endpoint(System.getenv("GCP_VERTEXAI_ENDPOINT")) | |
.location(System.getenv("GCP_LOCATION")) | |
.publisher("google") | |
.modelName("text-embedding-005") | |
.build(); | |
InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>(); | |
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder() | |
.documentSplitter(new DocumentBySentenceSplitter(200, 20)) | |
.embeddingStore(embeddingStore) | |
.embeddingModel(embeddingModel) | |
.textSegmentTransformer(SurroundingContextEnricher.textSegmentTransformer(2, 3)) | |
.build(); | |
IngestionResult ingestionResult = ingestor.ingest(capitalDocument); | |
VertexAiGeminiChatModel chatModel = VertexAiGeminiChatModel.builder() | |
.project(System.getenv("GCP_PROJECT_ID")) | |
.location(System.getenv("GCP_LOCATION")) | |
.modelName("gemini-2.0-flash-001") | |
.build(); | |
interface CapitalsAssistant { | |
Result<String> learnAboutCapitals(String query); | |
} | |
CapitalsAssistant assistant = AiServices.builder(CapitalsAssistant.class) | |
.chatLanguageModel(chatModel) | |
.retrievalAugmentor(DefaultRetrievalAugmentor.builder() | |
.contentRetriever(EmbeddingStoreContentRetriever.builder() | |
.embeddingModel(embeddingModel) | |
.embeddingStore(embeddingStore) | |
.maxResults(3) | |
.minScore(0.75) | |
.build()) | |
.contentInjector(SurroundingContextEnricher.contentInjector(PromptTemplate.from(""" | |
You are a helpful history and geography assistant knowing everything about the capitals of the world. | |
Here's the question from the user: | |
<question> | |
{{userMessage}} | |
</question> | |
Answer the question using the following information: | |
<excerpts> | |
{{contents}} | |
</excerpts> | |
"""))) | |
.build()) | |
.build(); | |
Result<String> response = assistant.learnAboutCapitals("How many inhabitants live in the capital of Somaliland?"); | |
System.out.println(response.content()); | |
response.sources().forEach(src -> { | |
System.out.println(" - " + src.textSegment().text()); | |
System.out.println(" surrounding context: " + src.textSegment().metadata().getString(SurroundingContextEnricher.METADATA_CONTEXT_KEY)); | |
}); | |
} | |
private static class SurroundingContextEnricher { | |
private static final String METADATA_CONTEXT_KEY = "Surrounding context"; | |
public static TextSegmentTransformer textSegmentTransformer(int nSegmentsBefore, int nSegmentsAfter) { | |
return new TextSegmentTransformer() { | |
@Override | |
public TextSegment transform(TextSegment segment) { | |
return transformAll(Collections.singletonList(segment)).getFirst(); | |
} | |
@Override | |
public List<TextSegment> transformAll(List<TextSegment> segments) { | |
if (segments == null || segments.isEmpty()) { | |
return Collections.emptyList(); | |
} | |
List<TextSegment> list = new ArrayList<>(); | |
for (int i = 0; i < segments.size(); i++) { | |
TextSegment textSegment = segments.get(i); | |
String context = IntStream.rangeClosed(i - nSegmentsBefore, i + nSegmentsAfter) | |
.filter(j -> j >= 0 && j < segments.size()) | |
.mapToObj(j -> segments.get(j).text()) | |
.collect(Collectors.joining(" ")); | |
Metadata metadata = new Metadata(textSegment.metadata().toMap()); | |
metadata.put(METADATA_CONTEXT_KEY, context); | |
list.add(TextSegment.from(textSegment.text(), metadata)); | |
} | |
return list; | |
} | |
}; | |
} | |
public static ContentInjector contentInjector(PromptTemplate promptTemplate) { | |
return (contents, userMessage) -> { | |
String excerpts = contents.stream() | |
.map(content -> | |
content | |
.textSegment() | |
.metadata() | |
.getString(METADATA_CONTEXT_KEY)) | |
.collect(Collectors.joining("\n\n")); | |
return promptTemplate.apply(Map.of( | |
"userMessage", userMessage.singleText(), | |
"contents", excerpts | |
)).toUserMessage(); | |
}; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment