You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

549 lines
84 KiB
Plaintext

2 years ago
{
"cells": [
{
"cell_type": "markdown",
"id": "c1d7ea96-5551-4bfe-885c-cc70bdeee9b9",
"metadata": {},
"source": [
"Книга для обучения GCAEC классификатора с использованием Tensorflow из статьи https://gitlab.com/protsenkovi/efd_nn/"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c32ace82-fff8-4b9f-8eb9-0945079f3776",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num GPUs Available: 1\n",
"1 Physical GPUs, 1 Logical GPUs\n",
"2.12.0\n"
]
}
],
"source": [
"import os\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' \n",
"import random\n",
"import tensorflow as tf\n",
"# tf.autograph.set_verbosity(1)\n",
"# tf.get_logger().setLevel('INFO')\n",
"import tensorflow_addons as tfa\n",
"from IPython.display import display, clear_output\n",
"import seaborn as sns\n",
"import pandas as pd\n",
"import numpy as np\n",
"import warnings\n",
"\n",
"from tensorflow.keras import regularizers\n",
"from tensorflow.keras.callbacks import LearningRateScheduler\n",
"# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"-1\"\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Dropout, Reshape, SimpleRNN, GRU, LSTM, PReLU, MaxPooling1D, Flatten, AveragePooling1D, \\\n",
" GaussianNoise\n",
"from tensorflow.keras.layers import Conv1D, BatchNormalization\n",
"from tensorflow.keras.optimizers import Adam, Adamax, SGD\n",
"from tensorflow.keras import losses \n",
"from tensorflow.keras import metrics as kmetrics\n",
"from tensorflow.keras.saving import load_model\n",
"from tensorflow.keras.utils import plot_model\n",
"from sklearn.neural_network import MLPClassifier\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"from sklearn.metrics import confusion_matrix\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold\n",
"from time import time\n",
"\n",
"np.random.seed(42)\n",
"random.seed(42)\n",
"clear_output()\n",
"\n",
"tf.keras.utils.set_random_seed(42)\n",
"tf.config.experimental.enable_op_determinism()\n",
"print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))\n",
"\n",
"gpus = tf.config.list_physical_devices('GPU')\n",
"if gpus:\n",
" try:\n",
" # Currently, memory growth needs to be the same across GPUs\n",
" for gpu in gpus:\n",
" tf.config.experimental.set_memory_growth(gpu, True)\n",
" logical_gpus = tf.config.list_logical_devices('GPU')\n",
" print(len(gpus), \"Physical GPUs,\", len(logical_gpus), \"Logical GPUs\")\n",
" except RuntimeError as e:\n",
" # Memory growth must be set before GPUs have been initialized\n",
" print(e)\n",
" \n",
"print(tf.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "050e0c0a-629f-41d1-8c9f-9fcf72720bca",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((116600, 24, 6), (29150, 24, 6), (11054, 24, 6))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"window = 24\n",
"\n",
"X_train = np.load(f\"../assets/X_train_{window}.npy\")\n",
"y_train = np.load(f\"../assets/y_train_{window}.npy\")\n",
"X_val = np.load(f\"../assets/X_val_{window}.npy\")\n",
"y_val = np.load(f\"../assets/y_val_{window}.npy\")\n",
"X_test = np.load(f\"../assets/X_test_{window}.npy\")\n",
"y_test = np.load(f\"../assets/y_test_{window}.npy\")\n",
"\n",
"shuffled_index = np.arange(X_train.shape[0])\n",
"np.random.shuffle(shuffled_index)\n",
"X_train = X_train[shuffled_index]\n",
"y_train = y_train[shuffled_index]\n",
"\n",
"shuffled_index = np.arange(X_val.shape[0])\n",
"np.random.shuffle(shuffled_index)\n",
"X_val = X_val[shuffled_index]\n",
"y_val = y_val[shuffled_index]\n",
"\n",
"shuffled_index = np.arange(X_test.shape[0])\n",
"np.random.shuffle(shuffled_index)\n",
"X_test = X_test[shuffled_index]\n",
"y_test = y_test[shuffled_index]\n",
"\n",
"X_train.shape, X_val.shape, X_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "89fef4d0-bc58-4aa1-847a-26278045e8a5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import tensorflow.keras.layers as kl\n",
"import tensorflow.keras as k\n",
"\n",
"def new_pred(predict):\n",
" pred = []\n",
" for i in predict:\n",
" if i < 0.5:\n",
" pred.append(0)\n",
" else:\n",
" pred.append(1)\n",
" return pred\n",
"\n",
"conv_num = -1\n",
"def conv_block_name():\n",
" global conv_num\n",
" conv_num = conv_num + 1\n",
" return 'cb_{}'.format(conv_num)\n",
"\n",
"def conv_block(output_timesteps, output_channels, kernel_size, name, activity_regu=None):\n",
" def f(preceding_layer):\n",
" input_timesteps, input_channels = preceding_layer.get_shape().as_list()[1:]\n",
" \n",
" inputs = k.Input(shape=(input_timesteps, input_channels))\n",
" \n",
" act = kl.Conv1D(output_channels, kernel_size=kernel_size, activation='linear', padding='same', name=name+'_conv_features1', activity_regularizer=activity_regu)(inputs)\n",
" gate = kl.Conv1D(output_channels, kernel_size=kernel_size, activation='sigmoid', padding='same', name=name+'_conv_memory')(inputs)\n",
" gated_act = kl.Multiply()([tfa.layers.InstanceNormalization()(kl.PReLU()(act)), gate]) \n",
" \n",
" a = kl.Permute((2,1))(gated_act)\n",
" b = kl.Dense(output_timesteps, use_bias=False)(a)\n",
" c = kl.Permute((2,1))(b)\n",
" \n",
" m = k.Model(inputs=inputs, outputs=c, name=name)\n",
"# m.summary()\n",
" return m(preceding_layer)\n",
" return f\n",
"\n",
"def model_1D(input_shape):\n",
" inputs = k.Input(shape=input_shape)\n",
" e = conv_block(output_timesteps=window//2, output_channels=128, kernel_size=3, name=conv_block_name())(inputs)\n",
" e = conv_block(output_timesteps=2, output_channels=128, kernel_size=3, name=conv_block_name(), activity_regu=regularizers.l1(1e-3))(e)\n",
" c = kl.Flatten()(e)\n",
" c = Dense(1, activation='sigmoid', name='sigmoid_layer')(c)\n",
" d = conv_block(output_timesteps=window//2, output_channels=128, kernel_size=3, name=conv_block_name())(e)\n",
" d = conv_block(output_timesteps=input_shape[0], output_channels=input_shape[1], kernel_size=3, name=conv_block_name())(d)\n",
" decoder_output = kl.GaussianNoise(1e-2, name='decoder_output')(d)\n",
" \n",
" classifier_model = k.Model(inputs=inputs, outputs=c, name='classifier')\n",
" model = k.Model(inputs=inputs, outputs=[decoder_output, c], name='autoencoder')\n",
" return model, classifier_model\n",
"\n",
"class ProgressCallback(k.callbacks.Callback):\n",
" def on_epoch_end(self, epoch, logs={}):\n",
" results = pd.DataFrame(data=np.array([v for k,v in logs.items()]).reshape(2,4), \n",
" columns=['loss', 'ae_loss', 'classifier_loss', 'accuracy'],\n",
" index=['train','val'])\n",
" results = results.style.set_caption(f\"{epoch}\")\n",
" clear_output(wait=True)\n",
" display(results)\n",
"\n",
"\n",
"def train(X_train, y_train, X_val, y_val, model, epochs=50, lr=1e-2):\n",
" model.compile(loss={\n",
" 'decoder_output':losses.MSE,\n",
" 'sigmoid_layer':losses.BinaryCrossentropy()\n",
" }, \n",
" optimizer=Adam(learning_rate=lr), \n",
" metrics={\n",
" 'sigmoid_layer': kmetrics.BinaryAccuracy(),\n",
" })\n",
" try:\n",
" history = model.fit(X_train, \n",
" [X_train, y_train], \n",
" epochs=epochs,\n",
" batch_size=4096, \n",
" verbose=0,\n",
" validation_data=(X_val, (X_val, y_val)),\n",
" callbacks=[ProgressCallback()])\n",
" except KeyboardInterrupt as e:\n",
" history = []\n",
" return history"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "650486b7-542b-4fc1-a38f-028b74d61e42",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"autoencoder\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_1 (InputLayer) [(None, 24, 6)] 0 [] \n",
" \n",
" cb_0 (Functional) (None, 12, 128) 8480 ['input_1[0][0]'] \n",
" \n",
" cb_1 (Functional) (None, 2, 128) 100376 ['cb_0[0][0]'] \n",
" \n",
" cb_2 (Functional) (None, 12, 128) 99096 ['cb_1[0][0]'] \n",
" \n",
" cb_3 (Functional) (None, 24, 6) 4992 ['cb_2[0][0]'] \n",
" \n",
" flatten (Flatten) (None, 256) 0 ['cb_1[0][0]'] \n",
" \n",
" decoder_output (GaussianNoise) (None, 24, 6) 0 ['cb_3[0][0]'] \n",
" \n",
" sigmoid_layer (Dense) (None, 1) 257 ['flatten[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 213,201\n",
"Trainable params: 213,201\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"gcaec, classifier = model_1D(input_shape=X_train.shape[1:])\n",
"gcaec.summary()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d59ed374-3085-4a36-94dd-cc8670d9d6e3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_0c7ee\">\n",
" <caption>999</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_0c7ee_level0_col0\" class=\"col_heading level0 col0\" >loss</th>\n",
" <th id=\"T_0c7ee_level0_col1\" class=\"col_heading level0 col1\" >ae_loss</th>\n",
" <th id=\"T_0c7ee_level0_col2\" class=\"col_heading level0 col2\" >classifier_loss</th>\n",
" <th id=\"T_0c7ee_level0_col3\" class=\"col_heading level0 col3\" >accuracy</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_0c7ee_level0_row0\" class=\"row_heading level0 row0\" >train</th>\n",
" <td id=\"T_0c7ee_row0_col0\" class=\"data row0 col0\" >0.003677</td>\n",
" <td id=\"T_0c7ee_row0_col1\" class=\"data row0 col1\" >0.002197</td>\n",
" <td id=\"T_0c7ee_row0_col2\" class=\"data row0 col2\" >0.000171</td>\n",
" <td id=\"T_0c7ee_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0c7ee_level0_row1\" class=\"row_heading level0 row1\" >val</th>\n",
" <td id=\"T_0c7ee_row1_col0\" class=\"data row1 col0\" >4.735565</td>\n",
" <td id=\"T_0c7ee_row1_col1\" class=\"data row1 col1\" >0.002891</td>\n",
" <td id=\"T_0c7ee_row1_col2\" class=\"data row1 col2\" >4.731774</td>\n",
" <td id=\"T_0c7ee_row1_col3\" class=\"data row1 col3\" >0.619776</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f71403ebe20>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Training time = 11.216969986756643\n"
]
}
],
"source": [
"start = time()\n",
"\n",
"history = train(X_train=X_train, y_train=y_train, X_val=X_test, y_val=y_test, model=gcaec, lr=1e-4, epochs=1000)\n",
"\n",
"print(\"\\nTraining time = \", (time() - start) / 60)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8a019554-3ea7-4713-a0ae-e95c31d83822",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f71403e7fa0>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAGdCAYAAAD3zLwdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEeElEQVR4nO3dd3hUVd4H8O/0yaT3XiihhJIAAaRDiGJQpCg2xAgrlg0rGBvsLuiuIth4cdcIqytgAWF1AV1BEAKIINKDQCC0QICQ3utMZu77x0kmGRIggSQzId/P88yTmXvv3Hvuycw9vzntyiRJkkBEREREkFs7AURERES2goERERERUTUGRkRERETVGBgRERERVWNgRERERFSNgRERERFRNQZGRERERNUYGBERERFVU1o7AW2NyWRCeno6HB0dIZPJrJ0cIiIiagRJklBcXAw/Pz/I5devF2Jg1ETp6ekIDAy0djKIiIjoFly6dAkBAQHXXc/AqIkcHR0BiIx1cnKycmqIiIioMYqKihAYGGgux6+HgVEjJSQkICEhAUajEQDg5OTEwIiIiKiNuVk3GBlvIts0RUVFcHZ2RmFhIQMjIiKiNqKx5TdHpRERERFVY2DUSAkJCQgLC0P//v2tnRQiIiJqIWxKayI2pRERkVWU5gAn1gNXjwIlWYBaB8iV4qHUAl7dAYUacAkEdB6AvSfg5AdwahkAjS+/2fmaiEiSWHiQbSvLA/6vJ1BV3vT3ugQBnt0BtT3g4F0bOElGwGQEPLuKAMrRF5Armj/tbQwDIyJq305vAdY/Bwx/BXDyBYquAvYegE9vwDkA0DhYO4W2xVAOZJ4A/PuJYPL8z8CpH4C8VKAkE1CogPICy/fYe4iCWOsMaBxFAa1zB1RawNFP1Hio7cU6B2+xnalKLKsJWtti4GooB0qzRWDSGJIkAhVFddFsNIjH7v8DLv5qGRT1fRLw7Caem4xAWS6QdRIwlAH5F4DCS7XbFqSJR2OoHQA7N0AGwCVYpL2qAnDwARyr/zcVhYDWRdRSqbRAZUn1/9hdLPPoIj4LMjmgLxU1WwWXREBWlgec3wl0jhbnIwFwDQZUdsDlgyI4k0yAV7fGpbcFsCmtkeoO1z99+jSb0ojuFG8433i9nSvgGiIu5C7BopDw6CJ+ZSvU1Rd+e0DjVFugtZaaoKG8QBRCWicRtBjKRYF28nuRbtcQsf3v3wCX9gHRrwMyBfDL+0DgQLGfX/8BhI0Hej4oCtUtfwUu7hbvDb1HBImGCmDn22JfMe+JZp20X1vu/GQKcV4qXW2AKlOIQlhtL/4fkgmQy4GidECuEuucA8V6fSmgLxHnYKoSBbpSK5Yr1ICDpyiINU5A4WWgLEccT20vAhrvniIgyTkN+PcFKotFEFJVARj1QOovgE9Psf8qPZCfCvSYKIIClQ7Y9LJIs1cPoDgdKM8Xxw0eIgKMlE0izYbS2nOWKwGfXkBprtiPsdIyT3o+CDz4WeMCxfICIPccUFkE5J0T/7+CiyKv9KUigMpPBVT2lmmwBfNyRJDdjBrblMbAqInYx4jIhkiSeMjlQOEV4Ow2IOwBEcyYTED6EcA7TBRyuWfFr96M30XBd2G3KCT3LLHcp1IrCqYrh0Sh21hKrQgeZArRt6MgTbz2DQd0buKXf/FV8es4oL8IXlJ+BCCJghwQBW/eefHrPzMZ8OgsamIqi8UvaP9+Ik3nfxYFXF0yhWgauR7fCOBqUuPPp7HkKiD8UXGedq6iKcbBG6LKASJN+RdFMFFZJGoX9CXif1JVKc5XXyrOy1AughO6voe/EAFsc9OXif9JabYIjGUKIPO4CCYrSwClWvwtviq2lyRAXyz+h4BYV3T59tMhk4vaqD8dEt+bZsTAqIUwMCJqBEkSgYXOTVxgy/OBn98Rv3Z7PSR+dRdniF/rxkoRVGQlA/v+VVsT4N9X/Lq+sFsUtAUXAa8wUQ3/+1qgokAcS6EWv97r8uklquyLrkAU0I28zHn1AJ7ZKQoBk0lc+M9sBS7tFwV2Wa64cGeniAKiKYFTW6ZxFk00JoPlcoUGmLm/tkaqORgqROAEmchzhUoU2vpSkedKrfg/mAxAcaaoOTEagJIMEYy6dxb7MFWJ/RkrxftlMtFUZzKKz0xVhQjESjJE4KnzAJQa8bkBRO1bxnHxGZApgIBI0cSkUIl0OPqK2qSyHBHoVhSKmiU7V1HrU1UpAkFA1DzpS0VNo0sQYOcCXDkszsVkALJPiRqk+z8U6cxOEZ9pn54iuD+9BTjzk9jXrKPNm9/NSZLEuSg1gL2XqPFy6yi+JzI5cHGPqJ3z7Aqk7RX55Bwozrc8H+gyRuRDCzWbMjBqIQyM6I5gqBCFQ1muKECcfGvXFV4W/T70xcCBf4tCx8EL8OsjLvq73hOBi3tn4PDnIlCJeALoGgNc+k0UAuufbVp6FJr6TQbWMHU90CmqcdvqS0V/JEOZ6FuReVzUfmhdxC9nmUJc4HPOijwuviryUTIBZxNFodwpShSSMjlw+YAoKCRJFKgKlajhKskQTR5KO9Enw8FHBIVVFUD4Y8DRr0WhHjkNuLAH8OwCDH4B+N8sUfjc/Wb1/qqb/fo+CeSeAVI2i4LaVCX6ieSlAt3uF4VWeZ4o0ABRE6CyE+m+ehRI3gD8+k+xrsu9wONrm/3fYHNutXO+ydg8nZm3/Q3YvVg8f72gbfa3sgEcldbMrr0lCJHVGatEoekcIJ5Dqm2TNxlFM42dqyh49aWiCamiUPxiS/x77a9Z4PYDk6SvxOOWz6URx7b3FIV4eT7gFCA6cgYOELVINYFYx5EiUOsSA3QYJvrPaBxFPsgUIr9Kc8Tr7FNA2AQR6O16V7zfJbjxaVbbi6auGs7+jX+vySjORalp2ntk8vqF4sRlDW//xDoRSGkb6EPlGiz6Hl2PnUvtc3Pn8+pak9Rdtet8IxqR8DvArQYizTXCa8AM4MhXov8Sg6IWxxqjJmKNEbWKml+oRenAyR+q+67IRNV+SYb4pb9vmaihkMlrq6ptqWkn4gngzBbRZwEAPLqKJgsnP9FHQq4QtVPuncSIFUcfoP/T4hxSNgFBg4ErB4FTG4HeD4ugx1AuquevbUrISxV5FTQIKM0S+2qs4kzggy7i+V8yRe0JXd+Bz4CN8eL5o6uBbvdZNz1EjcQaIyJbknNWBDcFaaKmQ64UvwB//Yf41V2eXz301Q44sU68R66q36ejITXB0I2CIq1LbZ+curx6iPSc3lx/XU3tS6+HgW5jRdNZcYZoUqssFjUv5flA59Gi9injOOAXAWSdAnx7i9qQ8gIg4xgQMrRpv3R7TBR/u91nWfCq7BruX+HWQTyApgVFgBiCPO5DMUyZQdHNqXS1z316WS8dRC2EgRHRzRgNIpApvCw6AIcMra3RqSwRHS7zzosOkgUXxfBYjaPo+KvSNRyQ1JV3vuHlNwqKrq0d6jBcdHZM3iCaaAIGAPcuAtIPi4Cm0yjR3Hb8W1Hj0jfWsrkk/6IIgvo8cePq/5qC0N6jNhABRNNUh2HieWCd2+bYudQut2X9nrJ2CtoOQ1nt85rRdER3EDalNRGb0u5Q2adFHxw7VxFIqB1E59aT3wOb57Tsse1cAbdOIuC6tE90mO0xSQzN9u0thqF7dhFNURoH0ZwkV7X+nDlEgBi19e/RoqP2PW9aOzVEjcamNGqfCi9XdzaVif4tTgG1Q0NzzwLHvhVzuQx4Bji1qXnm3dA6i47BQXeJZqPSHDGi58IvovBQ6WonmfMKE/PqZBwXw3Ejpze9g6bK7vbTTHSrdG7AC0esnQqiFsPAqJE4Ks2GGCrEcGOlnRgeXnhZNCsdXVN/0rvr2f9J/WVqh+r5Uxrg4AOERosmNUOFGN4eNFhMXhd6T9ODG/9+TdueiIhaBZvSmohNaS1IkoCcM2JIuUdnIOlrMUlgzfT7eali+vqaidtuV0B/MaGYRxfRnBU8VEzbX3RVjJTSOoumLb8+rKUhImrj2JRGtin/oggy0pNE0KGvvj/PmZ9EwNPYGh8A5lsOAIBHqJiJNmSo2IdKBwyZJWp4ynJFx2R7D1GzVJ4
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"f, [a1,a2] = plt.subplots(2,1)\n",
"a1.plot(history.history['loss'], label='loss')\n",
"a1.plot(history.history['val_loss'], label='val_loss')\n",
"a1.set_yscale('log')\n",
"a1.legend()\n",
"a2.plot(history.history['sigmoid_layer_binary_accuracy'], label='accuracy')\n",
"a2.plot(history.history['val_sigmoid_layer_binary_accuracy'], label='val_accuracy')\n",
"a2.set_yscale('log')\n",
"a2.legend()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "255de68d-e904-4b25-b9cc-aca0a4c18e35",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"threshold = 0.5"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6ceac140-7f0f-4061-92cb-c290f7376a36",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3644/3644 [==============================] - 8s 2ms/step\n",
"Train Accuracy: 1.0\n",
"F1: 1.0\n"
]
}
],
"source": [
"pred_train = classifier.predict(X_train) > threshold\n",
"conf_train = confusion_matrix(y_train, pred_train)\n",
"acc_train = (conf_train[0][0] + conf_train[1][1]) / (conf_train[0][0] + conf_train[1][1] + conf_train[0][1] + conf_train[1][0])\n",
"f1_train = (2 * conf_train[1][1]) / (2 * conf_train[1][1] + conf_train[0][1] + conf_train[1][0])\n",
"print('Train Accuracy: ', acc_train)\n",
"print('F1: ', f1_train)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5d66fa98-4ce0-4084-8271-997bc7af6335",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"911/911 [==============================] - 2s 2ms/step\n",
"Val Accuracy: 0.9884734133790738\n",
"F1: 0.9885464957731115\n"
]
}
],
"source": [
"pred_val = classifier.predict(X_val) > threshold\n",
"conf_val = confusion_matrix(y_val, pred_val)\n",
"acc_val = (conf_val[0][0] + conf_val[1][1]) / (conf_val[0][0] + conf_val[1][1] + conf_val[0][1] + conf_val[1][0])\n",
"f1_val = (2 * conf_val[1][1]) / (2 * conf_val[1][1] + conf_val[0][1] + conf_val[1][0])\n",
"print('Val Accuracy: ', acc_val)\n",
"print('F1: ', f1_val)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "18b0ba57-d083-4c52-b797-af034c726810",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"346/346 [==============================] - 1s 2ms/step\n",
"Test Accuracy: 0.6197756468246789\n",
"F1: 0.45969919012726573\n"
]
}
],
"source": [
"predict_cnn = classifier.predict(X_test) > threshold\n",
"conf_test = confusion_matrix(y_test, predict_cnn)\n",
"acc_test = (conf_test[0][0] + conf_test[1][1]) / (\n",
" conf_test[0][0] + conf_test[1][1] + conf_test[0][1] + conf_test[1][0])\n",
"f1_test = (2 * conf_test[1][1]) / (2 * conf_test[1][1] + conf_test[0][1] + conf_test[1][0])\n",
"print('Test Accuracy: ', acc_test)\n",
"print('F1: ', f1_test)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b621e097-30f1-4910-894c-cadd370f00ae",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhAAAAGdCAYAAABDxkoSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA6mUlEQVR4nO3deVxV1d7H8S8gIIOAoICkqGkqOFxTS9DqZqJUNFh2S7O0NEtDfRxSs2tmWg+lqWnm0JXEBqdyuCmpOWIqOVA4y2PlDQsBJ0RU5v38kZzrAeqwEYLs8369zusla6+9zjol+OW31t7bzjAMQwAAACbYV/UEAADAnw8BAgAAmEaAAAAAphEgAACAaQQIAABgGgECAACYRoAAAACmESAAAIBpBAgAAGBajaqeQJG0pcuqegpAtVPrvnuregpAteTq6Vmp41fkv0l+vZ6osLGqEyoQAABUExMnTpSdnZ3Vq0WLFpbj2dnZioyMlI+Pj9zd3dWzZ0+lpaVZjZGcnKyIiAi5urrK19dXo0ePVn5+vlWfbdu2qV27dnJ2dlbTpk0VExNjeq4ECAAAqpGWLVvq1KlTlteOHTssx0aMGKE1a9bos88+U1xcnFJSUvToo49ajhcUFCgiIkK5ubnatWuXFi1apJiYGE2YMMHS58SJE4qIiFCXLl2UmJio4cOH67nnntOGDRtMzbPaLGEAAFBt2NlV2VvXqFFD/v7+JdovXLig6OhoLV68WPfcc48kaeHChQoKCtI333yjkJAQffXVVzpy5Ig2bdokPz8/tW3bVpMnT9bYsWM1ceJEOTk5ad68eWrcuLGmTZsmSQoKCtKOHTs0Y8YMhYeHl3meVCAAAKhEOTk5yszMtHrl5OT8Zv/jx48rICBAN998s/r06aPk5GRJUkJCgvLy8hQWFmbp26JFCwUGBio+Pl6SFB8fr9atW8vPz8/SJzw8XJmZmTp8+LClz7VjFPUpGqOsCBAAAFSiqKgoeXp6Wr2ioqJK7duxY0fFxMRo/fr1mjt3rk6cOKE777xTFy9eVGpqqpycnOTl5WV1jp+fn1JTUyVJqampVuGh6HjRsd/rk5mZqStXrpT5c7GEAQBAJRo3bpxGjhxp1ebs7Fxq3/vuu8/y5zZt2qhjx45q2LChli9fLhcXl0qdp1lUIAAAqETOzs7y8PCwev1WgCjOy8tLzZo10/fffy9/f3/l5uYqIyPDqk9aWpplz4S/v3+JqzKKvrbVx8PDw1RIIUAAAFCcnV3Fva5DVlaWfvjhB9WrV0/t27eXo6OjNm/ebDmelJSk5ORkhYaGSpJCQ0N18OBBpaenW/ps3LhRHh4eCg4OtvS5doyiPkVjlBUBAgCAauKll15SXFyc/vOf/2jXrl165JFH5ODgoN69e8vT01MDBgzQyJEjtXXrViUkJOjZZ59VaGioQkJCJEndu3dXcHCwnn76ae3fv18bNmzQ+PHjFRkZaal6DBo0SD/++KPGjBmjY8eOac6cOVq+fLlGjBhhaq7sgQAAoJr4+eef1bt3b509e1Z169bVHXfcoW+++UZ169aVJM2YMUP29vbq2bOncnJyFB4erjlz5ljOd3Bw0Nq1azV48GCFhobKzc1N/fr106RJkyx9GjdurNjYWI0YMUIzZ85U/fr1tWDBAlOXcEqSnWEYRsV87OvDrayBkriVNVC6Sr+V9fLPKmwsv8f/UWFjVScsYQAAANNYwgAAoLgqvBPlnwUBAgCAYogPtrGEAQAATCNAAAAA0wgQAADANPZAAABQHJsobaICAQAATCNAAAAA0wgQAADANAIEAAAwjU2UAAAUxyZKm6hAAAAA0wgQAADANAIEAAAwjQABAABMI0AAAADTuAoDAIDiuArDJioQAADANAIEAAAwjQABAABMYw8EAADFsQXCJgIEAAAlkCBsYQkDAACYRoAAAACmESAAAIBpBAgAAGAamygBACiOO1HaRAUCAACYRoAAAACmESAAAIBp7IEAAKAYtkDYRoAAAKAEEoQtLGEAAADTCBAAAMA0AgQAADCNAAEAAExjEyUAAMVxGYZNVCAAAIBpBAgAAGAaAQIAAJjGHggAAIpjC4RNVCAAAIBpVCAAACiBEoQtVCAAAIBpBAgAAGAaAQIAAJjGHggAAIrjTpQ2UYEAAACmESAAAIBpBAgAAGAaeyAAACiOLRA2UYEAAACmUYEAAKAEShC2UIEAAACmESAAAIBpBAgAAGAaeyAAACiOO1HaRAUCAACYRoAAAACmsYQBAEAxLGDYRoAAAKA4EoRNLGEAAADTqEAAAFACJQhbqEAAAADTCBAAAMC0ci1hGIahzz//XFu3blV6eroKCwutjq9cubJCJgcAAKqncgWI4cOHa/78+erSpYv8/Pxkxx27AAA3Ev5ds6lcAeLjjz/WypUrdf/991f0fAAAwJ9AufZAeHp66uabb67ouQAAgKveeust2dnZafjw4Za27OxsRUZGysfHR+7u7urZs6fS0tKszktOTlZERIRcXV3l6+ur0aNHKz8/36rPtm3b1K5dOzk7O6tp06aKiYkxPb9yBYiJEyfq9ddf15UrV8pzOgAA+B179+7V/Pnz1aZNG6v2ESNGaM2aNfrss88UFxenlJQUPfroo5bjBQUFioiIUG5urnbt2qVFixYpJiZGEyZMsPQ5ceKEIiIi1KVLFyUmJmr48OF67rnntGHDBlNztDMMwzD7wa5cuaJHHnlEO3fuVKNGjeTo6Gh1/NtvvzU7pNKWLjN9DnCjq3XfvVU9BaBacvX0rNTxz5j8x/T31AkPN9U/KytL7dq105w5c/TGG2+obdu2evfdd3XhwgXVrVtXixcv1mOPPSZJOnbsmIKCghQfH6+QkBCtW7dODzzwgFJSUuTn5ydJmjdvnsaOHavTp0/LyclJY8eOVWxsrA4dOmR5z169eikjI0Pr168v8zzLtQeiX79+SkhI0FNPPcUmSgAAfkdOTo5ycnKs2pydneXs7Fxq/8jISEVERCgsLExvvPGGpT0hIUF5eXkKCwuztLVo0UKBgYGWABEfH6/WrVtbwoMkhYeHa/DgwTp8+LBuvfVWxcfHW41R1OfapZKyKFeAiI2N1YYNG3THHXeU53QAAKq5ivvFOCoqSq+//rpV22uvvaaJEyeW6Lt06VJ9++232rt3b4ljqampcnJykpeXl1W7n5+fUlNTLX2uDQ9Fx4uO/V6fzMxMXblyRS4uLmX6XOUKEA0aNJCHh0d5TgUA4C9l3LhxGjlypFVbadWHkydP6n/+53+0ceNG1axZ84+aXrmVK0BMmzZNY8aM0bx589SoUaMKnhKu9eHWLYrZts2qLbBOHX0ydFip/U+kpyt6yxb936kUpWZkaMi99+rx0E6VPs+thw8pessWpWZk6CZvbw3q1l2hzZpZjn+4dYu2HDqk9AsXVMPBQc0DAjSwa1cF129Q6XPDje/DRYv03vvv68levTS62A/qa128eFGz587Vlq1bdSEzU/X8/fXSyJG6s3PnSpvbxk2bNGf+fKWcOqXABg00bMgQy/vl5edrzty52rFrl37+5Re5u7ur4223adiQIfKtW7fS5oQ/1u8tV1wrISFB6enpateunaWtoKBA27dv1+zZs7Vhwwbl5uYqIyPDqgqRlpYmf39/SZK/v7/27NljNW7RVRrX9il+5UZaWpo8PDzKXH2QyhkgnnrqKV2+fFlNmjSRq6triU2U586dK8+w+A2NfX01vW8/y9cO9r998Ux2Xp4CatdWl5Yt9d76dRXy/t+dOKGo1au0fETpP5gPJidr0uef6/muYQpt3lybDhzQP5cu0YIXBunmq2WyBj51NPz+CAXUrq2c/Hwtj9+lUR99pCX/M1xebm4VMk/8NR0+ckQrVq7ULU2b/m6/vLw8DRoyRN7e3pr61lvyrVtXKampquXuXu733peQoAmTJunLf/+71OOJBw5o3KuvauiLL+rOO+7Qug0bNHL0aC35+GM1bdJE2dnZOpqUpIH9+6tZs2bKzMzU1OnTNXzUKC3+6KNyzwt/Tl27dtXBgwet2p599lm1aNFCY8eOVYMGDeTo6Kj
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"aa = pd.DataFrame(data=conf_test, columns=['Norm', 'Anom'], index=['Norm', 'Anom'])\n",
"sns.heatmap(aa, annot=True, cmap=sns.blend_palette(['#f5f0f0','#e8a7a8'], as_cmap=True));"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e164894d-c7f3-45a6-87a7-07a5ac13a9b6",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
]
}
],
"source": [
"model_file_path = f\"../results/gcaec_{window}.h5\"\n",
"classifier.save(model_file_path, save_format=\"h5\")\n",
"pd.DataFrame(\n",
" data=[[model_file_path, acc_train, f1_train, acc_val, f1_val, acc_test, f1_test]], \n",
" columns=['model', 'acc_train', 'f1_train', 'acc_val', 'f1_val', 'acc_test', 'f1_test']\n",
").to_csv(f\"{model_file_path}_stats.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6ed0f38-3ac6-4a8b-855d-5959bfdbcf20",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}