Created
July 1, 2023 23:51
-
-
Save kishida/c0a7ae4e7a5db7e7e440fbde1886a5f6 to your computer and use it in GitHub Desktop.
manipulate app with natural language without OpenAI nor any LLM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.atilika.kuromoji.ipadic.Tokenizer; | |
import java.awt.*; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.stream.Collectors; | |
public class CaseFrameGrammer { | |
enum Degree { | |
NONE, LITTLE; | |
} | |
interface GeometoryCommand { | |
void execute(FunctionApiSample.GraphObject obj, Degree degree, Dimension fieldSize); | |
} | |
interface ColorCommand { | |
void execute(FunctionApiSample.GraphObject obj, String color); | |
} | |
static Map<String, List<String>> normalizeData = Map.ofEntries( | |
Map.entry("left", List.of("左")), | |
Map.entry("right", List.of("右")), | |
Map.entry("up", List.of("上")), | |
Map.entry("down", List.of("下")), | |
Map.entry("center", List.of("中央", "真ん中")), | |
Map.entry("large", List.of("大きい")), | |
Map.entry("small", List.of("小さい")), | |
Map.entry("little", List.of("少し", "ちょっと")), | |
Map.entry("red", List.of("赤い", "あかい", "赤", "あか", "赤色")), | |
Map.entry("blue", List.of("青", "あお", "青い", "あおい", "青色")), | |
Map.entry("yellow", List.of("黄色", "きいろ", "黄色い", "きいろい")), | |
Map.entry("green", List.of("緑", "みどり", "緑色")), | |
Map.entry("black", List.of("黒", "くろ", "黒い", "くろい", "黒色")), | |
Map.entry("white", List.of("白", "しろ", "白い", "しろい", "白色")), | |
Map.entry("rect", List.of("四角", "しかく")), | |
Map.entry("triangle", List.of("三角", "さんかく")), | |
Map.entry("image", List.of("画像", "写真")) | |
); | |
static Map<String, String> normalize = normalizeData.entrySet().stream() | |
.flatMap(e -> e.getValue().stream().map(v -> Map.entry(v, e.getKey()))) | |
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); | |
GeometoryCommand geometoryCommand = null; | |
ColorCommand colorCommand = null; | |
FunctionApiSample.GraphObject obj = null; | |
Tokenizer tokenizer = new Tokenizer(); | |
public void parse(String input, Map<String, FunctionApiSample.GraphObject> objectMap, Dimension fieldSize) { | |
var tokens = tokenizer.tokenize(input); | |
Degree degree = Degree.NONE; | |
for (var token : tokens) { | |
System.out.println(token.getAllFeatures()); | |
var baseForm = token.getBaseForm(); | |
var command = normalize.get(baseForm); | |
if (command == null) { | |
continue; | |
} | |
System.out.println("command:" + command); | |
switch (command) { | |
case "rect", "triangle", "image": | |
obj = objectMap.get(command); | |
break; | |
case "red", "blue", "yellow", "green", "black", "white": | |
colorCommand = colorCommand(command); | |
geometoryCommand = null; | |
break; | |
case "left": | |
geometoryCommand = this::toLeft; | |
colorCommand = null; | |
break; | |
case "right": | |
geometoryCommand = this::toRight; | |
colorCommand = null; | |
break; | |
case "up": | |
geometoryCommand = this::toUp; | |
colorCommand = null; | |
break; | |
case "down": | |
geometoryCommand = this::toDown; | |
colorCommand = null; | |
break; | |
case "large": | |
geometoryCommand = this::toLarge; | |
colorCommand = null; | |
break; | |
case "center": | |
geometoryCommand = this::toCenter; | |
colorCommand = null; | |
break; | |
case "small": | |
geometoryCommand = this::toSmall; | |
colorCommand = null; | |
break; | |
case "little": | |
degree = Degree.LITTLE; | |
break; | |
} | |
} | |
if (obj != null) { | |
if (geometoryCommand != null) { | |
geometoryCommand.execute(obj, degree, fieldSize); | |
} else if (colorCommand != null) { | |
colorCommand.execute(obj, "red"); | |
} | |
} | |
} | |
void toLeft(FunctionApiSample.GraphObject obj, Degree degree, Dimension fieldSize) { | |
obj.left -= fieldSize.width / (degree == Degree.LITTLE ? 10 : 5); | |
obj.left = Math.max(obj.left, 0); | |
} | |
void toRight(FunctionApiSample.GraphObject obj, Degree degree, Dimension fieldSize) { | |
obj.left += fieldSize.width / (degree == Degree.LITTLE ? 10 : 5); | |
obj.left = Math.min(obj.left, fieldSize.width - obj.width); | |
} | |
void toUp(FunctionApiSample.GraphObject obj, Degree degree, Dimension fieldSize) { | |
obj.top -= fieldSize.height / (degree == Degree.LITTLE ? 10 : 5); | |
obj.top = Math.max(obj.top, 0); | |
} | |
void toDown(FunctionApiSample.GraphObject obj, Degree degree, Dimension fieldSize) { | |
obj.top += fieldSize.height / (degree == Degree.LITTLE ? 10 : 5); | |
obj.top = Math.min(obj.top, fieldSize.height - obj.height); | |
} | |
void toCenter(FunctionApiSample.GraphObject obj, Degree degree, Dimension fieldSize) { | |
System.out.println("obj.width:" + obj.width + ", obj.height:" + obj.height); | |
System.out.println("fieldSize.width:" + fieldSize.width + ", fieldSize.height:" + fieldSize.height); | |
obj.left = (fieldSize.width - obj.width) / 2; | |
obj.top = (fieldSize.height - obj.height) / 2; | |
} | |
//大きくする | |
void toLarge(FunctionApiSample.GraphObject obj, Degree degree, Dimension fieldSize) { | |
var widen = obj.width / (degree == Degree.LITTLE ? 10 : 5); | |
var heighten = obj.height / (degree == Degree.LITTLE ? 10 : 5); | |
obj.width += widen; | |
obj.height += heighten; | |
obj.left -= widen / 2; | |
obj.top -= heighten / 2; | |
} | |
//小さくする | |
void toSmall(FunctionApiSample.GraphObject obj, Degree degree, Dimension fieldSize) { | |
var narrow = obj.width / (degree == Degree.LITTLE ? 10 : 5); | |
var shorten = obj.height / (degree == Degree.LITTLE ? 10 : 5); | |
obj.width -= narrow; | |
obj.height -= shorten; | |
obj.left += narrow / 2; | |
obj.top += shorten / 2; | |
} | |
ColorCommand colorCommand(String color) { | |
return (obj, c) -> obj.color = color; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.fasterxml.jackson.core.JsonProcessingException; | |
import com.fasterxml.jackson.databind.ObjectMapper; | |
import lombok.AllArgsConstructor; | |
import lombok.Data; | |
import javax.swing.*; | |
import java.awt.*; | |
import java.awt.image.BufferedImage; | |
import java.net.URI; | |
import java.net.http.HttpClient; | |
import java.net.http.HttpRequest; | |
import java.net.http.HttpResponse; | |
import java.util.ArrayDeque; | |
import java.util.Deque; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.function.Function; | |
import java.util.stream.Collectors; | |
public class FunctionApiSample { | |
@Data | |
@AllArgsConstructor | |
static abstract class GraphObject { | |
String id; | |
int left; | |
int top; | |
int width; | |
int height; | |
String color; | |
boolean visible = true; | |
abstract void draw(Graphics2D g); | |
public String toString() { | |
return String.format("id:%s, left:%d, top:%d, width:%d, height:%d, color:%s", | |
id, left, top, width, height, color); | |
} | |
} | |
static class Rectangle extends GraphObject { | |
public Rectangle(String id, int left, int top, int width, int height, String color) { | |
super(id, left, top, width, height, color, true); | |
} | |
@Override | |
public void draw(Graphics2D g) { | |
g.setColor(colors.get(color)); | |
g.fillRect(left, top, width, height); | |
} | |
} | |
static class Triangle extends GraphObject { | |
public Triangle(String id, int left, int top, int width, int height, String color) { | |
super(id, left, top, width, height, color, true); | |
} | |
@Override | |
public void draw(Graphics2D g) { | |
g.setColor(colors.get(color)); | |
g.fillPolygon(new int[]{left, left + width, left + width / 2}, | |
new int[]{top + height, top + height, top}, 3); | |
} | |
} | |
static class ImageObj extends GraphObject { | |
Image image; | |
public ImageObj(String id, int left, int top, int width, int height, String path) { | |
super(id, left, top, width, height, "black", true); | |
image = new ImageIcon(path).getImage(); | |
this.height = image.getHeight(null) * width / image.getWidth(null); | |
} | |
@Override | |
public void draw(Graphics2D g) { | |
g.drawImage(image, left, top, width, height, null); | |
} | |
} | |
/** HttpClientの準備 */ | |
static HttpClient client = HttpClient.newHttpClient(); | |
/** リクエストトークンを環境変数から取得 */ | |
static String token = System.getenv("OPENAI_TOKEN"); | |
static Map<String, GraphObject> objectMap; | |
static Map<String, Color> colors = Map.of( | |
"red", Color.RED, | |
"blue", Color.BLUE, | |
"green", Color.GREEN, | |
"yellow", Color.YELLOW, | |
"black", Color.BLACK, | |
"white", Color.WHITE); | |
static BufferedImage image; | |
static JLabel imageLabel; | |
static JTextField textField; | |
record ChatLog(String role, String content) {} | |
static Deque<ChatLog> history = new ArrayDeque<>(); | |
static JProgressBar progress; | |
static JCheckBox chkUseGpt; | |
public static void main(String[] args) throws Exception { | |
// オブジェクト一覧 | |
List<GraphObject> objects = List.of( | |
new Rectangle("rect", 300, 50, 150, 100, "red"), | |
new Triangle("triangle", 600, 200, 170, 150, "blue"), | |
new ImageObj("image", 150, 240, 400, 200, "redhair_girl.png")); | |
objectMap = objects.stream() | |
.collect(Collectors.toMap(GraphObject::getId, Function.identity())); | |
// テキストフィールドとボタンを持ったGUIを作成 | |
var frame = new JFrame("Function API Sample"); | |
textField = new JTextField(30); | |
textField.setFont(new Font("Sans Serif", Font.PLAIN, 24)); | |
textField.addActionListener(e -> goPrompt()); | |
var panel = new JPanel(); | |
var button = new JButton("Send"); | |
button.addActionListener(e -> goPrompt()); | |
chkUseGpt = new JCheckBox("Use GPT"); | |
panel.add(textField); | |
panel.add(button); | |
panel.add(chkUseGpt); | |
frame.add(BorderLayout.NORTH, panel); | |
image = new BufferedImage(800, 600, BufferedImage.TYPE_INT_RGB); | |
Graphics2D g = image.createGraphics(); | |
draw(g); | |
g.dispose(); | |
imageLabel = new JLabel(new ImageIcon(image)); | |
frame.add(BorderLayout.CENTER, imageLabel); | |
progress = new JProgressBar(); | |
frame.add(BorderLayout.SOUTH, progress); | |
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); | |
frame.setLocation(100, 100); | |
frame.setSize(830, 710); | |
frame.setVisible(true); | |
} | |
static CaseFrameGrammer caseFrameGrammer = new CaseFrameGrammer(); | |
static void goPrompt() { | |
String prompt = textField.getText(); | |
if (chkUseGpt.isSelected()) { | |
gptRequest(prompt); | |
} else { | |
caseFrameGrammer.parse(prompt, objectMap, new Dimension(800, 600)); | |
// 画面を再描画 | |
Graphics2D g = image.createGraphics(); | |
draw(g); | |
g.dispose(); | |
imageLabel.repaint(); | |
textField.setText(""); | |
} | |
} | |
static void draw(Graphics2D g) { | |
g.setColor(Color.WHITE); | |
g.fillRect(0, 0, 800, 600); | |
objectMap.values().stream().filter(obj -> obj.visible).forEach(obj -> obj.draw(g)); | |
} | |
static void gptRequest(String prompt) { | |
history.addLast(new ChatLog("user", prompt)); | |
while (history.size() > 10) history.removeFirst(); | |
// リクエストJSONの作成 | |
String promptStr = history.stream() | |
.map(log -> "{\"role\": \"%s\", \"content\": \"%s\"}".formatted(log.role(), log.content())) | |
.collect(Collectors.joining(",\n")); | |
String objectsStr = objectMap.values().stream().map(GraphObject::toString).collect(Collectors.joining("\\n")); | |
String req = requestJson.formatted(objectsStr, promptStr); | |
System.out.println(req); | |
// リクエストの作成 | |
HttpRequest request = HttpRequest.newBuilder() | |
.uri(URI.create("https://api.openai.com/v1/chat/completions")) | |
.header("Content-Type", "application/json") | |
.header("Authorization", "Bearer " + token) | |
.POST(HttpRequest.BodyPublishers.ofString(req)) | |
.build(); | |
// リクエストの送信 | |
progress.setIndeterminate(true); | |
client.sendAsync(request, HttpResponse.BodyHandlers.ofString()) | |
.thenApply(HttpResponse::body) | |
.thenAccept(FunctionApiSample::apiResponse) | |
.whenComplete((result, e) -> { | |
progress.setIndeterminate(false); | |
textField.setText(""); | |
}); | |
} | |
/** | |
* 次のようなJSONを解析する | |
* { | |
* "id" : "chatcmpl-7SZ4df34uEA9IvyYHhqxw8L6qytNQ", | |
* "object" : "chat.completion", | |
* "created" : 1687042363, | |
* "model" : "gpt-3.5-turbo-0613", | |
* "choices" : [ { | |
* "index" : 0, | |
* "message" : { | |
* "role" : "assistant", | |
* "content" : null, | |
* "function_call" : { | |
* "name" : "set_position", | |
* "arguments" : "{\n \"id\": \"triangle\",\n \"left\": 800,\n \"top\": 200\n}" | |
* } | |
* }, | |
* "finish_reason" : "function_call" | |
* } ], | |
* "usage" : { | |
* "prompt_tokens" : 274, | |
* "completion_tokens" : 29, | |
* "total_tokens" : 303 | |
* } | |
* } | |
* @param json | |
*/ | |
static void apiResponse(String json) { | |
try { | |
System.out.println("---"); | |
System.out.println(json); | |
// jsonをjacksonでパース | |
ObjectMapper mapper = new ObjectMapper(); | |
var tree = mapper.readTree(json); | |
// function_callを得る | |
var functionCall = tree.at("/choices/0/message/function_call"); | |
// argumentsを得る | |
var arguments = functionCall.at("/arguments"); | |
// argumentsをパース | |
var args = mapper.readValue(arguments.asText(), Map.class); | |
var obj = objectMap.get(args.get("id")); | |
switch(functionCall.at("/name").asText()) { | |
case "set_position" -> { | |
var oldLeft = obj.getLeft(); | |
var oldTop = obj.getTop(); | |
// オブジェクトを移動 | |
obj.setLeft((int) args.get("left")); | |
obj.setTop((int) args.get("top")); | |
history.addLast(new ChatLog("assistant", "I moved the %s from (%d, %d) to (%d, %d)" | |
.formatted(obj.getId(), oldLeft, oldTop, obj.getLeft(), obj.getTop()))); | |
} | |
case "set_color" -> { | |
var oldColor = obj.getColor(); | |
// オブジェクトの色を変更 | |
obj.setColor(args.get("color").toString()); | |
history.addLast(new ChatLog("assistant", "I changed the %s color from %s to %s" | |
.formatted(obj.getId(), oldColor, obj.getColor()))); | |
} | |
case "set_size" -> { | |
var oldWidth = obj.getWidth(); | |
var oldHeight = obj.getHeight(); | |
// オブジェクトのサイズを変更 | |
obj.setWidth((int) args.get("width")); | |
obj.setHeight((int) args.get("height")); | |
history.addLast(new ChatLog("assistant", "I changed the %s size from (%d, %d) to (%d, %d)" | |
.formatted(obj.getId(), oldWidth, oldHeight, obj.getWidth(), obj.getHeight()))); | |
} | |
case "set_visible" -> { | |
var oldVisible = obj.isVisible(); | |
// オブジェクトの表示/非表示を変更 | |
obj.setVisible((boolean) args.get("visible")); | |
history.addLast(new ChatLog("assistant", "I changed the %s visibility from %s to %s" | |
.formatted(obj.getId(), oldVisible, obj.isVisible()))); | |
} | |
default -> { | |
// それ以外の関数は無視 | |
history.addLast(new ChatLog("assistant", "I don't know how to do that.")); | |
} | |
} | |
// 画面を再描画 | |
Graphics2D g = image.createGraphics(); | |
draw(g); | |
g.dispose(); | |
imageLabel.repaint(); | |
} catch (JsonProcessingException e) { | |
System.out.println("JSON parse error"); | |
System.out.println(json); | |
throw new RuntimeException(e); | |
} | |
} | |
/** リクエストJSONのテンプレート | |
* model gpt-4-0613 or gpt-3.5-turbo-0613 | |
*/ | |
static String requestJson = """ | |
{ | |
"model": "gpt-3.5-turbo-0613", | |
"messages": [ | |
{"role": "system", "content": "You are object manipulator. field size is 800, 600. we have 3 objects below.\\n %s"}, | |
%s | |
], | |
"functions": [ | |
{ | |
"name": "set_position", | |
"description": "Set the position of an object", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"id": { | |
"type": "string", | |
"description": "The object ID to move" | |
}, | |
"left": { | |
"type": "integer", | |
"description": "The left position in pixels" | |
}, | |
"top": { | |
"type": "integer", | |
"description": "The top position in pixels" | |
} | |
}, | |
"required": ["id", "left", "top"] | |
} | |
}, | |
{ | |
"name": "set_size", | |
"description": "Set the size of an object", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"id": { | |
"type": "string", | |
"description": "The object ID to resize" | |
}, | |
"width": { | |
"type": "integer", | |
"description": "The width in pixels" | |
}, | |
"height": { | |
"type": "integer", | |
"description": "The height in pixels" | |
} | |
}, | |
"required": ["id", "width", "height"] | |
} | |
}, | |
{ | |
"name": "set_color", | |
"description": "Set the color of an object", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"id": { | |
"type": "string", | |
"description": "The object ID to resize" | |
}, | |
"color": { | |
"type": "string", | |
"description": "The color. color can be 'blue', 'red', 'green', 'yellow', 'black', 'white'" | |
} | |
}, | |
"required": ["id", "color"] | |
} | |
}, | |
{ | |
"name": "set_visible", | |
"description": "Set the visibility of an object", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"id": { | |
"type": "string", | |
"description": "The object ID to resize" | |
}, | |
"visible": { | |
"type": "boolean", | |
"description": "The visibility" | |
} | |
}, | |
"required": ["id", "visible"] | |
} | |
} | |
] | |
} | |
"""; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<project xmlns="http://maven.apache.org/POM/4.0.0" | |
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | |
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | |
<modelVersion>4.0.0</modelVersion> | |
<groupId>naoki</groupId> | |
<artifactId>GraphManipulater</artifactId> | |
<version>1.0-SNAPSHOT</version> | |
<dependencies> | |
<!-- lombok --> | |
<dependency> | |
<groupId>org.projectlombok</groupId> | |
<artifactId>lombok</artifactId> | |
<version>1.18.20</version> | |
</dependency> | |
<!-- jackson --> | |
<dependency> | |
<groupId>com.fasterxml.jackson.core</groupId> | |
<artifactId>jackson-databind</artifactId> | |
<version>2.12.7.1</version> | |
</dependency> | |
<!-- kuromoji --> | |
<dependency> | |
<groupId>com.atilika.kuromoji</groupId> | |
<artifactId>kuromoji-ipadic</artifactId> | |
<version>0.9.0</version> | |
</dependency> | |
</dependencies> | |
<properties> | |
<maven.compiler.source>20</maven.compiler.source> | |
<maven.compiler.target>20</maven.compiler.target> | |
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | |
</properties> | |
</project> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment