Last active
April 24, 2025 12:26
-
-
Save zakki/3d5b573b8147fdf516f599beef6698a0 to your computer and use it in GitHub Desktop.
Go OCR development with ChatGPT 4o
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Go image to sgf | |
import os | |
import sys | |
import math | |
import re | |
import string | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pytesseract | |
def preprocess_image(image_path): | |
img = cv2.imread(image_path) | |
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
blur = cv2.medianBlur(gray, 5) | |
return img, gray, blur | |
def detect_stones_hough(gray_image): | |
circles = cv2.HoughCircles(gray_image, cv2.HOUGH_GRADIENT, dp=1.2, | |
minDist=20, param1=50, param2=30, | |
minRadius=8, maxRadius=40) | |
print("検出された円の数:", circles.shape[1] if circles is not None else 0) | |
print(circles) | |
return circles[0] if circles is not None else [] | |
def filter_stones_by_intensity(circles, gray_image): | |
stones = [] | |
for circle in circles: | |
x, y, r = map(int, circle) | |
if 0 <= y < gray_image.shape[0] and 0 <= x < gray_image.shape[1]: | |
patch = gray_image[max(0, y-3):y+4, max(0, x-3):x+4] | |
avg_intensity = np.mean(patch) | |
color = 'B' if avg_intensity < 128 else 'W' | |
stones.append(((x, y), color)) | |
return stones | |
def draw_detected_stones(img, stones, radius=8): | |
out = img.copy() | |
for (x, y), color, num in stones: | |
col = (0, 0, 255) if color == 'B' else (255, 255, 0) | |
cv2.circle(out, (x, y), radius, col, 2) | |
# draw num on circle | |
return out | |
def detect_lines(gray): | |
# 盤面検出のためエッジ抽出 | |
edges = cv2.Canny(gray, 50, 150, apertureSize=3) | |
# 直線検出(Hough Line) | |
lines = cv2.HoughLines(edges, 1, np.pi / 180, 180) | |
# 垂直・水平線の分類 | |
horizontal_lines = [] | |
vertical_lines = [] | |
if lines is not None: | |
for rho, theta in lines[:, 0]: | |
if abs(np.cos(theta)) < 0.1: | |
vertical_lines.append((rho, theta)) | |
elif abs(np.sin(theta)) < 0.1: | |
horizontal_lines.append((rho, theta)) | |
return vertical_lines, horizontal_lines | |
def draw_lines(img, vertical_lines, horizontal_lines): | |
# ラインを画像に描画(確認用) | |
line_img = img.copy() | |
for rho, theta in vertical_lines + horizontal_lines: | |
a = np.cos(theta) | |
b = np.sin(theta) | |
x0 = a * rho | |
y0 = b * rho | |
pt1 = (int(x0 + 1000 * (-b)), int(y0 + 1000 * a)) | |
pt2 = (int(x0 - 1000 * (-b)), int(y0 - 1000 * a)) | |
cv2.line(line_img, pt1, pt2, (0, 255, 255), 1) | |
return line_img | |
def detect_grid(img, circle_centers, vertical_lines, horizontal_lines, result_img=None): | |
# 石のx, y座標 | |
x_vals = np.array([x for x, y, r in circle_centers]) | |
y_vals = np.array([y for x, y, r in circle_centers]) | |
# それぞれの方向でヒストグラム | |
hist_x, _ = np.histogram(x_vals, bins=img.shape[1], range=(0, img.shape[1])) | |
hist_y, _ = np.histogram(y_vals, bins=img.shape[0], range=(0, img.shape[0])) | |
# FFT | |
fft_x = np.abs(np.fft.rfft(hist_x)) | |
fft_y = np.abs(np.fft.rfft(hist_y)) | |
freq_x = np.fft.rfftfreq(len(hist_x)) | |
freq_y = np.fft.rfftfreq(len(hist_y)) | |
# ピーク周波数を探す(有効範囲に絞って) | |
valid_fx = (freq_x > 0.001) & (freq_x < 0.05) | |
valid_fy = (freq_y > 0.001) & (freq_y < 0.05) | |
peak_fx = freq_x[valid_fx][np.argmax(fft_x[valid_fx])] | |
peak_fy = freq_y[valid_fy][np.argmax(fft_y[valid_fy])] | |
pitch_x_fft = 1 / peak_fx | |
pitch_y_fft = 1 / peak_fy | |
center_x_fft = int(np.median(x_vals)) | |
center_y_fft = int(np.median(y_vals)) | |
def create_grid_linex(offset): | |
# 中央から19本の格子線を推定 | |
grid_x_lines = [int(center_x_fft + (i - offset) * pitch_x_fft) for i in range(19)] | |
grid_y_lines = [int(center_y_fft + (i - offset) * pitch_y_fft) for i in range(19)] | |
return grid_x_lines, grid_y_lines | |
def get_hl_x(rho, theta): | |
a = np.cos(theta) | |
b = np.sin(theta) | |
x0 = a * rho | |
y0 = b * rho | |
pt1 = (int(x0 + 1000 * (-b)), int(y0 + 1000 * a)) | |
pt2 = (int(x0 - 1000 * (-b)), int(y0 - 1000 * a)) | |
return (pt1[0] + pt2[0]) // 2 | |
def get_hl_y(rho, theta): | |
a = np.cos(theta) | |
b = np.sin(theta) | |
x0 = a * rho | |
y0 = b * rho | |
pt1 = (int(x0 + 1000 * (-b)), int(y0 + 1000 * a)) | |
pt2 = (int(x0 - 1000 * (-b)), int(y0 - 1000 * a)) | |
return (pt1[1] + pt2[1]) // 2 | |
selected_x_lines = None | |
max_match_x_count = -1 | |
selected_y_lines = None | |
max_match_y_count = -1 | |
print(vertical_lines) | |
print(horizontal_lines) | |
for i in range(19): | |
# 直線のx座標を取得 | |
x_line, y_line = create_grid_linex(i) | |
# for rho, theta in vertical_lines: | |
# print(rho, theta) | |
# print(math.cos(theta) * rho - x_line) | |
# x_lineの位置にあるvertical_lineswをカウント | |
match_x_count = sum(1 for rho, theta in horizontal_lines if np.any(abs(get_hl_x(rho, theta) - np.array(x_line)) < 10)) | |
match_y_count = sum(1 for rho, theta in vertical_lines if np.any(abs(get_hl_y(rho, theta) - np.array(y_line)) < 10)) | |
if match_x_count > max_match_x_count: | |
print("update x_line", i, max_match_x_count, match_x_count, len(vertical_lines)) | |
max_match_x_count = match_x_count | |
selected_x_lines = x_line | |
if match_y_count > max_match_y_count: | |
print("update y_line", i, max_match_y_count, match_y_count, len(horizontal_lines)) | |
max_match_y_count = match_y_count | |
selected_y_lines = y_line | |
return selected_x_lines, selected_y_lines | |
# -------------------- 4. 各石 ROI の番号 OCR -------------------------------- | |
def ocr_digit(roi): | |
cfg = "--psm 8 --dpi 72 -c tessedit_char_whitelist=0123456789" | |
t = pytesseract.image_to_string(roi, config=cfg).strip() | |
m = re.match(r'^(\d{1,3})$', t) | |
# print(f"ocr_digit: {t} {m}") | |
return int(m.group(1)) if m else None | |
def stones_to_sgf(stones, grid_x, grid_y): | |
# SGFで使う座標変換(0〜18 → 'a'〜's') | |
def sgf_coord(x, y): | |
return f"{string.ascii_lowercase[x]}{string.ascii_lowercase[y]}" | |
# まず格子に最も近い点にスナップさせて格子インデックスを推定 | |
# 範囲より1グリッド以上外はNone | |
def snap_to_grid(x, y, x_coords, y_coords): | |
grid_size = x_coords[1] - x_coords[0] | |
if x < x_coords[0] - grid_size or x > x_coords[-1] + grid_size: | |
return None, None | |
if y < y_coords[0] - grid_size or y > y_coords[-1] + grid_size: | |
return None, None | |
ix = np.argmin(np.abs(np.array(x_coords) - x)) | |
iy = np.argmin(np.abs(np.array(y_coords) - y)) | |
return ix, iy | |
# 盤面全体におけるSGF着手(手番のあるものと、後から追加するものを分ける) | |
placed = {} | |
numbered = [] | |
unumbered = [] | |
out_of_bounds = [] | |
# 格子点(19x19)と対応づけ | |
for (x, y), col, number in stones: | |
ix, iy = snap_to_grid(x, y, grid_x, grid_y) | |
if ix is None or iy is None: | |
out_of_bounds.append((number, col)) | |
continue | |
coord = sgf_coord(ix, iy) | |
if coord not in placed: | |
placed[coord] = True | |
try: | |
move_number = int(number) | |
numbered.append((move_number, col, coord)) | |
except: | |
unumbered.append((coord, col)) | |
num_to_stone = {} | |
for no, col, coord in numbered: | |
num_to_stone[no] = (col, coord) | |
# 手番順に並べる | |
numbered.sort() | |
print("手番順に並べた着手:", numbered) | |
print("手順不明な着手:", unumbered) | |
print("範囲外の着手:", out_of_bounds) | |
# 黒白と交互に着手、重複と抜け対応 | |
sgf_moves = [] | |
previous_move_number = 0 | |
color = 'B' | |
for move_number, col, coord in numbered: | |
if previous_move_number == move_number: | |
unumbered.append((coord, col)) | |
continue | |
# 手番が抜けている場合は、空白の着手を追加 | |
while previous_move_number + 1 < move_number: | |
previous_move_number += 1 | |
sgf_moves.append(f";{color}[]") | |
color = 'W' if color == 'B' else 'B' | |
# 着手を追加 | |
sgf_moves.append(f";{color}[{coord}]") | |
previous_move_number = move_number | |
color = 'W' if color == 'B' else 'B' | |
# TODO : 重複手の処理 | |
for coord, col in unumbered: | |
sgf_moves.append(f";{col}[{coord}]") | |
for number, col in out_of_bounds: | |
sgf_moves.append(f";{col}[]") | |
# SGFとしてまとめる | |
sgf_content = "(;" + "".join(sgf_moves) + ")" | |
# テキスト表示 | |
sgf_content[:1000] # 長いので先頭だけ表示 | |
return sgf_content | |
# -------------------- 5. メイン処理 ------------------------------------------ | |
def main(image_path): | |
img, gray, blur = preprocess_image(image_path) | |
circles = detect_stones_hough(blur) | |
if len(circles) == 0: | |
print("石が検出されませんでした。") | |
return [] | |
# 円の半径の統計を取って外れ値を除外 | |
radii = [int(circle[2]) for circle in circles] | |
mean_radius = np.mean(radii) | |
std_radius = np.std(radii) | |
min_radius = mean_radius - 2 * std_radius | |
max_radius = mean_radius + 2 * std_radius | |
circles = [circle for circle in circles if min_radius <= circle[2] <= max_radius] | |
print("フィルタリング後の円の数:", len(circles)) | |
stones = filter_stones_by_intensity(circles, gray) | |
# 円の内側の数字をOCRで読み取る | |
for i, ((x, y), color) in enumerate(stones): | |
for pad in range(5): | |
pad_x = int(circles[i][2] * (1.0 - pad * 0.05)) | |
pad_y = int(circles[i][2] * (0.9 - pad * 0.05)) | |
roi = gray[max(0, y-pad_y):y+pad_y, max(0, x-pad_x):x+pad_x] | |
cv2.imwrite(f"roi_{i}_{pad}.png", roi) # デバッグ用 | |
# 白石は文字が黒、黒石は文字が白 ⇒ 2回反転トライ | |
for inv in [False,True]: | |
# roi_bin = cv2.threshold(255 - roi if inv else roi, | |
# 0,255,cv2.THRESH_OTSU)[1] | |
roi_bin = 255 - roi if inv else roi | |
# cv2.imwrite(f"roi_{i}_{inv}.png", roi_bin) # デバッグ用 | |
num = ocr_digit(roi_bin) | |
if num: break | |
if num: | |
break | |
# digit = ocr_digit(roi) | |
print(f"石の位置: ({x}, {y}), 色: {color}, 認識された数字: {num}") | |
# if num is not None: | |
stones[i] = ((x, y), color, num) | |
result_img = draw_detected_stones(img, stones, radius=int(mean_radius)) | |
# 盤面のラインを検出 | |
vertical_lines, horizontal_lines = detect_lines(gray) | |
# result_img = draw_lines(result_img, vertical_lines, horizontal_lines) | |
# 格子線の検出 | |
grid_x_lines, grid_y_lines = detect_grid(gray, circles, vertical_lines, horizontal_lines, result_img) | |
print("検出された格子線の数:", len(grid_x_lines), len(grid_y_lines)) | |
print("格子線のx座標:", grid_x_lines) | |
print("格子線のy座標:", grid_y_lines) | |
# 格子線を描画 | |
for x in grid_x_lines: | |
cv2.line(result_img, (x, 0), (x, result_img.shape[0]), (255, 0, 0), 1) | |
for y in grid_y_lines: | |
cv2.line(result_img, (0, y), (result_img.shape[1], y), (255, 0, 0), 1) | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)) | |
plt.title("改良された石の検出結果") | |
plt.axis("off") | |
plt.show() | |
return stones, grid_x_lines, grid_y_lines | |
if __name__ == "__main__": | |
# 画像のパスを指定してください | |
stones, grid_x_lines, grid_y_lines = main(sys.argv[1]) | |
print("検出された石:", stones) | |
sgf = stones_to_sgf(stones, grid_x_lines, grid_y_lines) | |
print("SGF形式の出力:") | |
print(sgf) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment