from __future__ import absolute_import, division, print_function, unicode_literals import glob import cv2 import numpy as np import tensorflow as tf # exit() from util import OUR_PIECES # (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data() # print(train_images[0]) training_img = [] training_labels = [] test_img = [] test_labels_ = [] for piece in OUR_PIECES: # training set for _ in range(10): for filename in glob.glob(f"../training_images/{piece}/*_square/*.png")[:-50]: training_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY)) training_labels.append(piece) # test set for _ in range(5): for filename in glob.glob(f"../training_images/{piece}/*_square/*.png")[-50:]: test_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY)) test_labels_.append(piece) width, height = training_img[0].shape training_img = np.array(training_img).reshape((len(training_img), width, height, 1)) test_img = np.array(test_img).reshape((len(test_img),width, height, 1)) # Normalize pixel values to be between 0 and 1 train_images, test_images = training_img / 255.0, test_img / 255.0 model = tf.keras.models.Sequential() model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu', input_shape=(width, height, 1))) model.add(tf.keras.layers.MaxPooling2D((2, 2))) model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu')) model.add(tf.keras.layers.MaxPooling2D((2, 2))) model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu')) model.add(tf.keras.layers.MaxPooling2D((2, 2))) model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu')) model.add(tf.keras.layers.Flatten()) model.add(tf.keras.layers.Dense(64, activation='relu')) model.add(tf.keras.layers.Dense(3, activation='softmax')) model.summary() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(train_images, training_labels, epochs=3) test_loss, test_acc = model.evaluate(test_images, test_labels_) print(test_acc) # Save entire model to a HDF5 file model.save('chess_model_3_pieces.h5')