Created
March 28, 2025 09:36
-
-
Save kishida/ffdd5199544bd46b7af246d8fda8fffe to your computer and use it in GitHub Desktop.
LangChain4JとGemma3でFunction Calling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.mycompany.langsample; | |
import dev.langchain4j.agent.tool.Tool; | |
import dev.langchain4j.http.client.jdk.JdkHttpClient; | |
import dev.langchain4j.model.openai.OpenAiStreamingChatModel; | |
import dev.langchain4j.service.AiServices; | |
import dev.langchain4j.service.TokenStream; | |
import java.awt.BorderLayout; | |
import java.awt.Color; | |
import java.awt.Font; | |
import java.awt.Graphics2D; | |
import java.awt.Image; | |
import java.awt.image.BufferedImage; | |
import java.net.http.HttpClient; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.function.Consumer; | |
import java.util.function.Function; | |
import java.util.stream.Collectors; | |
import javax.swing.ImageIcon; | |
import javax.swing.JButton; | |
import javax.swing.JFrame; | |
import javax.swing.JLabel; | |
import javax.swing.JPanel; | |
import javax.swing.JProgressBar; | |
import javax.swing.JTextField; | |
public class ModelManipulationSample { | |
static abstract class GraphObject { | |
String id; | |
int left; | |
int top; | |
int width; | |
int height; | |
String color; | |
public GraphObject(String id, int left, int top, int width, int height, String color) { | |
this.id = id; | |
this.left = left; | |
this.top = top; | |
this.width = width; | |
this.height = height; | |
this.color = color; | |
} | |
public String getId() { | |
return id; | |
} | |
abstract void draw(Graphics2D g); | |
@Override | |
public String toString() { | |
return String.format("id:%s, left:%d, top:%d, width:%d, height:%d, color:%s", | |
id, left, top, width, height, color); | |
} | |
} | |
static class Rectangle extends GraphObject { | |
public Rectangle(String id, int left, int top, int width, int height, String color) { | |
super(id, left, top, width, height, color); | |
} | |
@Override | |
public void draw(Graphics2D g) { | |
g.setColor(colors.get(color)); | |
g.fillRect(left, top, width, height); | |
} | |
} | |
static class Triangle extends GraphObject { | |
public Triangle(String id, int left, int top, int width, int height, String color) { | |
super(id, left, top, width, height, color); | |
} | |
@Override | |
public void draw(Graphics2D g) { | |
g.setColor(colors.get(color)); | |
g.fillPolygon(new int[]{left, left + width, left + width / 2}, | |
new int[]{top + height, top + height, top}, 3); | |
} | |
} | |
static class ImageObj extends GraphObject { | |
Image image; | |
public ImageObj(String id, int left, int top, int width, int height, String path) { | |
super(id, left, top, width, height, "black"); | |
image = new ImageIcon(path).getImage(); | |
this.height = image.getHeight(null) * width / image.getWidth(null); | |
} | |
@Override | |
public void draw(Graphics2D g) { | |
g.drawImage(image, left, top, width, height, null); | |
} | |
} | |
static class GraphObjects { | |
Map<String, GraphObject> objectMap; | |
public GraphObjects(List<GraphObject> objs) { | |
objectMap = objs.stream().collect(Collectors.toMap(GraphObject::getId, Function.identity())); | |
} | |
String objectListStr() { | |
return objectMap.values().stream().map(GraphObject::toString) | |
.collect(Collectors.joining("\\n")); | |
} | |
int objectCount () { | |
return objectMap.size(); | |
} | |
void draw(Graphics2D g) { | |
g.setColor(Color.WHITE); | |
g.fillRect(0, 0, 800, 600); | |
objectMap.values().forEach(obj -> obj.draw(g)); | |
} | |
void repaint() { | |
Graphics2D g = image.createGraphics(); | |
draw(g); | |
g.dispose(); | |
imageLabel.repaint(); | |
} | |
String doFunction(String name, String id, Consumer<GraphObject> cons) { | |
var obj = objectMap.get(id); | |
if (obj == null) { | |
System.out.println("obj does not found in set" + name); | |
return "obj does not found:" + id; | |
} | |
cons.accept(obj); | |
var msg = "%s %s changed".formatted(id, name); | |
System.out.println(msg); | |
repaint(); | |
return "ok"; | |
} | |
@Tool("move the object to the specified position") | |
String setPosition(String id, int left, int top) { | |
return doFunction("position", id, obj -> { | |
obj.left = left; | |
obj.top = top; | |
}); | |
} | |
@Tool("change size of the object to the specified size") | |
String setSize(String id, int width, int height) { | |
return doFunction("size", id, obj -> { | |
obj.width = width; | |
obj.height = height; | |
}); | |
} | |
@Tool("change color of the object to the specified color") | |
String setColor(String id, String color) { | |
return doFunction("color", id, obj -> { | |
obj.color = color; | |
}); | |
} | |
} | |
static Map<String, Color> colors = Map.of( | |
"red", Color.RED, | |
"blue", Color.BLUE, | |
"green", Color.GREEN, | |
"yellow", Color.YELLOW, | |
"black", Color.BLACK, | |
"white", Color.WHITE); | |
static BufferedImage image; | |
static JLabel imageLabel; | |
static JTextField textField; | |
static JProgressBar progress; | |
static Assistant assistant; | |
static GraphObjects objects; | |
public static void main(String[] args) throws Exception { | |
// オブジェクト一覧 | |
objects = new GraphObjects(List.of( | |
new Rectangle("rect", 300, 50, 150, 100, "red"), | |
new Triangle("triangle", 600, 200, 170, 150, "blue"), | |
new ImageObj("image", 250, 240, 240, 160, "cat.jpg"))); | |
OpenAiStreamingChatModel model = OpenAiStreamingChatModel.builder() | |
.baseUrl("http://localhost:1234/v1") | |
//.modelName("gemma-3-4b-it") | |
//.modelName("google_gemma-3-27b-it") | |
.modelName("gemma-3-12b-it") | |
.httpClientBuilder(JdkHttpClient.builder().httpClientBuilder( | |
// 1.1を指定しないとHTTP2を使おうとしてLM Studioでは失敗する | |
HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1))) | |
.build(); | |
assistant = AiServices.builder(Assistant.class) | |
.streamingChatLanguageModel(model) | |
.systemMessageProvider(memId -> | |
"You are object manipulator. field size is 800, 600. we have %d objects below.\\n %s" | |
.formatted(objects.objectCount(), objects.objectListStr())) | |
.tools(objects) | |
.build(); | |
// テキストフィールドとボタンを持ったGUIを作成 | |
var frame = new JFrame("Function API Sample"); | |
textField = new JTextField(30); | |
textField.setFont(new Font("Sans Serif", Font.PLAIN, 24)); | |
textField.addActionListener(e -> goPrompt()); | |
var panel = new JPanel(); | |
var button = new JButton("Send"); | |
button.addActionListener(e -> goPrompt()); | |
panel.add(textField); | |
panel.add(button); | |
frame.add(BorderLayout.NORTH, panel); | |
image = new BufferedImage(800, 600, BufferedImage.TYPE_INT_RGB); | |
Graphics2D g = image.createGraphics(); | |
objects.draw(g); | |
g.dispose(); | |
imageLabel = new JLabel(new ImageIcon(image)); | |
frame.add(BorderLayout.CENTER, imageLabel); | |
progress = new JProgressBar(); | |
frame.add(BorderLayout.SOUTH, progress); | |
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); | |
frame.setLocation(100, 100); | |
frame.setSize(800, 600); | |
frame.setVisible(true); | |
} | |
static void goPrompt() { | |
String prompt = textField.getText(); | |
gptRequest(prompt); | |
} | |
interface Assistant { | |
TokenStream chat(String userMessage); | |
} | |
static void gptRequest(String prompt) { | |
//progress.setIndeterminate(true); | |
TokenStream chat = assistant.chat(prompt); | |
chat.onPartialResponse(str -> System.out.print(str)) | |
.onCompleteResponse(resp -> { | |
System.out.println(); | |
}) | |
.onError(th -> { | |
System.out.println(th); | |
}) | |
.onToolExecuted(exe -> System.out.println(exe.request().name())) | |
.start(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
temp.mp4