Skip to content

Instantly share code, notes, and snippets.

@tallclair
Last active October 2, 2025 04:11
Show Gist options
  • Save tallclair/8368028cf978a384741fc2b8c7821ddc to your computer and use it in GitHub Desktop.
Save tallclair/8368028cf978a384741fc2b8c7821ddc to your computer and use it in GitHub Desktop.
A script to convert consecutive calls to SetFeatureGateDuringTest to use the multi-gate SetFeatureGatesDuringTest instead (generated by gemini-2.5-pro)
"""
This script refactors Go test files to consolidate multiple consecutive calls
to `featuregatetesting.SetFeatureGateDuringTest` into a single call to
`featuregatetesting.SetFeatureGatesDuringTest`.
It operates in three main stages:
1. Find candidate files: It uses `git grep` to quickly find all `_test.go` files
that contain the target function `SetFeatureGateDuringTest`. This is much
faster than walking the entire directory tree.
2. Refactor content: For each candidate file, it reads the content and uses
regular expressions to find blocks of consecutive calls. It then replaces
these blocks with a single, equivalent `SetFeatureGatesDuringTest` call.
3. Format code: After refactoring, it runs the project's standard Go formatting
script (`hack/update-gofmt.sh`) to ensure the changes are correctly formatted.
"""
import re
import os
import sys
import subprocess
import time
def refactor_feature_gate_calls(file_path):
"""
Reads a single Go test file and replaces consecutive calls to
SetFeatureGateDuringTest with a single SetFeatureGatesDuringTest call.
"""
try:
# The `with` statement ensures the file is automatically closed.
with open(file_path, 'r') as f:
content = f.read()
except Exception as e:
# In Python, it's common to catch broad exceptions for I/O and report them.
print(f"Error reading {file_path}: {e}", file=sys.stderr)
return
# This regular expression is the core of the detection logic.
# It's defined once using re.compile for efficiency if used in a loop.
# re.MULTILINE allows `^` and `$` to match the start/end of lines, not just the string.
pattern = re.compile(
# Group 1: Capture the newline preceding the block, if it exists.
r"(\s*\n)?"
# Group 2: Capture the main block of consecutive calls.
# The `(?: ... )` creates a non-capturing group for the inner pattern.
r"((?:"
# This part matches either a line with the function call or a blank/whitespace line.
# This allows the regex to match calls separated by empty lines.
r"(?:[ \t]*featuregatetesting\.SetFeatureGateDuringTest\(.*\)|[ \t]*)\s*\n"
# `)+` ensures we match one or more lines of this pattern.
r")+)"
# Group 3: Capture the newline following the block, if it exists.
r"(\s*\n)?",
re.MULTILINE
)
def replacement_func(match):
"""
This is a callback function that gets executed for each match found by the
main pattern. It receives a `match` object and returns the replacement string.
"""
# Extract the captured groups from the match. `group(N)` corresponds to the Nth
# pair of parentheses in the regex. `or ""` provides a default value.
preceding_newline = match.group(1) or ""
original_block = match.group(2)
trailing_newline = match.group(3) or ""
# The regex can greedily capture trailing whitespace. We must preserve it.
stripped_block = original_block.rstrip()
trailing_whitespace_in_block = original_block[len(stripped_block):]
# Split the matched block into individual lines.
lines = stripped_block.splitlines()
# Filter out any empty or whitespace-only lines to get only the actual calls.
# This is a list comprehension, a concise way to create lists.
call_lines = [line for line in lines if "SetFeatureGateDuringTest" in line]
# If there's less than 2 calls, it's not a candidate for refactoring.
# Return the original matched string to make no change.
if len(call_lines) < 2:
return match.group(0)
gate_settings = []
# This regex extracts the arguments from a single SetFeatureGateDuringTest call.
# `([^,]+)` captures any sequence of characters that is not a comma.
call_pattern = re.compile(r"SetFeatureGateDuringTest\(([^,]+),\s*([^,]+),\s*([^,]+),\s*([^)]+)\)")
# Parse the first line to establish the common variables (t, gate).
first_line_match = call_pattern.search(call_lines[0])
if not first_line_match:
return match.group(0) # Should not happen, but a safe guard.
# `.strip()` removes leading/trailing whitespace.
test_var = first_line_match.group(1).strip()
gate_var = first_line_match.group(2).strip()
# Capture the indentation of the first line to preserve formatting.
indentation = re.match(r"([ \t]*)", call_lines[0]).group(1)
# Iterate over all the call lines to extract feature settings.
for line in call_lines:
line_match = call_pattern.search(line)
if not line_match:
return match.group(0)
# Ensure that all calls in the block use the same test and gate variables.
current_test_var = line_match.group(1).strip()
current_gate_var = line_match.group(2).strip()
if current_test_var != test_var or current_gate_var != gate_var:
# If they don't match, this block cannot be refactored.
return match.group(0)
# Extract the feature name and its boolean value.
feature_name = line_match.group(3).strip()
value = line_match.group(4).strip()
# Format the new line for the FeatureOverrides map.
# f-strings (f"...") are a modern way to format strings in Python.
gate_settings.append(f"{indentation}\t{feature_name}: {value},")
if not gate_settings:
return match.group(0)
# Construct the new, refactored block of code as a list of strings.
new_lines = [
f"{indentation}featuregatetesting.SetFeatureGatesDuringTest({test_var}, {gate_var}, featuregatetesting.FeatureOverrides{{",
# The `*` operator unpacks the `gate_settings` list, inserting its elements here.
*gate_settings,
f"{indentation}}})",
]
# Join the new lines, add a trailing newline if the original block had one,
# and then restore any captured whitespace.
new_block = "\n".join(new_lines)
if stripped_block.endswith('\n'):
new_block += "\n"
return preceding_newline + new_block + trailing_whitespace_in_block + trailing_newline
# --- Pass 1: Refactor consecutive calls ---
# This approach avoids infinite loops by iterating over the original content once.
# `re.finditer` returns an iterator yielding match objects for all non-overlapping
# matches in the string.
last_end = 0
new_content_parts = []
pass1_made_changes = False
for match in pattern.finditer(content):
# Append the text between the last match and this one.
new_content_parts.append(content[last_end:match.start()])
# Call our replacement function to get the refactored block (or original).
replacement = replacement_func(match)
new_content_parts.append(replacement)
# Check if a change was actually made.
if replacement != match.group(0):
pass1_made_changes = True
# Update the position for the next iteration.
last_end = match.end()
# Append the remainder of the file after the last match.
new_content_parts.append(content[last_end:])
content_after_pass1 = "".join(new_content_parts)
# --- Pass 2: Refactor for loops ---
# This pattern finds for loops that iterate over a map to set feature gates.
for_loop_pattern = re.compile(
r"^(?P<indent>[ \t]*)for\s+(?P<key_var>[^,]+),\s+(?P<val_var>[^ ]+)\s+:=\s+range\s+(?P<source_map>[^{]+)\s*\{\s*\n"
r"\s*featuregatetesting\.SetFeatureGateDuringTest\((?P<t_var>[^,]+),\s*(?P<gate_var>[^,]+),\s*(?P<feature_var>[^,]+),\s*(?P<value_var>[^)]+)\)\s*\n"
r"^\s*\}",
re.MULTILINE
)
def for_loop_replacement_func(match):
# Extract named groups from the matched for loop pattern.
indent = match.group("indent")
key_var = match.group("key_var").strip()
val_var = match.group("val_var").strip()
source_map = match.group("source_map").strip()
t_var = match.group("t_var").strip()
gate_var = match.group("gate_var").strip()
feature_var = match.group("feature_var").strip()
value_var = match.group("value_var").strip()
# Validate that the loop variables are used correctly in the call.
# If not, it's not a candidate for this specific refactoring.
if key_var == feature_var and val_var == value_var:
# Construct the replacement string.
return f"{indent}featuregatetesting.SetFeatureGatesDuringTest({t_var}, {gate_var}, {source_map})"
else:
return match.group(0)
# Use re.subn to replace all occurrences and get a count of replacements.
new_content, pass2_made_changes = for_loop_pattern.subn(for_loop_replacement_func, content_after_pass1)
# --- Write file if changed in either pass ---
if pass1_made_changes or pass2_made_changes > 0:
try:
with open(file_path, 'w') as f:
f.write(new_content)
print(f"Refactored {file_path}")
return True
except Exception as e:
print(f"Error writing to {file_path}: {e}", file=sys.stderr)
return False
return False
def find_and_refactor_files(root_dir):
"""
Finds candidate Go test files using `git grep` and then calls the
refactoring function on each of them.
"""
print("Finding candidate files using 'git grep'...")
refactoring_count = 0
try:
# Use subprocess to run an external command (`git grep`).
# This is much faster for finding files than walking the filesystem in Python.
cmd = ["git", "grep", "-l", "SetFeatureGateDuringTest"]
result = subprocess.run(
cmd,
capture_output=True, # Capture stdout/stderr
text=True, # Decode output as text
check=True, # Raise an exception if the command fails
cwd=root_dir # Run the command in the project root
)
# Get the list of files from stdout, splitting by newline.
files = result.stdout.strip().split('\n')
if not files or (len(files) == 1 and not files[0]):
print("No candidate files found.")
return 0
print(f"Found {len(files)} candidate files.")
for file_path in files:
if file_path: # Handle potential empty line from split
# `git grep` returns paths relative to the current working directory.
# We join it with the root directory to get the full path.
if refactor_feature_gate_calls(os.path.join(root_dir, file_path)):
refactoring_count += 1
return refactoring_count
except (subprocess.CalledProcessError, FileNotFoundError) as e:
print(f"Error running 'git grep': {e}. Please ensure 'git' is installed and you are in a git repository.", file=sys.stderr)
sys.exit(1)
# This is the standard entry point for a Python script.
# The code inside this block only runs when the script is executed directly.
if __name__ == "__main__":
# Get the project root from the command-line arguments, or default to ".".
project_root = sys.argv[1] if len(sys.argv) > 1 else "."
abs_project_root = os.path.abspath(project_root)
print(f"Starting refactoring in directory: {abs_project_root}")
refactored_count = find_and_refactor_files(abs_project_root)
print(f"Refactoring complete. {refactored_count} file(s) were refactored.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment