Last active
September 16, 2023 00:45
-
-
Save tuhdo/4a550107957698ab7af03dd10b4e664e to your computer and use it in GitHub Desktop.
Let DecisionTreeRegressor learn how to perform addition and multiplication
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
# Create and open the CSV file for writing | |
def gen_dataset(fname, op): | |
start = 1 | |
end = 500 | |
with open(fname, mode='w', newline='') as csv_file: | |
csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) | |
# Write a header row (optional) | |
csv_writer.writerow(['Operand1', 'Operand2', 'Sum']) | |
# Generate and write the addition dataset | |
if "+" == op: | |
for operand1 in range(start, end + 1): | |
for operand2 in range(start, end + 1): | |
result = operand1 + operand2 | |
csv_writer.writerow([operand1, operand2, result]) | |
elif "*" == op: | |
for operand1 in range(start, end + 1): | |
for operand2 in range(start, end + 1): | |
result = operand1 * operand2 | |
csv_writer.writerow([operand1, operand2, result]) | |
print(f"Arithmetic {op} dataset generated and saved to {fname}.") | |
gen_dataset("add_dataset.csv", "+") | |
gen_dataset("mul_dataset.csv", "*") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from sklearn.tree import DecisionTreeRegressor | |
from sklearn.ensemble import RandomForestRegressor | |
from sklearn.preprocessing import OrdinalEncoder | |
from sklearn.model_selection import train_test_split | |
from sklearn.tree import export_graphviz | |
import graphviz | |
# uncomment to choose dataset to train | |
dataset_path = 'add_dataset.csv' | |
# dataset_path = 'mul_dataset.csv' | |
df = pd.read_csv(dataset_path) | |
# Split the dataset into features and label | |
X = df[['Operand1', 'Operand2']] | |
y = df['Sum'] | |
# Split the dataset into train set and val set | |
# On my machine, trained model did addition slightly wrong (e.g. 1 + 1 = 1.1) when test_size = 0.85 | |
# Likewise, for multiplication, results stopped being accurate when test_size = 0.66. | |
test_size = 0.65 | |
random_state = 1 | |
is_shuffle = True | |
X_train, X_val, y_train, y_val = train_test_split( | |
X, y, | |
test_size=test_size, | |
random_state=random_state, | |
shuffle=is_shuffle | |
) | |
regressor = DecisionTreeRegressor(random_state=random_state) | |
regressor.fit(X_train, y_train) | |
# Train the model with random forest | |
# regressor = RandomForestRegressor(random_state=random_state) | |
# regressor.fit(X_train, y_train) | |
# Create a feature vector with your input data (operands) | |
# In this example, we're using [5, 7] for prediction | |
feature_vector = np.array([[353, 103]]) # add test | |
# feature_vector = np.array([[5, 15]]) # mul test | |
# Make a prediction | |
result = regressor.predict(feature_vector) | |
print(f"feature_vector: {feature_vector}") | |
print(f"The predicted sum is: {result[0]}") | |
# Uncomment below code to generate tree visualization with Graphviz. | |
# # Export the first decision tree to a DOT file | |
# tree = regressor.estimators_[0] # Get the first decision tree from the forest | |
# dot_data = export_graphviz(tree, out_file=None, | |
# feature_names=["Operand1", "Operand2"], # Replace with your feature names | |
# filled=True, rounded=True, special_characters=True) | |
# # Create a graph from the DOT data and display it | |
# graph = graphviz.Source(dot_data) | |
# graph.view(filename="add_dataset_sm") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment