|
28 | 28 |
|
29 | 29 | public interface IRLModelPPO
|
30 | 30 | {
|
| 31 | + /// <summary> |
| 32 | + /// The entropy loss weight |
| 33 | + /// </summary> |
31 | 34 | float EntropyLossWeight { get; set; }
|
| 35 | + |
| 36 | + /// <summary> |
| 37 | + /// The value loss weight |
| 38 | + /// </summary> |
32 | 39 | float ValueLossWeight { get; set; }
|
| 40 | + |
| 41 | + /// <summary> |
| 42 | + /// The clip epsilon for PPO actor loss |
| 43 | + /// </summary> |
33 | 44 | float ClipEpsilon { get; set; }
|
| 45 | + |
| 46 | + /// <summary> |
| 47 | + /// The value loss clip |
| 48 | + /// </summary> |
34 | 49 | float ClipValueLoss { get; set; }
|
35 | 50 |
|
| 51 | + /// <summary> |
| 52 | + /// Evaluate the values of current states |
| 53 | + /// </summary> |
| 54 | + /// <param name="vectorObservation">Batched vector observations.</param> |
| 55 | + /// <param name="visualObservation">List of batched visual observations.</param> |
| 56 | + /// <returns>Values of the input batched states</returns> |
36 | 57 | float[] EvaluateValue(float[,] vectorObservation, List<float[,,,]> visualObservation);
|
| 58 | + |
| 59 | + /// <summary> |
| 60 | + /// Evaluate the desired actions of current states |
| 61 | + /// </summary> |
| 62 | + /// <param name="vectorObservation">Batched vector observations.</param> |
| 63 | + /// <param name="actionProbs">Output action probabilities of the output actions. Used for PPO training.</param> |
| 64 | + /// <param name="visualObservation">List of batched visual observations.</param> |
| 65 | + /// <param name="actionsMask">Action masks for discrete action space. Each element in the list is for one branch of the actions. Can be null if no mask.</param> |
| 66 | + /// <returns>The desired actions of the batched input states.</returns> |
37 | 67 | float[,] EvaluateAction(float[,] vectorObservation, out float[,] actionProbs, List<float[,,,]> visualObservation, List<float[,]> actionsMask = null);
|
| 68 | + |
| 69 | + /// <summary> |
| 70 | + /// Evaluate the input actions' probabilities of current states |
| 71 | + /// </summary> |
| 72 | + /// <param name="vectorObservation">Batched vector observations.</param> |
| 73 | + /// <param name="actions">The batched actions that need the probabilies</param> |
| 74 | + /// <param name="visualObservation">List of batched visual observations.</param> |
| 75 | + /// <param name="actionsMask">Action masks for discrete action space. Each element in the list is for one branch of the actions. Can be null if no mask.</param> |
| 76 | + /// <returns>Output action probabilities of the output actions. Used for PPO training.</returns> |
38 | 77 | float[,] EvaluateProbability(float[,] vectorObservation, float[,] actions, List<float[,,,]> visualObservation, List<float[,]> actionsMask = null);
|
| 78 | + |
| 79 | + /// <summary> |
| 80 | + /// Train a batch for PPO |
| 81 | + /// </summary> |
| 82 | + /// <param name="vectorObservations">Batched vector observations.</param> |
| 83 | + /// <param name="visualObservations">List of batched visual observations.</param> |
| 84 | + /// <param name="actions">The old actions taken in those input states.</param> |
| 85 | + /// <param name="actionProbs">The old probabilities of old actions taken in those input states.</param> |
| 86 | + /// <param name="targetValues">Target values.</param> |
| 87 | + /// <param name="oldValues">Old values evaluated from the neural network from those input states.</param> |
| 88 | + /// <param name="advantages">Advantages.</param> |
| 89 | + /// <param name="actionsMask">Action masks for discrete action space. Each element in the list is for one branch of the actions. Can be null if no mask.</param> |
| 90 | + /// <returns></returns> |
39 | 91 | float[] TrainBatch(float[,] vectorObservations, List<float[,,,]> visualObservations, float[,] actions, float[,] actionProbs, float[] targetValues, float[] oldValues, float[] advantages, List<float[,]> actionsMask = null);
|
40 | 92 | }
|
41 | 93 |
|
|
0 commit comments