Skip to content

Commit ed13ad1

Browse files
committed
more refactoring (wip)
1 parent f038d3d commit ed13ad1

7 files changed

+193
-108
lines changed

newton-4.00/sdk/dBrain/ndBrainAgentDeterministicPolicyGradient_Trainer.cpp

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,8 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::BuildCriticClass()
401401
}
402402
}
403403

404-
ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability(ndInt32 index, const ndBrainVector& distribution) const
404+
//ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability(ndInt32 index, const ndBrainVector& distribution) const
405+
ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability(ndInt32, const ndBrainVector&) const
405406
{
406407
ndAssert(0);
407408
return 0;
@@ -447,7 +448,8 @@ ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyPro
447448
}
448449

449450
//#pragma optimize( "", off )
450-
ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability(ndInt32 index) const
451+
//ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability(ndInt32 index) const
452+
ndBrainFloat ndBrainAgentDeterministicPolicyGradient_Trainer::CalculatePolicyProbability(ndInt32) const
451453
{
452454
ndAssert(0);
453455
return 0;
@@ -599,35 +601,34 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnQvalueFunction(ndInt3
599601
const ndBrain& brain = **m_policyTrainer->GetBrain();
600602
ndInt32 criticInputSize = brain.GetInputSize() + brain.GetOutputSize();
601603

602-
m_criticValue.SetCount(m_parameters.m_miniBatchSize);
603-
m_criticGradients.SetCount(m_parameters.m_miniBatchSize);
604+
m_criticValue[0].SetCount(m_parameters.m_miniBatchSize);
605+
m_criticOutputGradients[0].SetCount(m_parameters.m_miniBatchSize);
604606
m_criticObservationActionBatch.SetCount(m_parameters.m_miniBatchSize * criticInputSize);
605-
607+
606608
for (ndInt32 n = 0; n < m_parameters.m_criticUpdatesCount; ++n)
607609
{
608610
for (ndInt32 i = 0; i < m_parameters.m_miniBatchSize; ++i)
609611
{
610612
const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
611-
613+
612614
ndBrainMemVector criticObservationAction(&m_criticObservationActionBatch[i * criticInputSize], criticInputSize);
613615
ndMemCpy(&criticObservationAction[0], m_replayBuffer.GetActions(index), brain.GetOutputSize());
614616
ndMemCpy(&criticObservationAction[brain.GetOutputSize()], m_replayBuffer.GetObservations(index), brain.GetInputSize());
615617
}
616618
m_criticTrainer[criticIndex]->MakePrediction(m_criticObservationActionBatch);
617-
m_criticTrainer[criticIndex]->GetOutput(m_criticValue);
618-
619+
m_criticTrainer[criticIndex]->GetOutput(m_criticValue[0]);
620+
619621
ndBrainLossLeastSquaredError loss(1);
620622
ndBrainFixSizeVector<1> groundTruth;
621623
for (ndInt32 i = 0; i < m_parameters.m_miniBatchSize; ++i)
622624
{
623625
groundTruth[0] = m_expectedRewards[n * m_parameters.m_miniBatchSize + i];
624626
loss.SetTruth(groundTruth);
625-
const ndBrainMemVector output(&m_criticValue[i], 1);
626-
ndBrainMemVector gradient(&m_criticGradients[i], 1);
627+
const ndBrainMemVector output(&m_criticValue[0][i], 1);
628+
ndBrainMemVector gradient(&m_criticOutputGradients[0][i], 1);
627629
loss.GetLoss(output, gradient);
628630
}
629-
//trainer.BackPropagate(criticObservationAction, loss);
630-
m_criticTrainer[criticIndex]->BackPropagate(m_criticGradients);
631+
m_criticTrainer[criticIndex]->BackPropagate(m_criticOutputGradients[0]);
631632
m_criticTrainer[criticIndex]->ApplyLearnRate(m_parameters.m_criticLearnRate);
632633
}
633634
}
@@ -661,20 +662,20 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
661662
{
662663
m_rewardBatch[i].SetCount(m_parameters.m_miniBatchSize);
663664
}
664-
665+
665666
for (ndInt32 n = 0; n < m_parameters.m_criticUpdatesCount; ++n)
666667
{
667668
for (ndInt32 i = 0; i < m_parameters.m_miniBatchSize; ++i)
668669
{
669670
const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
670-
671+
671672
//get the state rewards
672673
ndBrainFloat r = m_replayBuffer.GetReward(index);
673674
for (ndInt32 j = 0; j < ndInt32(sizeof(m_criticTrainer) / sizeof(m_criticTrainer[0])); ++j)
674675
{
675676
m_rewardBatch[j][i] = r;
676677
}
677-
678+
678679
// get the next state actions
679680
ndBrainMemVector nextObsevations(&m_nextObsevationsBatch[i * brain.GetInputSize()], brain.GetInputSize());
680681
if (m_replayBuffer.GetTerminalState(index))
@@ -689,16 +690,16 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
689690
}
690691
m_policyTrainer->MakePrediction(m_nextObsevationsBatch);
691692
m_policyTrainer->GetOutput(m_nextActionBatch);
692-
693-
// calculate the expected rward for each action
693+
694+
// calculate the expected reward for each action
694695
for (ndInt32 i = 0; i < m_parameters.m_miniBatchSize; ++i)
695696
{
696-
const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
697-
697+
//const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
698+
698699
const ndBrainMemVector nextAction(&m_nextActionBatch[i * brain.GetOutputSize()], brain.GetOutputSize());
699700
const ndBrainMemVector nextObsevations(&m_nextObsevationsBatch[i * brain.GetInputSize()], brain.GetInputSize());
700701
ndBrainMemVector criticNextObservationAction(&m_criticNextObservationActionBatch[i * criticInputSize], criticInputSize);
701-
702+
702703
ndMemCpy(&criticNextObservationAction[0], &nextAction[0], brain.GetOutputSize());
703704
ndMemCpy(&criticNextObservationAction[brain.GetOutputSize()], &nextObsevations[0], brain.GetInputSize());
704705
}
@@ -707,7 +708,7 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
707708
{
708709
m_referenceCriticTrainer[i]->MakePrediction(m_criticNextObservationActionBatch);
709710
m_referenceCriticTrainer[i]->GetOutput(m_nextQValue);
710-
711+
711712
const ndInt32 index = m_miniBatchIndexBuffer[n * m_parameters.m_miniBatchSize + i];
712713
for (ndInt32 j = 0; j < m_parameters.m_miniBatchSize; ++j)
713714
{
@@ -717,20 +718,20 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::CalculateExpectedRewards()
717718
}
718719
}
719720
}
720-
721+
721722
for (ndInt32 i = 0; i < m_parameters.m_miniBatchSize; ++i)
722723
{
723724
for (ndInt32 j = 1; j < ndInt32(sizeof(m_criticTrainer) / sizeof(m_criticTrainer[0])); ++j)
724725
{
725726
m_rewardBatch[0][i] = ndMin(m_rewardBatch[0][i], m_rewardBatch[j][i]);
726727
}
727728
}
728-
729+
729730
if (m_parameters.m_entropyRegularizerCoef > ndBrainFloat(1.0e-6f))
730731
{
731732
ndAssert(0);
732733
}
733-
734+
734735
for (ndInt32 i = 0; i < m_parameters.m_miniBatchSize; ++i)
735736
{
736737
m_expectedRewards[n * m_parameters.m_miniBatchSize + i] = m_rewardBatch[0][i];
@@ -858,7 +859,16 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction()
858859
const ndBrain& brain = **m_policyTrainer->GetBrain();
859860
ndInt32 criticInputSize = brain.GetInputSize() + brain.GetOutputSize();
860861

862+
m_actionBatch.SetCount(m_parameters.m_miniBatchSize* brain.GetOutputSize());
861863
m_obsevationsBatch.SetCount(m_parameters.m_miniBatchSize * brain.GetInputSize());
864+
m_criticObservationActionBatch.SetCount(m_parameters.m_miniBatchSize* criticInputSize);
865+
866+
for (ndInt32 i = 0; i < ndInt32(sizeof(m_criticTrainer) / sizeof(m_criticTrainer[0])); ++i)
867+
{
868+
m_criticValue[i].SetCount(m_parameters.m_miniBatchSize);
869+
m_criticOutputGradients[i].SetCount(m_parameters.m_miniBatchSize);
870+
m_criticInputGradients[i].SetCount(m_parameters.m_miniBatchSize * criticInputSize);
871+
}
862872

863873
for (ndInt32 n = 0; n < m_parameters.m_policyUpdatesCount; ++n)
864874
{
@@ -870,9 +880,40 @@ void ndBrainAgentDeterministicPolicyGradient_Trainer::LearnPolicyFunction()
870880
ndMemCpy(&observation[0], m_replayBuffer.GetActions(index), brain.GetOutputSize());
871881
}
872882
m_policyTrainer->MakePrediction(m_obsevationsBatch);
883+
m_policyTrainer->GetOutput(m_actionBatch);
873884

885+
for (ndInt32 i = 0; i < m_parameters.m_miniBatchSize; ++i)
886+
{
887+
ndBrainMemVector criticObservationAction(&m_criticObservationActionBatch[i * criticInputSize], criticInputSize);
888+
ndMemCpy(&criticObservationAction[0], &m_actionBatch[i * brain.GetOutputSize()], brain.GetOutputSize());
889+
ndMemCpy(&criticObservationAction[brain.GetOutputSize()], &m_obsevationsBatch[i * brain.GetInputSize()], brain.GetInputSize());
890+
}
874891

892+
for (ndInt32 i = 0; i < ndInt32(sizeof(m_criticTrainer) / sizeof(m_criticTrainer[0])); ++i)
893+
{
894+
m_criticTrainer[i]->MakePrediction(m_criticObservationActionBatch);
895+
m_policyTrainer->GetOutput(m_criticValue[i]);
896+
if (m_parameters.m_entropyRegularizerCoef > ndBrainFloat(1.0e-6f))
897+
{
898+
ndAssert(0);
899+
}
900+
m_criticOutputGradients[i].Set(ndBrainFloat(1.0f));
901+
m_criticTrainer[i]->BackPropagate(m_criticOutputGradients[i]);
902+
m_criticTrainer[i]->GetInput(m_criticInputGradients[i]);
903+
}
875904

905+
for (ndInt32 i = 0; i < m_parameters.m_miniBatchSize; ++i)
906+
{
907+
for (ndInt32 j = 1; j < ndInt32(sizeof(m_criticTrainer) / sizeof(m_criticTrainer[0])); ++j)
908+
{
909+
if (m_criticValue[j][i] < m_criticValue[0][i])
910+
{
911+
ndBrainMemVector dstObservationAction(&m_criticInputGradients[0][i * criticInputSize], criticInputSize);
912+
const ndBrainMemVector srcObservationAction(&m_criticInputGradients[j][i * criticInputSize], criticInputSize);
913+
dstObservationAction.Set(srcObservationAction);
914+
}
915+
}
916+
}
876917
}
877918
}
878919

newton-4.00/sdk/dBrain/ndBrainAgentDeterministicPolicyGradient_Trainer.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,21 @@ class ndBrainAgentDeterministicPolicyGradient_Trainer : public ndBrainThreadPool
194194
HyperParameters m_parameters;
195195

196196
ndSharedPtr<ndBrainTrainerCpu> m_policyTrainer;
197-
//ndSharedPtr<ndBrainTrainerCpuInference> m_referencePolicy;
198-
199197
ndSharedPtr<ndBrainTrainerCpu> m_criticTrainer[ND_NUMBER_OF_CRITICS];
200198
ndSharedPtr<ndBrainTrainerCpu> m_referenceCriticTrainer[ND_NUMBER_OF_CRITICS];
201199

202-
ndBrainVector m_nextQValue;
203-
ndBrainVector m_criticValue;
204-
ndBrainVector m_criticGradients;
200+
ndBrainVector m_actionBatch;
205201
ndBrainVector m_nextActionBatch;
206202
ndBrainVector m_obsevationsBatch;
207203
ndBrainVector m_nextObsevationsBatch;
208204
ndBrainVector m_criticObservationActionBatch;
209205
ndBrainVector m_criticNextObservationActionBatch;
206+
207+
ndBrainVector m_nextQValue;
208+
ndBrainVector m_criticValue[ND_NUMBER_OF_CRITICS];
210209
ndBrainVector m_rewardBatch[ND_NUMBER_OF_CRITICS];
210+
ndBrainVector m_criticInputGradients[ND_NUMBER_OF_CRITICS];
211+
ndBrainVector m_criticOutputGradients[ND_NUMBER_OF_CRITICS];
211212

212213
ndBrainVector m_expectedRewards;
213214
ndArray<ndInt32> m_miniBatchIndexBuffer;

newton-4.00/sdk/dBrain/ndBrainTrainer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class ndBrainTrainer: public ndClassAlloc
5555

5656
ndSharedPtr<ndBrain>& GetBrain();
5757
virtual void UpdateParameters() = 0;
58+
59+
virtual void GetInput(ndBrainVector&) const {}
5860
virtual void GetOutput(ndBrainVector&) const {}
5961

6062
// legacy method;

0 commit comments

Comments
 (0)