Skip to content

Commit da245fb

Browse files
committed
Update ndBrainAgentDeterministicPolicyGradient_Trainer.cpp
1 parent 782f293 commit da245fb

File tree

1 file changed

+8
-29
lines changed

1 file changed

+8
-29
lines changed

newton-4.00/sdk/dBrain/ndBrainAgentDeterministicPolicyGradient_Trainer.cpp

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
#define ND_TD3_POLICY_MAX_PER_ACTION_SIGMA ndBrainFloat(1.0f)
4444
#define ND_MAX_SAC_ENTROPY_COEFFICIENT ndBrainFloat (2.0e-5f)
4545

46-
#define ND_TD3_VARIANCE
46+
//#define ND_TD3_VARIANCE
4747

4848
ndBrainAgentDeterministicPolicyGradient_Trainer::HyperParameters::HyperParameters()
4949
{
@@ -381,6 +381,9 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::BuildCriticClass()
381381
layers.PushBack(new ndBrainLayerLinear(layers[layers.GetCount() - 1]->GetOutputSize(), m_parameters.m_hiddenLayersNumberOfNeurons));
382382
layers.PushBack(new ndBrainLayerActivationTanh(layers[layers.GetCount() - 1]->GetOutputSize()));
383383
layers.PushBack(new ndBrainLayerLinear(layers[layers.GetCount() - 1]->GetOutputSize(), 1));
384+
#ifdef ND_TD3_VARIANCE
385+
layers.PushBack(new ndBrainLayerActivationRelu(layers[layers.GetCount() - 1]->GetOutputSize()));
386+
#endif
384387

385388
ndSharedPtr<ndBrain> critic(new ndBrain);
386389
for (ndInt32 i = 0; i < layers.GetCount(); ++i)
@@ -740,21 +743,6 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction()
740743
const ndBrain& brain = **m_policyTrainer->GetBrain();
741744
ndInt32 criticInputSize = brain.GetInputSize() + brain.GetOutputSize();
742745

743-
#ifdef ND_TD3_VARIANCE
744-
ndInt32 count = m_parameters.m_criticUpdatesCount * m_parameters.m_miniBatchSize;
745-
m_miniBatchIndexBuffer.SetCount(0);
746-
for (ndInt32 i = 0; i < count; ++i)
747-
{
748-
m_miniBatchIndexBuffer.PushBack(m_shuffleBuffer[m_shuffleBatchIndex]);
749-
m_shuffleBatchIndex++;
750-
if (m_shuffleBatchIndex >= m_shuffleBuffer.GetCount())
751-
{
752-
m_shuffleBatchIndex = 0;
753-
m_shuffleBuffer.RandomShuffle(m_shuffleBuffer.GetCount());
754-
}
755-
}
756-
#endif
757-
758746
m_actionBatch.SetCount(m_parameters.m_miniBatchSize * brain.GetOutputSize());
759747
m_obsevationsBatch.SetCount(m_parameters.m_miniBatchSize * brain.GetInputSize());
760748
m_policyGradientBatch.SetCount(m_parameters.m_miniBatchSize * brain.GetOutputSize());
@@ -788,19 +776,10 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction()
788776

789777
for (ndInt32 i = 0; i < ndInt32(sizeof(m_criticTrainer) / sizeof(m_criticTrainer[0])); ++i)
790778
{
791-
#ifdef ND_TD3_VARIANCE
792-
m_referenceCriticTrainer[i]->MakePrediction(m_criticObservationActionBatch);
793-
m_referenceCriticTrainer[i]->GetOutput(m_criticOutputGradients[i]);
794-
m_criticTrainer[i]->MakePrediction(m_criticObservationActionBatch);
795-
m_criticTrainer[i]->GetOutput(m_criticValue[i]);
796-
797-
m_criticOutputGradients[i].Sub(m_criticValue[i]);
798-
m_criticOutputGradients[i].Scale(ndBrainFloat(-1.0f));
799-
#else
800-
m_criticTrainer[i]->MakePrediction(m_criticObservationActionBatch);
801-
m_criticTrainer[i]->GetOutput(m_criticValue[i]);
802-
m_criticOutputGradients[i].Set(ndBrainFloat(1.0f));
803-
#endif
779+
m_criticTrainer[i]->MakePrediction(m_criticObservationActionBatch);
780+
m_criticTrainer[i]->GetOutput(m_criticValue[i]);
781+
m_criticOutputGradients[i].Set(ndBrainFloat(1.0f));
782+
804783
if (m_parameters.m_entropyRegularizerCoef > ndBrainFloat(1.0e-6f))
805784
{
806785
ndAssert(0);

0 commit comments

Comments
 (0)