Skip to content

Commit 8b3c8c3

Browse files
committed
refactoring
1 parent c9eaa36 commit 8b3c8c3

File tree

5 files changed

+27
-27
lines changed

5 files changed

+27
-27
lines changed
14.9 KB
Loading
47.2 KB
Loading

stock_prediction_deep_learning.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import os
1616
import secrets
1717
import pandas as pd
18-
import tensorflow as tf
1918
from datetime import datetime
2019

2120
from stock_prediction_class import StockPrediction
@@ -28,27 +27,18 @@
2827

2928
def train_LSTM_network(stock):
3029
data = StockData(stock)
31-
3230
plotter = Plotter(True, stock.get_project_folder(), data.get_stock_short_name(), data.get_stock_currency(), stock.get_ticker())
33-
34-
(x_train, y_train), (x_test, y_test), (min_max, test_data) = data.download_transform_to_numpy(TIME_STEPS)
35-
36-
print(x_test)
31+
(x_train, y_train), (x_test, y_test), (training_data, test_data) = data.download_transform_to_numpy(TIME_STEPS)
32+
plotter.plot_histogram_data_split(training_data, test_data, stock.get_validation_date())
3733

3834
lstm = LongShortTermMemory(stock.get_project_folder())
3935
model = lstm.create_model(x_train)
40-
41-
defined_metrics = [
42-
tf.keras.metrics.MeanSquaredError(name='MSE')
43-
]
44-
45-
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, mode='min', verbose=1)
46-
47-
model.compile(optimizer='adam', loss='mean_squared_error', metrics=defined_metrics)
36+
model.compile(optimizer='adam', loss='mean_squared_error', metrics=lstm.get_defined_metrics())
4837
history = model.fit(x_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(x_test, y_test),
49-
callbacks=[callback])
38+
callbacks=[lstm.get_callback()])
5039
print("saving weights")
5140
model.save(os.path.join(stock.get_project_folder(), 'model_weights.h5'))
41+
5242
plotter.plot_loss(history)
5343
plotter.plot_mse(history)
5444

@@ -60,7 +50,7 @@ def train_LSTM_network(stock):
6050

6151
print("plotting prediction results")
6252
test_predictions_baseline = model.predict(x_test)
63-
test_predictions_baseline = min_max.inverse_transform(test_predictions_baseline)
53+
test_predictions_baseline = data.get_min_max().inverse_transform(test_predictions_baseline)
6454
test_predictions_baseline = pd.DataFrame(test_predictions_baseline)
6555
test_predictions_baseline.to_csv(os.path.join(stock.get_project_folder(), 'predictions.csv'))
6656

stock_prediction_lstm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ class LongShortTermMemory:
2222
def __init__(self, project_folder):
2323
self.project_folder = project_folder
2424

25+
def get_defined_metrics(self):
26+
defined_metrics = [
27+
tf.keras.metrics.MeanSquaredError(name='MSE')
28+
]
29+
return defined_metrics
30+
31+
def get_callback(self):
32+
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, mode='min', verbose=1)
33+
return callback
34+
2535
def create_model(self, x_train):
2636
model = Sequential()
2737
# 1st layer with Dropout regularisation

stock_prediction_numpy.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class StockData:
2626
def __init__(self, stock):
2727
self._stock = stock
2828
self._sec = yf.Ticker(self._stock.get_ticker())
29+
self._min_max = MinMaxScaler(feature_range=(0, 1))
2930

3031
def __data_verification(self, train):
3132
print('mean:', train.mean(axis=0))
@@ -36,28 +37,27 @@ def __data_verification(self, train):
3637
def get_stock_short_name(self):
3738
return self._sec.info['shortName']
3839

40+
def get_min_max(self):
41+
return self._min_max
42+
3943
def get_stock_currency(self):
4044
return self._sec.info['currency']
4145

4246
def download_transform_to_numpy(self, time_steps):
43-
min_max = MinMaxScaler(feature_range=(0, 1))
4447
end_date = datetime.today()
4548
print('End Date: ' + end_date.strftime("%Y-%m-%d"))
4649
data = yf.download([self._stock.get_ticker()], start=self._stock.get_start_date(), end=end_date)[['Close']]
4750
data = data.reset_index()
48-
print(data)
49-
50-
plotter = Plotter(True, self._stock.get_project_folder(), self._sec.info['shortName'], self._sec.info['currency'], self._stock.get_ticker())
51+
#print(data)
5152

5253
training_data = data[data['Date'] < self._stock.get_validation_date()].copy()
5354
test_data = data[data['Date'] >= self._stock.get_validation_date()].copy()
5455
training_data = training_data.set_index('Date')
5556
# Set the data frame index using column Date
5657
test_data = test_data.set_index('Date')
57-
print(test_data)
58-
plotter.plot_histogram_data_split(training_data, test_data, self._stock.get_validation_date())
58+
#print(test_data)
5959

60-
train_scaled = min_max.fit_transform(training_data)
60+
train_scaled = self._min_max.fit_transform(training_data)
6161
self.__data_verification(train_scaled)
6262

6363
# Training Data Transformation
@@ -72,7 +72,7 @@ def download_transform_to_numpy(self, time_steps):
7272

7373
total_data = pd.concat((training_data, test_data), axis=0)
7474
inputs = total_data[len(total_data) - len(test_data) - time_steps:]
75-
test_scaled = min_max.fit_transform(inputs)
75+
test_scaled = self._min_max.fit_transform(inputs)
7676

7777
# Testing Data Transformation
7878
x_test = []
@@ -83,16 +83,16 @@ def download_transform_to_numpy(self, time_steps):
8383

8484
x_test, y_test = np.array(x_test), np.array(y_test)
8585
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))
86-
return (x_train, y_train), (x_test, y_test), (min_max, test_data)
86+
return (x_train, y_train), (x_test, y_test), (training_data, test_data)
8787

88-
def __daterange(self, start_date, end_date):
88+
def __date_range(self, start_date, end_date):
8989
for n in range(int((end_date - start_date).days)):
9090
yield start_date + timedelta(n)
9191

9292
def generate_future_data(self, time_steps, min_max, start_date, end_date):
9393
x_future = []
9494
y_future = []
95-
for single_date in self.__daterange(start_date, end_date):
95+
for single_date in self.__date_range(start_date, end_date):
9696
x_future.append(single_date)
9797
y_future.append(random.uniform(10, 100))
9898

0 commit comments

Comments
 (0)