-
Notifications
You must be signed in to change notification settings - Fork 51
Open
Description
Hello,
Thank you for your work.
I am doing domain adaptation with DANN. I would like to save the best model using model checkpoint based on the loss value of the task network:
chk = ModelCheckpoint(os.path.join(model_directory_name,'Model'),
monitor="loss",
verbose=1,
save_best_only=True,
save_weights_only=False,
mode='min',
save_freq=1)
During the trainning, i keep receiving this warning:
WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.
This is the training code:
# define callbacks
callbacks_list = [chk]
# Build model
model = DANN(encoder=encoder(), task=task(), discriminator=discriminitor(),
Xt=Xt,lambda_= 0.1, metrics=["acc"],random_state=0)
# start training
model_log = model.fit(Xs, ys,epochs = 2, callbacks=callbacks_list, verbose=1, class_weight=class_weights)
Metadata
Metadata
Assignees
Labels
No labels