Compare commits

...

2 Commits

Author SHA1 Message Date
Wojciech Janota b0ec18fedb Work 3 weeks ago
Wojciech Janota b2402f4336 Modify network arch 2 months ago

@ -18,27 +18,49 @@ def custom_loss_function(y_true, y_pred):
quality_coefficient = 0.7 quality_coefficient = 0.7
return abs(size_minimizer_part * size_coefficient + mse_original_jpeg_part * quality_coefficient) return abs(size_minimizer_part * size_coefficient + mse_original_jpeg_part * quality_coefficient)
@register_keras_serializable()
def custom_loss_function_entropy(y_true, y_pred):
mse_original_jpeg_part = K.mean(K.square(y_pred - y_true))
entropy_part = tf.reduce_mean(tf.losses.categorical_crossentropy(y_true, y_pred)) / 64
entropy_coefficient = 0.3
quality_coefficient = 0.7
return abs(quality_coefficient * mse_original_jpeg_part - entropy_coefficient * entropy_part)
class NeuralNetworkEncoder: class NeuralNetworkEncoder:
def __init__(self, pretrained_weights_path: str = None, internal_activation_function: str = None, external_activation_function: str = None, optimizer: str = None, loss_function: str = None, image_dimension_x: int = None, image_dimension_y: int = None): def __init__(self, pretrained_weights_path: str = None, internal_activation_function: str = None, external_activation_function: str = None, optimizer: str = None, loss_function: str = None, image_dimension_x: int = None, image_dimension_y: int = None):
tf.config.threading.set_intra_op_parallelism_threads(NUM_THREADS) tf.config.threading.set_intra_op_parallelism_threads(NUM_THREADS)
tf.config.threading.set_inter_op_parallelism_threads(NUM_THREADS) tf.config.threading.set_inter_op_parallelism_threads(NUM_THREADS)
if pretrained_weights_path: if pretrained_weights_path:
self.model = load_model(pretrained_weights_path, custom_objects={'custom_loss_function': custom_loss_function}) self.model = load_model(pretrained_weights_path, custom_objects={'custom_loss_function': custom_loss_function_entropy})
else: else:
self.model = keras.Sequential([ self.model = keras.Sequential([
layers.Reshape((512, 512, 1), input_shape=(262144,)), layers.Reshape((512, 512, 1), input_shape=(262144,)),
#layers.InputLayer(input_shape=(512 * 512, 1, 1)), #layers.InputLayer(input_shape=(512 * 512, 1, 1)),
layers.Conv2D(32, (3, 3), activation=internal_activation_function, padding='same'), layers.Conv2D(256, (3, 3), activation=internal_activation_function, padding='same'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation=internal_activation_function, padding='same'),
layers.MaxPooling2D((2, 2)), layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation=internal_activation_function, padding='same'), layers.Conv2D(128, (3, 3), activation=internal_activation_function, padding='same'),
layers.MaxPooling2D((2, 2)), layers.MaxPooling2D((2, 2)),
layers.Conv2D(256, (3, 3), activation=internal_activation_function, padding='same'), layers.Conv2D(64, (3, 3), activation=internal_activation_function, padding='same'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(32, (3, 3), activation=internal_activation_function, padding='same'),
layers.Flatten(), layers.Flatten(),
layers.Dense(64, activation=external_activation_function) layers.Dense(64, activation=external_activation_function)
]) ])
# version 2 of the model for research purposes
# self.model = keras.Sequential([
# layers.Reshape((512, 512, 1), input_shape=(262144,)),
# #layers.InputLayer(input_shape=(512 * 512, 1, 1)),
# layers.Conv2D(32, (3, 3), activation=internal_activation_function, padding='same'),
# layers.MaxPooling2D((2, 2)),
# layers.Conv2D(64, (3, 3), activation=internal_activation_function, padding='same'),
# layers.MaxPooling2D((2, 2)),
# layers.Conv2D(128, (3, 3), activation=internal_activation_function, padding='same'),
# layers.MaxPooling2D((2, 2)),
# layers.Conv2D(256, (3, 3), activation=internal_activation_function, padding='same'),
# layers.Flatten(),
# layers.Dense(64, activation=external_activation_function)
# ])
self.model.compile(optimizer=optimizer, loss=custom_loss_function) self.model.compile(optimizer=optimizer, loss=custom_loss_function)

Loading…
Cancel
Save