Hm
This commit is contained in:
parent
42066f7065
commit
4d13dd3528
|
@ -18,19 +18,51 @@ training_labels = []
|
||||||
test_img = []
|
test_img = []
|
||||||
test_labels_ = []
|
test_labels_ = []
|
||||||
|
|
||||||
for piece in OUR_PIECES:
|
|
||||||
# training set
|
|
||||||
for _ in range(10):
|
|
||||||
for filename in glob.glob(f"../training_images/{piece}/*_square/*.png")[:-50]:
|
|
||||||
training_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
|
||||||
training_labels.append(piece)
|
|
||||||
|
|
||||||
# test set
|
# training set
|
||||||
for _ in range(5):
|
for _ in range(10):
|
||||||
for filename in glob.glob(f"../training_images/{piece}/*_square/*.png")[-50:]:
|
for filename in glob.glob(f"../training_images/rook/*_square/*.png")[:-50]:
|
||||||
test_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
training_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
||||||
test_labels_.append(piece)
|
training_labels.append(0)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
for filename in glob.glob(f"../training_images/knight/*_square/*.png")[:-50]:
|
||||||
|
training_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
||||||
|
training_labels.append(1)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
for filename in glob.glob(f"../training_images/bishop/*_square/*.png")[:-50]:
|
||||||
|
training_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
||||||
|
training_labels.append(2)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
for filename in glob.glob(f"../training_images/empty/*_square/*.png")[:-7300]:
|
||||||
|
training_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
||||||
|
training_labels.append(3)
|
||||||
|
|
||||||
|
# test set
|
||||||
|
for _ in range(5):
|
||||||
|
for filename in glob.glob(f"../training_images/rook/*_square/*.png")[-50:]:
|
||||||
|
test_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
||||||
|
test_labels_.append(0)
|
||||||
|
|
||||||
|
# test set
|
||||||
|
for _ in range(5):
|
||||||
|
for filename in glob.glob(f"../training_images/knight/*_square/*.png")[-50:]:
|
||||||
|
test_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
||||||
|
test_labels_.append(1)
|
||||||
|
|
||||||
|
# test set
|
||||||
|
for _ in range(5):
|
||||||
|
for filename in glob.glob(f"../training_images/bishop/*_square/*.png")[-50:]:
|
||||||
|
test_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
||||||
|
test_labels_.append(2)
|
||||||
|
|
||||||
|
# test set
|
||||||
|
for _ in range(5):
|
||||||
|
for filename in glob.glob(f"../training_images/empty/*_square/*.png")[-50:]:
|
||||||
|
test_img.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY))
|
||||||
|
test_labels_.append(3)
|
||||||
|
|
||||||
width, height = training_img[0].shape
|
width, height = training_img[0].shape
|
||||||
|
|
||||||
|
@ -51,7 +83,7 @@ model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
|
||||||
|
|
||||||
model.add(tf.keras.layers.Flatten())
|
model.add(tf.keras.layers.Flatten())
|
||||||
model.add(tf.keras.layers.Dense(64, activation='relu'))
|
model.add(tf.keras.layers.Dense(64, activation='relu'))
|
||||||
model.add(tf.keras.layers.Dense(3, activation='softmax'))
|
model.add(tf.keras.layers.Dense(4, activation='softmax'))
|
||||||
|
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
|
@ -66,5 +98,5 @@ test_loss, test_acc = model.evaluate(test_images, test_labels_)
|
||||||
print(test_acc)
|
print(test_acc)
|
||||||
|
|
||||||
# Save entire model to a HDF5 file
|
# Save entire model to a HDF5 file
|
||||||
model.save('chess_model_3_pieces.h5')
|
model.save('pls_model.h5')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user