Skip to content

Commit

Permalink
Core Slates support (#146)
Browse files Browse the repository at this point in the history
* Fix OSX build issues.

* Add OSX build instructions.

* almost done

* Finished with slates skeleton.

* a bit more stuff

* Better API names. Add rl_sim.

* Add missing schema to proj file

* Small fixes, update submodule

* Update submodule

* Fix wrong word

* Fix initialization and mock implementation based on the interface change

* Must be able to check if slates model is compatible

* Update mock call counts

* Fix exploration, sampling bug that was fixed in CCB

* - Use default move constructor
- Fix variable nam consistency

* Remove comments, extract UUID to a local variable

* Small stylistic cleanups

* Native bits of C# bindings.

* C# bits for slates.

* Fix vs build and bindings.

* sim WIP

* Fix malformed vcxproj

* Fix slates_response include, fix simulator usage of _slot_id

* Add a unit test for slates

* Update rlclientlib/vw_model/safe_vw.cc

Co-authored-by: Alexey Taymanov <[email protected]>

* Remove redundant moves

* refactor model type checking

* update submodule

Co-authored-by: Rodrigo Kumpera <[email protected]>
Co-authored-by: Rodrigo Kumpera <[email protected]>
Co-authored-by: Alexey Taymanov <[email protected]>
  • Loading branch information
4 people committed Jul 6, 2020
1 parent 58a2138 commit 97d6da5
Show file tree
Hide file tree
Showing 45 changed files with 1,427 additions and 67 deletions.
2 changes: 1 addition & 1 deletion bindings/cs/rl.net.cli/PerfTestCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private PerfTestStepProvider DoWork(string tag)
};

Console.WriteLine(stepProvider.DataSize);
RLDriver rlDriver = new RLDriver(liveModel)
RLDriver rlDriver = new RLDriver(liveModel, useSlates: false)
{
StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs)
};
Expand Down
13 changes: 12 additions & 1 deletion bindings/cs/rl.net.cli/PerfTestStepProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,22 @@ public float Outcome {
}

public string DecisionContext { get; set; }

public string SlatesContext { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public float GetOutcome(long actionIndex, IEnumerable<ActionProbability> actionDistribution)
{
return this.Outcome;
}

public float GetOutcome(int[] actionIndexes, float[] probabilities)
{
return this.Outcome;
}

public float GetSlatesOutcome(int[] actionIndexes, float[] probabilities)
{
throw new NotImplementedException();
}
}

public string Tag { get; set; } = "Id";
Expand Down
56 changes: 43 additions & 13 deletions bindings/cs/rl.net.cli/RLDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ string DecisionContext
{
get;
}
string SlatesContext
{
get;
}

TOutcome GetOutcome(long actionIndex, IEnumerable<ActionProbability> actionDistribution);
TOutcome GetOutcome(int[] actionIndexes, float[] probabilities);
TOutcome GetSlatesOutcome(int[] actionIndexes, float[] probabilities);
}

internal class RunContext
Expand All @@ -38,6 +44,11 @@ public ApiStatus ApiStatusContainer
{
get;
} = new ApiStatus();

public SlatesResponse SlatesContainer
{
get;
} = new SlatesResponse();
}

internal interface IOutcomeReporter<TOutcome>
Expand All @@ -48,10 +59,12 @@ internal interface IOutcomeReporter<TOutcome>
public class RLDriver : IOutcomeReporter<float>, IOutcomeReporter<string>
{
private LiveModel liveModel;
private bool useSlates;

public RLDriver(LiveModel liveModel)
public RLDriver(LiveModel liveModel, bool useSlates)
{
this.liveModel = liveModel;
this.useSlates = useSlates;
}

public TimeSpan StepInterval
Expand Down Expand Up @@ -100,22 +113,39 @@ bool IOutcomeReporter<string>.TryQueueOutcomeEvent(RunContext runContext, string
private void Step<TOutcome>(RunContext runContext, IOutcomeReporter<TOutcome> outcomeReporter, IStepContext<TOutcome> step)
{
string eventId = step.EventId;
TOutcome outcome = default(TOutcome);

if (!liveModel.TryChooseRank(eventId, step.DecisionContext, runContext.ResponseContainer, runContext.ApiStatusContainer))
{
this.SafeRaiseError(runContext.ApiStatusContainer);
}
if(useSlates) {
if(!liveModel.TryRequestSlatesDecision(eventId, step.SlatesContext, runContext.SlatesContainer, runContext.ApiStatusContainer))
{
this.SafeRaiseError(runContext.ApiStatusContainer);
}

long actionIndex = -1;
if (!runContext.ResponseContainer.TryGetChosenAction(out actionIndex, runContext.ApiStatusContainer))
int[] actions = runContext.SlatesContainer.Select(slot => slot.ActionId).ToArray();
float[] probs = runContext.SlatesContainer.Select(slot => slot.Probability).ToArray();
outcome = step.GetSlatesOutcome(actions, probs);
if (outcome == null)
{
return;
}
} else
{
this.SafeRaiseError(runContext.ApiStatusContainer);
}
if (!liveModel.TryChooseRank(eventId, step.DecisionContext, runContext.ResponseContainer, runContext.ApiStatusContainer))
{
this.SafeRaiseError(runContext.ApiStatusContainer);
}

TOutcome outcome = step.GetOutcome(actionIndex, runContext.ResponseContainer.AsEnumerable());
if (outcome == null)
{
return;
long actionIndex = -1;
if (!runContext.ResponseContainer.TryGetChosenAction(out actionIndex, runContext.ApiStatusContainer))
{
this.SafeRaiseError(runContext.ApiStatusContainer);
}

outcome = step.GetOutcome(actionIndex, runContext.ResponseContainer.AsEnumerable());
if (outcome == null)
{
return;
}
}

if (!outcomeReporter.TryQueueOutcomeEvent(runContext, eventId, outcome))
Expand Down
65 changes: 50 additions & 15 deletions bindings/cs/rl.net.cli/RLSimulator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ namespace Rl.Net.Cli
public enum Topic : long
{
HerbGarden,
MachineLearning
MachineLearning,
Soccer,
SpaceExploration
}

internal static class ActionDistributionExtensions
Expand All @@ -35,22 +37,25 @@ internal class SimulatorStepProvider : IDriverStepProvider<float>
{
public static readonly Random RandomSource = new Random();
public const int InfinitySteps = -1;
internal const int DefaultSlatesSlotCount = 2;

private static Func<Topic, float> GenerateRewardDistribution(float herbGardenProbability, float machineLearningProbability)
private static Func<Topic, float> GenerateRewardDistribution(params float[] probabilities)
{
Dictionary<Topic, float> topicProbabilities = new Dictionary<Topic, float>
{
{ Topic.HerbGarden, herbGardenProbability },
{ Topic.MachineLearning, machineLearningProbability }
{ Topic.HerbGarden, probabilities[0] },
{ Topic.MachineLearning, probabilities[1] },
{ Topic.Soccer, probabilities[2] },
{ Topic.SpaceExploration, probabilities[3] },
};

return (topic) => topicProbabilities[topic];
}

internal static Person[] People = new[]
{
new Person("rnc", "engineering", "hiking", "spock", GenerateRewardDistribution(0.03f, 0.1f)),
new Person("mk", "psychology", "kids", "7of9", GenerateRewardDistribution(0.3f, 0.1f))
new Person("rnc", "engineering", "hiking", "spock", GenerateRewardDistribution(0.03f, 0.1f ,0.05f, 0.15f)),
new Person("mk", "psychology", "kids", "7of9", GenerateRewardDistribution(0.3f, 0.1f, 0.08f, 0.1f))
};

private static Person GetRandomPerson()
Expand All @@ -61,10 +66,12 @@ private static Person GetRandomPerson()
}

private readonly int steps;
private readonly int slots;

public SimulatorStepProvider(int steps)
public SimulatorStepProvider(int steps, int slots)
{
this.steps = steps;
this.slots = slots;
}

public IEnumerator<IStepContext<float>> GetEnumerator()
Expand All @@ -78,7 +85,8 @@ public IEnumerator<IStepContext<float>> GetEnumerator()
{
StatisticsCalculator = stats,
EventId = Guid.NewGuid().ToString(),
Person = GetRandomPerson()
Person = GetRandomPerson(),
SlotCount = this.slots,
};

yield return step;
Expand All @@ -95,9 +103,10 @@ IEnumerator IEnumerable.GetEnumerator()

internal class SimulatorStep : IStepContext<float>
{
internal static readonly Topic[] ActionSet = new[] { Topic.HerbGarden, Topic.MachineLearning };
internal static readonly (Topic topic, int slot_id)[] ActionSet = new[] { (Topic.HerbGarden, 0), (Topic.MachineLearning, 0), (Topic.Soccer, 1), (Topic.SpaceExploration, 1) };

private static readonly string ActionsJson = string.Join(",", ActionSet.Select(topic => $"{{ \"TAction\": {{ \"topic\": \"{topic}\" }} }}"));
private static readonly string ActionsJson = string.Join(",", ActionSet.Select(action => $"{{ \"TAction\": {{ \"topic\": \"{action.topic}\" }}, \"_slot_id\": {action.slot_id} }}"));
private string SlotsJson => string.Join(",", Enumerable.Range(0, SlotCount).Select(slotId => $"{{ \"slot_id\": \"__{slotId}\" }}"));

public StatisticsCalculator StatisticsCalculator
{
Expand Down Expand Up @@ -136,7 +145,15 @@ public string ActionDistributionString
}
}

public int SlotCount
{
get;
set;
}

public string DecisionContext => $"{{ { this.Person.FeaturesJson }, \"_multi\": [{ ActionsJson }] }}";
public string SlatesContext => $"{{ { this.Person.FeaturesJson }, \"_multi\":[{ActionsJson}], \"_slots\": [{SlotsJson}] }}";


private float? outcomeCache;
private IEnumerable<ActionProbability> actionDistributionCache;
Expand All @@ -155,19 +172,37 @@ public float GetOutcome(long actionIndex, IEnumerable<ActionProbability> actionD
public void Record(StatisticsCalculator statisticsCalculator)
{
statisticsCalculator.Record(this.Person, this.DecisionCache.Value, this.outcomeCache.Value);

Console.WriteLine($" {statisticsCalculator.TotalActions}, ctxt, {this.Person.Id}, action, {this.DecisionCache.Value}, outcome, {this.outcomeCache.Value}, dist, {this.ActionDistributionString}, {statisticsCalculator.GetStats(this.Person, this.DecisionCache.Value)}");
}

public float GetOutcome(int[] actionIndexes, float[] probabilities)
{
throw new NotImplementedException();
}

public float GetSlatesOutcome(int[] actionIndexes, float[] probabilities)
{
if (!this.outcomeCache.HasValue)
{
this.DecisionCache = (Topic)0;
this.actionDistributionCache = new List<ActionProbability>();
this.outcomeCache = actionIndexes
.Zip(probabilities, (int action, float prob) => this.Person.GenerateOutcome(ActionSet[action].topic))
.Aggregate((float)0, (acc, x) => acc + x);
}

return this.outcomeCache.Value;
}
}
}

internal class RLSimulator
{
private RLDriver driver;

public RLSimulator(LiveModel liveModel)
public RLSimulator(LiveModel liveModel, bool useSlates)
{
this.driver = new RLDriver(liveModel);
this.driver = new RLDriver(liveModel, useSlates);
}

public TimeSpan StepInterval
Expand All @@ -182,9 +217,9 @@ public TimeSpan StepInterval
}
}

public void Run(int steps = SimulatorStepProvider.InfinitySteps)
public void Run(int steps = SimulatorStepProvider.InfinitySteps, int slots = SimulatorStepProvider.DefaultSlatesSlotCount)
{
SimulatorStepProvider stepProvider = new SimulatorStepProvider(steps);
SimulatorStepProvider stepProvider = new SimulatorStepProvider(steps, slots);

this.driver.Run(stepProvider);
}
Expand Down
5 changes: 4 additions & 1 deletion bindings/cs/rl.net.cli/ReplayCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ class ReplayCommand : CommandBase
[Option(longName: "sleep", HelpText = "sleep interval in milliseconds", Required = false, Default = 100)]
public int SleepIntervalMs { get; set; }

[Option(longName:"slates", HelpText = "use slates", Required = false, Default = false)]
public bool UseSlates { get; set; }

public override void Run()
{
LiveModel liveModel = Helpers.CreateLiveModelOrExit(this.ConfigPath);
RLDriver rlDriver = new RLDriver(liveModel);
RLDriver rlDriver = new RLDriver(liveModel, useSlates: this.UseSlates);
rlDriver.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs);

using (TextReader textReader = File.OpenText(this.LogPath))
Expand Down
14 changes: 14 additions & 0 deletions bindings/cs/rl.net.cli/ReplayStepProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,26 @@ public JObject[] Observations
[JsonIgnore]
public string DecisionContext => this.Context.ToString(Formatting.None);

public string SlatesContext { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public string GetOutcome(long actionIndex, IEnumerable<ActionProbability> actionDistribution)
{
JToken observationValue = this.Observations?.First()?.SelectToken("v");

return observationValue?.ToString(Formatting.None);
}

public string GetOutcome(int[] actionIndexes, float[] probabilities)
{
JToken observationValue = this.Observations?.First()?.SelectToken("v");

return observationValue?.ToString(Formatting.None);
}

public string GetSlatesOutcome(int[] actionIndexes, float[] probabilities)
{
throw new NotImplementedException();
}
}

public ReplayStepProvider(IEnumerable<string> dsJsonHistory)
Expand Down
5 changes: 4 additions & 1 deletion bindings/cs/rl.net.cli/RunSimulatorCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ class RunSimulatorCommand : CommandBase
[Option(longName: "steps", HelpText = "Amount of steps", Required = false, Default = SimulatorStepProvider.InfinitySteps)]
public int Steps { get; set; }

[Option(longName:"slates", HelpText = "Use slates for decisions", Required = false, Default = false)]
public bool UseSlates { get; set; }

public override void Run()
{
LiveModel liveModel = Helpers.CreateLiveModelOrExit(this.ConfigPath);

RLSimulator rlSim = new RLSimulator(liveModel);
RLSimulator rlSim = new RLSimulator(liveModel, useSlates: this.UseSlates);
rlSim.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs);
rlSim.OnError += (sender, apiStatus) => Helpers.WriteStatusAndExit(apiStatus);
rlSim.Run(this.Steps);
Expand Down
1 change: 1 addition & 0 deletions bindings/cs/rl.net.native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_library(rl.net.native SHARED
rl.net.live_model.cc
rl.net.ranking_response.cc
rl.net.decision_response.cc
rl.net.slates_response.cc
binding_tracer.cc
)

Expand Down
17 changes: 17 additions & 0 deletions bindings/cs/rl.net.native/rl.net.live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ API int LiveModelRequestDecisionWithFlags(livemodel_context_t* context, const ch
return context->livemodel->request_decision(context_json, flags, *resp, status);
}

API int LiveModelRequestSlatesDecision(livemodel_context_t* context, const char * event_id, const char * context_json, reinforcement_learning::slates_response* resp, reinforcement_learning::api_status* status)
{
if(event_id == nullptr)
return context->livemodel->request_slates_decision(context_json, *resp, status);
else
return context->livemodel->request_slates_decision(event_id, context_json, *resp, status);
}

API int LiveModelRequestSlatesDecisionWithFlags(livemodel_context_t* context, const char * event_id, const char * context_json, unsigned int flags, reinforcement_learning::slates_response* resp, reinforcement_learning::api_status* status)
{
if(event_id == nullptr)
return context->livemodel->request_slates_decision(context_json, flags, *resp, status);
else
return context->livemodel->request_slates_decision(event_id, context_json, flags, *resp, status);
}


API int LiveModelReportActionTaken(livemodel_context_t* context, const char * event_id, reinforcement_learning::api_status* status)
{
return context->livemodel->report_action_taken(event_id, status);
Expand Down
4 changes: 4 additions & 0 deletions bindings/cs/rl.net.native/rl.net.live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ extern "C" {
API int LiveModelRequestDecision(livemodel_context_t* livemodel, const char * context_json, reinforcement_learning::decision_response* resp, reinforcement_learning::api_status* status = nullptr);
API int LiveModelRequestDecisionWithFlags(livemodel_context_t* livemodel, const char * context_json, unsigned int flags, reinforcement_learning::decision_response* resp, reinforcement_learning::api_status* status = nullptr);

API int LiveModelRequestSlatesDecision(livemodel_context_t* context, const char * event_id, const char * context_json, reinforcement_learning::slates_response* resp, reinforcement_learning::api_status* status = nullptr);

API int LiveModelRequestSlatesDecisionWithFlags(livemodel_context_t* context, const char * event_id, const char * context_json, unsigned int flags, reinforcement_learning::slates_response* resp, reinforcement_learning::api_status* status = nullptr);

API int LiveModelReportActionTaken(livemodel_context_t* livemodel, const char * event_id, reinforcement_learning::api_status* status = nullptr);
API int LiveModelReportOutcomeF(livemodel_context_t* livemodel, const char * event_id, float outcome, reinforcement_learning::api_status* status = nullptr);
API int LiveModelReportOutcomeJson(livemodel_context_t* livemodel, const char * event_id, const char * outcomeJson, reinforcement_learning::api_status* status = nullptr);
Expand Down
2 changes: 2 additions & 0 deletions bindings/cs/rl.net.native/rl.net.native.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
<ClCompile Include="rl.net.live_model.cc" />
<ClCompile Include="rl.net.ranking_response.cc" />
<ClCompile Include="rl.net.decision_response.cc" />
<ClCompile Include="rl.net.slates_response.cc" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="binding_tracer.h" />
Expand All @@ -103,6 +104,7 @@
<ClInclude Include="rl.net.live_model.h" />
<ClInclude Include="rl.net.ranking_response.h" />
<ClInclude Include="rl.net.decision_response.h" />
<ClInclude Include="rl.net.slates_response.h" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\..\rlclientlib\rlclientlib.vcxproj">
Expand Down
Loading

0 comments on commit 97d6da5

Please sign in to comment.