@@ -401,7 +401,8 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::BuildCriticClass()
401
401
}
402
402
}
403
403
404
- ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability (ndInt32 index, const ndBrainVector& distribution) const
404
+ // ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability(ndInt32 index, const ndBrainVector& distribution) const
405
+ ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability (ndInt32, const ndBrainVector&) const
405
406
{
406
407
ndAssert (0 );
407
408
return 0 ;
@@ -447,7 +448,8 @@ ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyPro
447
448
}
448
449
449
450
// #pragma optimize( "", off )
450
- ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability (ndInt32 index) const
451
+ // ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability(ndInt32 index) const
452
+ ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability (ndInt32) const
451
453
{
452
454
ndAssert (0 );
453
455
return 0 ;
@@ -599,35 +601,34 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnQvalueFunction(ndInt3
599
601
const ndBrain& brain = **m_policyTrainer->GetBrain ();
600
602
ndInt32 criticInputSize = brain.GetInputSize () + brain.GetOutputSize ();
601
603
602
- m_criticValue.SetCount (m_parameters.m_miniBatchSize );
603
- m_criticGradients .SetCount (m_parameters.m_miniBatchSize );
604
+ m_criticValue[ 0 ] .SetCount (m_parameters.m_miniBatchSize );
605
+ m_criticOutputGradients[ 0 ] .SetCount (m_parameters.m_miniBatchSize );
604
606
m_criticObservationActionBatch.SetCount (m_parameters.m_miniBatchSize * criticInputSize);
605
-
607
+
606
608
for (ndInt32 n = 0 ; n < m_parameters.m_criticUpdatesCount ; ++n)
607
609
{
608
610
for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
609
611
{
610
612
const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
611
-
613
+
612
614
ndBrainMemVector criticObservationAction (&m_criticObservationActionBatch[i * criticInputSize], criticInputSize);
613
615
ndMemCpy (&criticObservationAction[0 ], m_replayBuffer.GetActions (index), brain.GetOutputSize ());
614
616
ndMemCpy (&criticObservationAction[brain.GetOutputSize ()], m_replayBuffer.GetObservations (index), brain.GetInputSize ());
615
617
}
616
618
m_criticTrainer[criticIndex]->MakePrediction (m_criticObservationActionBatch);
617
- m_criticTrainer[criticIndex]->GetOutput (m_criticValue);
618
-
619
+ m_criticTrainer[criticIndex]->GetOutput (m_criticValue[ 0 ] );
620
+
619
621
ndBrainLossLeastSquaredError loss (1 );
620
622
ndBrainFixSizeVector<1 > groundTruth;
621
623
for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
622
624
{
623
625
groundTruth[0 ] = m_expectedRewards[n * m_parameters.m_miniBatchSize + i];
624
626
loss.SetTruth (groundTruth);
625
- const ndBrainMemVector output (&m_criticValue[i], 1 );
626
- ndBrainMemVector gradient (&m_criticGradients [i], 1 );
627
+ const ndBrainMemVector output (&m_criticValue[0 ][ i], 1 );
628
+ ndBrainMemVector gradient (&m_criticOutputGradients[ 0 ] [i], 1 );
627
629
loss.GetLoss (output, gradient);
628
630
}
629
- // trainer.BackPropagate(criticObservationAction, loss);
630
- m_criticTrainer[criticIndex]->BackPropagate (m_criticGradients);
631
+ m_criticTrainer[criticIndex]->BackPropagate (m_criticOutputGradients[0 ]);
631
632
m_criticTrainer[criticIndex]->ApplyLearnRate (m_parameters.m_criticLearnRate );
632
633
}
633
634
}
@@ -661,20 +662,20 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
661
662
{
662
663
m_rewardBatch[i].SetCount (m_parameters.m_miniBatchSize );
663
664
}
664
-
665
+
665
666
for (ndInt32 n = 0 ; n < m_parameters.m_criticUpdatesCount ; ++n)
666
667
{
667
668
for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
668
669
{
669
670
const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
670
-
671
+
671
672
// get the state rewards
672
673
ndBrainFloat r = m_replayBuffer.GetReward (index);
673
674
for (ndInt32 j = 0 ; j < ndInt32 (sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ])); ++j)
674
675
{
675
676
m_rewardBatch[j][i] = r;
676
677
}
677
-
678
+
678
679
// get the next state actions
679
680
ndBrainMemVector nextObsevations (&m_nextObsevationsBatch[i * brain.GetInputSize ()], brain.GetInputSize ());
680
681
if (m_replayBuffer.GetTerminalState (index))
@@ -689,16 +690,16 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
689
690
}
690
691
m_policyTrainer->MakePrediction (m_nextObsevationsBatch);
691
692
m_policyTrainer->GetOutput (m_nextActionBatch);
692
-
693
- // calculate the expected rward for each action
693
+
694
+ // calculate the expected reward for each action
694
695
for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
695
696
{
696
- const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
697
-
697
+ // const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
698
+
698
699
const ndBrainMemVector nextAction (&m_nextActionBatch[i * brain.GetOutputSize ()], brain.GetOutputSize ());
699
700
const ndBrainMemVector nextObsevations (&m_nextObsevationsBatch[i * brain.GetInputSize ()], brain.GetInputSize ());
700
701
ndBrainMemVector criticNextObservationAction (&m_criticNextObservationActionBatch[i * criticInputSize], criticInputSize);
701
-
702
+
702
703
ndMemCpy (&criticNextObservationAction[0 ], &nextAction[0 ], brain.GetOutputSize ());
703
704
ndMemCpy (&criticNextObservationAction[brain.GetOutputSize ()], &nextObsevations[0 ], brain.GetInputSize ());
704
705
}
@@ -707,7 +708,7 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
707
708
{
708
709
m_referenceCriticTrainer[i]->MakePrediction (m_criticNextObservationActionBatch);
709
710
m_referenceCriticTrainer[i]->GetOutput (m_nextQValue);
710
-
711
+
711
712
const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
712
713
for (ndInt32 j = 0 ; j < m_parameters.m_miniBatchSize ; ++j)
713
714
{
@@ -717,20 +718,20 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
717
718
}
718
719
}
719
720
}
720
-
721
+
721
722
for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
722
723
{
723
724
for (ndInt32 j = 1 ; j < ndInt32 (sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ])); ++j)
724
725
{
725
726
m_rewardBatch[0 ][i] = ndMin (m_rewardBatch[0 ][i], m_rewardBatch[j][i]);
726
727
}
727
728
}
728
-
729
+
729
730
if (m_parameters.m_entropyRegularizerCoef > ndBrainFloat (1 .0e-6f ))
730
731
{
731
732
ndAssert (0 );
732
733
}
733
-
734
+
734
735
for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
735
736
{
736
737
m_expectedRewards[n * m_parameters.m_miniBatchSize + i] = m_rewardBatch[0 ][i];
@@ -858,7 +859,16 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction()
858
859
const ndBrain& brain = **m_policyTrainer->GetBrain ();
859
860
ndInt32 criticInputSize = brain.GetInputSize () + brain.GetOutputSize ();
860
861
862
+ m_actionBatch.SetCount (m_parameters.m_miniBatchSize * brain.GetOutputSize ());
861
863
m_obsevationsBatch.SetCount (m_parameters.m_miniBatchSize * brain.GetInputSize ());
864
+ m_criticObservationActionBatch.SetCount (m_parameters.m_miniBatchSize * criticInputSize);
865
+
866
+ for (ndInt32 i = 0 ; i < ndInt32 (sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ])); ++i)
867
+ {
868
+ m_criticValue[i].SetCount (m_parameters.m_miniBatchSize );
869
+ m_criticOutputGradients[i].SetCount (m_parameters.m_miniBatchSize );
870
+ m_criticInputGradients[i].SetCount (m_parameters.m_miniBatchSize * criticInputSize);
871
+ }
862
872
863
873
for (ndInt32 n = 0 ; n < m_parameters.m_policyUpdatesCount ; ++n)
864
874
{
@@ -870,9 +880,40 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction()
870
880
ndMemCpy (&observation[0 ], m_replayBuffer.GetActions (index), brain.GetOutputSize ());
871
881
}
872
882
m_policyTrainer->MakePrediction (m_obsevationsBatch);
883
+ m_policyTrainer->GetOutput (m_actionBatch);
873
884
885
+ for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
886
+ {
887
+ ndBrainMemVector criticObservationAction (&m_criticObservationActionBatch[i * criticInputSize], criticInputSize);
888
+ ndMemCpy (&criticObservationAction[0 ], &m_actionBatch[i * brain.GetOutputSize ()], brain.GetOutputSize ());
889
+ ndMemCpy (&criticObservationAction[brain.GetOutputSize ()], &m_obsevationsBatch[i * brain.GetInputSize ()], brain.GetInputSize ());
890
+ }
874
891
892
+ for (ndInt32 i = 0 ; i < ndInt32 (sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ])); ++i)
893
+ {
894
+ m_criticTrainer[i]->MakePrediction (m_criticObservationActionBatch);
895
+ m_policyTrainer->GetOutput (m_criticValue[i]);
896
+ if (m_parameters.m_entropyRegularizerCoef > ndBrainFloat (1 .0e-6f ))
897
+ {
898
+ ndAssert (0 );
899
+ }
900
+ m_criticOutputGradients[i].Set (ndBrainFloat (1 .0f ));
901
+ m_criticTrainer[i]->BackPropagate (m_criticOutputGradients[i]);
902
+ m_criticTrainer[i]->GetInput (m_criticInputGradients[i]);
903
+ }
875
904
905
+ for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
906
+ {
907
+ for (ndInt32 j = 1 ; j < ndInt32 (sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ])); ++j)
908
+ {
909
+ if (m_criticValue[j][i] < m_criticValue[0 ][i])
910
+ {
911
+ ndBrainMemVector dstObservationAction (&m_criticInputGradients[0 ][i * criticInputSize], criticInputSize);
912
+ const ndBrainMemVector srcObservationAction (&m_criticInputGradients[j][i * criticInputSize], criticInputSize);
913
+ dstObservationAction.Set (srcObservationAction);
914
+ }
915
+ }
916
+ }
876
917
}
877
918
}
878
919
0 commit comments