Created
December 13, 2024 14:19
-
-
Save codepediair/788252afb7d5039e4ea8b401a820f81a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import required libraries | |
import pandas as pd | |
import yfinance as yf | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sklearn.model_selection import train_test_split | |
from sklearn.linear_model import LinearRegression | |
from sklearn.metrics import mean_squared_error | |
import numpy as np | |
# Retrieve historical stock data | |
def get_historical_data(ticker, start_date, end_date): | |
""" | |
Retrieves historical stock data for a given ticker symbol and date range. | |
Args: | |
ticker (str): The ticker symbol of the stock. | |
start_date (str): The start date of the date range. | |
end_date (str): The end date of the date range. | |
Returns: | |
pd.DataFrame: A pandas DataFrame containing the historical stock data. | |
""" | |
data = yf.download(ticker, start=start_date, end=end_date) | |
return data | |
# Prepare data for training | |
def prepare_data(data): | |
""" | |
Prepares the historical stock data for training a machine learning model. | |
Args: | |
data (pd.DataFrame): The historical stock data. | |
Returns: | |
pd.DataFrame: A pandas DataFrame containing the prepared data. | |
""" | |
data['Date'] = pd.to_datetime(data.index) | |
data['Date'] = data['Date'].apply(lambda date: date.timestamp()) | |
data['Close_Lag1'] = data['Close'].shift(1) | |
data['Close_Lag2'] = data['Close'].shift(2) | |
data.dropna(inplace=True) | |
return data | |
# Train a linear regression model | |
def train_model(data): | |
""" | |
Trains a linear regression model using the prepared data. | |
Args: | |
data (pd.DataFrame): The prepared data. | |
Returns: | |
LinearRegression: A trained linear regression model. | |
""" | |
X = data[['Date', 'Close_Lag1', 'Close_Lag2']] | |
y = data['Close'] | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
model = LinearRegression() | |
model.fit(X_train, y_train) | |
return model | |
# Make predictions using the trained model | |
def make_predictions(model, data): | |
""" | |
Makes predictions using the trained linear regression model. | |
Args: | |
model (LinearRegression): The trained linear regression model. | |
data (pd.DataFrame): The data to make predictions on. | |
Returns: | |
np.array: An array of predicted values. | |
""" | |
predictions = model.predict(data[['Date', 'Close_Lag1', 'Close_Lag2']]) | |
return predictions | |
# Evaluate the model | |
def evaluate_model(y_test, predictions): | |
""" | |
Evaluates the performance of the linear regression model. | |
Args: | |
y_test (pd.Series): The actual values. | |
predictions (np.array): The predicted values. | |
Returns: | |
float: The mean squared error of the model. | |
""" | |
mse = mean_squared_error(y_test, predictions) | |
return mse | |
# Main function | |
def main(): | |
# Retrieve historical stock data | |
ticker = 'AAPL' | |
start_date = '2020-01-01' | |
end_date = '2022-12-31' | |
data = get_historical_data(ticker, start_date, end_date) | |
# Prepare data for training | |
data = prepare_data(data) | |
# Train a linear regression model | |
model = train_model(data) | |
# Make predictions using the trained model | |
predictions = make_predictions(model, data) | |
# Evaluate the model | |
y_test = data['Close'] | |
mse = evaluate_model(y_test, predictions) | |
print(f'Mean Squared Error: {mse}') | |
# Plot the predicted values | |
plt.figure(figsize=(10, 6)) | |
sns.set() | |
plt.plot(data['Close'], label='Actual') | |
plt.plot(predictions, label='Predicted') | |
plt.legend() | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Explanation
This code retrieves historical stock data for a given ticker symbol and date range using the yfinance library. It then prepares the data for training a machine learning model by converting the date column to a numerical format and creating lagged features. The code trains a linear regression model using the prepared data and makes predictions using the trained model. Finally, it evaluates the performance of the model using the mean squared error metric and plots the predicted values.
Example Use Case
To use this code, simply run the main function and specify the ticker symbol, start date, and end date as desired. The code will retrieve the historical stock data, train a linear regression model, make predictions, and evaluate the model's performance.
Advice