Skip to content

Instantly share code, notes, and snippets.

@glaforge
Created December 13, 2024 17:18
Show Gist options
  • Save glaforge/d6e845c673a5441823efc800d2d6bbf6 to your computer and use it in GitHub Desktop.
Save glaforge/d6e845c673a5441823efc800d2d6bbf6 to your computer and use it in GitHub Desktop.
SkyjoCardCounter.java with Gemini 2.0 Flash and LangChain4j
import com.google.gson.Gson;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.vertexai.SchemaHelper;
import dev.langchain4j.model.vertexai.VertexAiGeminiChatModel;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
public class SkyjoCardCounter {
record CardsExample(Path imageFile, int total, int[] cards) {}
private static List<CardsExample> processImageFiles(Path samplesPath) throws IOException {
return Files.walk(samplesPath, 1, FileVisitOption.FOLLOW_LINKS)
.filter(path -> path.toFile().isFile())
.map(path -> {
try {
String fileName = path.getFileName().toString();
String[] numbersStr = fileName.substring(0, fileName.lastIndexOf('.')).split(" ");
int[] numbers = Arrays.stream(numbersStr).mapToInt(Integer::parseInt).toArray();
int total = Arrays.stream(numbers).sum();
return new CardsExample(path, total, numbers);
} catch (NumberFormatException e) {
System.err.println("Skipping file with invalid name: " + path.getFileName());
return null;
}
})
.filter(Objects::nonNull)
.toList();
}
public static void main(String[] args) throws IOException, InterruptedException {
Gson gson = new Gson();
record Card(
int label,
BoundingBox boundingBox
) {
record BoundingBox(int x1, int y1, int x2, int y2) {}
}
var model = VertexAiGeminiChatModel.builder()
.project(System.getenv("PROJECT_ID"))
.location(System.getenv("LOCATION"))
.modelName("gemini-2.0-flash-exp")
.responseMimeType("application/json")
.responseSchema(SchemaHelper.fromClass(Card[].class))
.temperature(0.1f)
.build();
var cardsExamples = processImageFiles(Path.of("/Users/glaforge/Projects/skyjo-counter/samples"));
for (CardsExample example : cardsExamples) {
System.out.println("File: " + example.imageFile());
Response<AiMessage> response =
model.generate(
SystemMessage.from("""
Detect playing cards with numbers, with no more than 12 items.
Output a JSON list of cards, where each entry contains the the 2D bounding box in `boundingBox`
and the `label` is the big number displayed in the center of the card.
If you see the text "SKYJO" on the card, use 0 as the label in `label`.
Ignore the small numbers in the corners of the cards.
Ignore cards with text written on them.
Be careful when reading the numbers, as sometimes some cards are tilted, cut, or upside down.
"""),
UserMessage.from(
ImageContent.from(example.imageFile().toUri()),
TextContent.from("""
Detect the cards of this image.
""")
));
Card[] cardValues = gson.fromJson(response.content().text(), Card[].class);
int sum = Arrays.stream(cardValues).mapToInt(c -> c.label).sum();
if (sum == example.total()) {
System.out.format(" ==> Your points: %d %n", sum);
} else {
System.out.format(" ==> WRONG RESULT: %d points instead of %d %n", sum, example.total());
Arrays.stream(cardValues).forEach(c ->
System.out.format(" - %d (%s) %n", c.label(), c.boundingBox()));
}
Thread.sleep(5_000);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment