@@ -260,10 +260,6 @@ ndBrainAgentDeterministicPolicyGradient_Trainer::ndBrainAgentDeterministicPolicy
260
260
:ndBrainThreadPool()
261
261
,m_name()
262
262
,m_parameters(parameters)
263
- // ,m_policy()
264
- // ,m_referencePolicy()
265
- // ,m_policyTrainers()
266
- // ,m_policyOptimizer()
267
263
,m_expectedRewards()
268
264
,m_miniBatchIndexBuffer()
269
265
,m_replayBuffer()
@@ -365,15 +361,12 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::BuildPolicyClass()
365
361
policyOptimizer->SetRegularizer (m_parameters.m_policyRegularizer );
366
362
policyOptimizer->SetRegularizerType (m_parameters.m_policyRegularizerType );
367
363
m_policyTrainer = ndSharedPtr<ndBrainTrainerCpu>(new ndBrainTrainerCpu (policy, policyOptimizer, this , m_parameters.m_miniBatchSize ));
368
-
369
- ndSharedPtr<ndBrain> referencePolicy (new ndBrain (**policy));
370
- m_referencePolicy = ndSharedPtr<ndBrainTrainerCpuInference>(new ndBrainTrainerCpuInference (referencePolicy, this , m_parameters.m_miniBatchSize ));
371
364
}
372
365
373
366
void ndBrainAgentDeterministicPolicyGradient_Trainer::BuildCriticClass ()
374
367
{
375
368
const ndBrain& policy = **m_policyTrainer->GetBrain ();
376
- for (ndInt32 k = 0 ; k < sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]); ++k)
369
+ for (ndInt32 k = 0 ; k < ndInt32 ( sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]) ); ++k)
377
370
{
378
371
ndFixSizeArray<ndBrainLayer*, 32 > layers;
379
372
@@ -603,45 +596,13 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::SaveTrajectory()
603
596
// #pragma optimize( "", off )
604
597
void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnQvalueFunction (ndInt32 criticIndex)
605
598
{
606
- // ndInt32 base = 0;
607
- // ndAtomic<ndInt32> iterator(0);
608
- // for (ndInt32 n = m_parameters.m_criticUpdatesCount - 1; n >= 0; --n)
609
- // {
610
- // auto BackPropagateBatch = ndMakeObject::ndFunction([this, &iterator, base, criticIndex](ndInt32, ndInt32)
611
- // {
612
- // ndBrainFixSizeVector<256> criticObservationAction;
613
- // criticObservationAction.SetCount(m_policy.GetOutputSize() + m_policy.GetInputSize());
614
- //
615
- // ndBrainLossLeastSquaredError loss(1);
616
- // ndBrainFixSizeVector<1> groundTruth;
617
- // for (ndInt32 i = iterator++; i < m_parameters.m_miniBatchSize; i = iterator++)
618
- // {
619
- // const ndInt32 index = m_miniBatchIndexBuffer[i + base];
620
- // ndBrainTrainer& trainer = *m_criticTrainers[criticIndex][i];
621
- //
622
- // ndMemCpy(&criticObservationAction[0], m_replayBuffer.GetActions(index), m_policy.GetOutputSize());
623
- // ndMemCpy(&criticObservationAction[m_policy.GetOutputSize()], m_replayBuffer.GetObservations(index), m_policy.GetInputSize());
624
- //
625
- // groundTruth[0] = m_expectedRewards[i + base];
626
- // loss.SetTruth(groundTruth);
627
- // trainer.BackPropagate(criticObservationAction, loss);
628
- // }
629
- // });
630
- //
631
- // iterator = 0;
632
- // ndBrainThreadPool::ParallelExecute(BackPropagateBatch);
633
- // m_criticOptimizer[criticIndex]->Update(this, m_criticTrainers[criticIndex], m_parameters.m_criticLearnRate);
634
- // base += m_parameters.m_miniBatchSize;
635
- // }
636
599
const ndBrain& brain = **m_policyTrainer->GetBrain ();
637
600
ndInt32 criticInputSize = brain.GetInputSize () + brain.GetOutputSize ();
638
601
639
602
m_criticValue.SetCount (m_parameters.m_miniBatchSize );
640
603
m_criticGradients.SetCount (m_parameters.m_miniBatchSize );
641
604
m_criticObservationActionBatch.SetCount (m_parameters.m_miniBatchSize * criticInputSize);
642
605
643
- // ndMemSet((ndInt32*)&m_criticObservationActionBatch[0], ndInt32(0xffffffff), m_criticObservationActionBatch.GetCount());
644
-
645
606
for (ndInt32 n = 0 ; n < m_parameters.m_criticUpdatesCount ; ++n)
646
607
{
647
608
for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
@@ -688,62 +649,6 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
688
649
m_shuffleBuffer.RandomShuffle (m_shuffleBuffer.GetCount ());
689
650
}
690
651
}
691
-
692
- // ndAtomic<ndInt32> iterator(0);
693
- // auto ExpectedRewards = ndMakeObject::ndFunction([this, count, &iterator](ndInt32, ndInt32)
694
- // {
695
- // const ndInt32 batchSize = 128;
696
- // ndBrainFixSizeVector<256> policyEntropyAction;
697
- // ndBrainFixSizeVector<256> criticNextObservationAction;
698
- //
699
- // ndAssert(count % batchSize == 0);
700
- // policyEntropyAction.SetCount(m_policy.GetOutputSize());
701
- // criticNextObservationAction.SetCount(m_policy.GetOutputSize() + m_policy.GetInputSize());
702
- // for (ndInt32 base = iterator.fetch_add(batchSize); base < count; base = iterator.fetch_add(batchSize))
703
- // {
704
- // for (ndInt32 j = 0; j < batchSize; ++j)
705
- // {
706
- // ndBrainFixSizeVector<ND_NUMBER_OF_CRITICS> rewards;
707
- // rewards.SetCount(0);
708
- // const ndInt32 index = m_miniBatchIndexBuffer[j + base];
709
- // ndBrainFloat r = m_replayBuffer.GetReward(index);
710
- // for (ndInt32 i = 0; i < sizeof(m_critic) / sizeof(m_critic[0]); ++i)
711
- // {
712
- // rewards.PushBack(r);
713
- // }
714
- // if (!m_replayBuffer.GetTerminalState(index))
715
- // {
716
- // ndBrainMemVector nextAction(&criticNextObservationAction[0], m_policy.GetOutputSize());
717
- // const ndBrainMemVector nextObservation(m_replayBuffer.GetNextObservations(index), m_policy.GetInputSize());
718
- // m_policy.MakePrediction(nextObservation, nextAction);
719
- // ndMemCpy(&criticNextObservationAction[m_policy.GetOutputSize()], &nextObservation[0], nextObservation.GetCount());
720
- //
721
- // ndBrainFixSizeVector<1> criticQvalue;
722
- // for (ndInt32 i = 0; i < sizeof(m_critic) / sizeof(m_critic[0]); ++i)
723
- // {
724
- // m_referenceCritic[i].MakePrediction(criticNextObservationAction, criticQvalue);
725
- // rewards[i] += m_parameters.m_discountRewardFactor * criticQvalue[0];
726
- // }
727
- // }
728
- //
729
- // ndBrainFloat minQ = ndBrainFloat(1.0e10f);
730
- // for (ndInt32 i = 0; i < sizeof(m_critic) / sizeof(m_critic[0]); ++i)
731
- // {
732
- // minQ = ndMin(minQ, rewards[i]);
733
- // }
734
- //
735
- // // calculate entropy
736
- // if (m_parameters.m_entropyRegularizerCoef > ndBrainFloat(1.0e-6f))
737
- // {
738
- // ndBrainFloat prob = CalculatePolicyProbability(index);
739
- // ndBrainFloat logProb = ndBrainFloat(ndLog(prob));
740
- // minQ -= m_parameters.m_entropyRegularizerCoef * logProb;
741
- // }
742
- // m_expectedRewards[j + base] = minQ;
743
- // }
744
- // }
745
- // });
746
- // ndBrainThreadPool::ParallelExecute(ExpectedRewards);
747
652
748
653
const ndBrain& brain = **m_policyTrainer->GetBrain ();
749
654
ndInt32 criticInputSize = brain.GetInputSize () + brain.GetOutputSize ();
@@ -752,7 +657,7 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
752
657
m_nextActionBatch.SetCount (m_parameters.m_miniBatchSize * brain.GetOutputSize ());
753
658
m_nextObsevationsBatch.SetCount (m_parameters.m_miniBatchSize * brain.GetInputSize ());
754
659
m_criticNextObservationActionBatch.SetCount (m_parameters.m_miniBatchSize * criticInputSize);
755
- for (ndInt32 i = 0 ; i < sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]); ++i)
660
+ for (ndInt32 i = 0 ; i < ndInt32 ( sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]) ); ++i)
756
661
{
757
662
m_rewardBatch[i].SetCount (m_parameters.m_miniBatchSize );
758
663
}
@@ -765,7 +670,7 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
765
670
766
671
// get the state rewards
767
672
ndBrainFloat r = m_replayBuffer.GetReward (index);
768
- for (ndInt32 j = 0 ; j < sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]); ++j)
673
+ for (ndInt32 j = 0 ; j < ndInt32 ( sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]) ); ++j)
769
674
{
770
675
m_rewardBatch[j][i] = r;
771
676
}
@@ -798,7 +703,7 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
798
703
ndMemCpy (&criticNextObservationAction[brain.GetOutputSize ()], &nextObsevations[0 ], brain.GetInputSize ());
799
704
}
800
705
801
- for (ndInt32 i = 0 ; i < sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]); ++i)
706
+ for (ndInt32 i = 0 ; i < ndInt32 ( sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]) ); ++i)
802
707
{
803
708
m_referenceCriticTrainer[i]->MakePrediction (m_criticNextObservationActionBatch);
804
709
m_referenceCriticTrainer[i]->GetOutput (m_nextQValue);
@@ -815,7 +720,7 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
815
720
816
721
for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
817
722
{
818
- for (ndInt32 j = 1 ; j < sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]); ++j)
723
+ for (ndInt32 j = 1 ; j < ndInt32 ( sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]) ); ++j)
819
724
{
820
725
m_rewardBatch[0 ][i] = ndMin (m_rewardBatch[0 ][i], m_rewardBatch[j][i]);
821
726
}
@@ -836,7 +741,6 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
836
741
// #pragma optimize( "", off )
837
742
void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction ()
838
743
{
839
-
840
744
// ndAtomic<ndInt32> iterator(0);
841
745
// ndInt32 base = 0;
842
746
// for (ndInt32 n = m_parameters.m_policyUpdatesCount - 1; n >= 0; --n)
@@ -951,8 +855,23 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction()
951
855
// base += m_parameters.m_miniBatchSize;
952
856
// }
953
857
858
+ const ndBrain& brain = **m_policyTrainer->GetBrain ();
859
+ ndInt32 criticInputSize = brain.GetInputSize () + brain.GetOutputSize ();
860
+
861
+ m_obsevationsBatch.SetCount (m_parameters.m_miniBatchSize * brain.GetInputSize ());
862
+
954
863
for (ndInt32 n = 0 ; n < m_parameters.m_policyUpdatesCount ; ++n)
955
864
{
865
+ for (ndInt32 i = 0 ; i < m_parameters.m_miniBatchSize ; ++i)
866
+ {
867
+ const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
868
+
869
+ ndBrainMemVector observation (&m_obsevationsBatch[i * brain.GetInputSize ()], brain.GetInputSize ());
870
+ ndMemCpy (&observation[0 ], m_replayBuffer.GetActions (index), brain.GetOutputSize ());
871
+ }
872
+ m_policyTrainer->MakePrediction (m_obsevationsBatch);
873
+
874
+
956
875
957
876
}
958
877
}
@@ -961,21 +880,20 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction()
961
880
void ndBrainAgentDeterministicPolicyGradient_Trainer::Optimize ()
962
881
{
963
882
CalculateExpectedRewards ();
964
- for (ndInt32 k = 0 ; k < sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]); ++k)
883
+ for (ndInt32 k = 0 ; k < ndInt32 ( sizeof (m_criticTrainer) / sizeof (m_criticTrainer[0 ]) ); ++k)
965
884
{
966
885
LearnQvalueFunction (k);
967
886
}
968
887
969
888
if (!ndPolycyDelayMod)
970
889
{
971
890
LearnPolicyFunction ();
972
- // m_referencePolicy.SoftCopy(m_policy, m_parameters.m_policyMovingAverageFactor);
973
891
}
974
892
// for (ndInt32 k = 0; k < sizeof(m_critic) / sizeof(m_critic[0]); ++k)
975
893
// {
976
894
// m_referenceCritic[k].SoftCopy(m_critic[k], m_parameters.m_criticMovingAverageFactor);
977
895
// }
978
- // ndPolycyDelayMod = (ndPolycyDelayMod + 1) % ND_TD3_POLICY_DELAY_MOD;
896
+ ndPolycyDelayMod = (ndPolycyDelayMod + 1 ) % ND_TD3_POLICY_DELAY_MOD;
979
897
}
980
898
981
899
// #pragma optimize( "", off )
0 commit comments