How can I use a weighted loss with LightningCLI? #17953
-
I would like to use a weighted loss function in a LightningModule model. The weights can be computed after initializing the train dataset of the LightningDataModule, ergo I only have access to the weights after DataModule.setup(). However, LightningCLI only calls the respective init methods. I would need to setup the DataModule before initializing the model. So, what is the best solution to achieve this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can use an argument link applied on instantiation, see argument-linking. This could be done in several ways, depending on when/where you want class MyModule(LightningModule):
def __init__(self, data_module: LightningDataModule):
# data setup could be called here or later
...
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data", "model.data_module", apply_on="instantiate")
cli = MyLightningCLI(MyModule, MyDataModule) Other possibilities are:
The link target would need |
Beta Was this translation helpful? Give feedback.
You can use an argument link applied on instantiation, see argument-linking. This could be done in several ways, depending on when/where you want
setup()
to be called. One possibility could be that theLightningModule
gets the instance of theLightningDataModule
. This is something like:Other possibilities are:
LightningDat…