A Discrete-Event Network Simulator
API
thompson-sampling-wifi-manager.cc
Go to the documentation of this file.
1/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
2/*
3 * Copyright (c) 2021 IITP RAS
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License version 2 as
7 * published by the Free Software Foundation;
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 *
18 * Author: Alexander Krotov <krotov@iitp.ru>
19 */
20
21#include "ns3/log.h"
22#include "ns3/double.h"
23#include "ns3/core-module.h"
24#include "ns3/packet.h"
25
26#include "ns3/wifi-phy.h"
27
29
30#include <cstdint>
31#include <cstdlib>
32#include <fstream>
33#include <iostream>
34#include <string>
35
36namespace ns3 {
37
42struct RateStats {
44 uint16_t channelWidth;
45 uint8_t nss;
46
47 double success{0.0};
48 double fails{0.0};
50};
51
59{
60 size_t m_nextMode;
61 size_t m_lastMode;
62
63 std::vector<RateStats> m_mcsStats;
64};
65
67
68NS_LOG_COMPONENT_DEFINE ("ThompsonSamplingWifiManager");
69
72{
73 static TypeId tid = TypeId ("ns3::ThompsonSamplingWifiManager")
75 .SetGroupName ("Wifi")
76 .AddConstructor<ThompsonSamplingWifiManager> ()
77 .AddAttribute ("Decay",
78 "Exponential decay coefficient, Hz; zero is a valid value for static scenarios",
79 DoubleValue (1.0),
81 MakeDoubleChecker<double> (0.0))
82 .AddTraceSource ("Rate",
83 "Traced value for rate changes (b/s)",
85 "ns3::TracedValueCallback::Uint64")
86 ;
87 return tid;
88}
89
91 : m_currentRate{0}
92{
93 NS_LOG_FUNCTION (this);
94
95 m_gammaRandomVariable = CreateObject<GammaRandomVariable> ();
96}
97
99{
100 NS_LOG_FUNCTION (this);
101}
102
105{
106 NS_LOG_FUNCTION (this);
108 station->m_nextMode = 0;
109 station->m_lastMode = 0;
110 return station;
111}
112
113void
115{
116 auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
117 if (!station->m_mcsStats.empty ())
118 {
119 return;
120 }
121
122 // Add HT, VHT or HE MCSes
123 for (const auto &mode : GetPhy ()->GetMcsList ())
124 {
125 for (uint16_t j = 20; j <= GetPhy ()->GetChannelWidth (); j *= 2)
126 {
127 WifiModulationClass modulationClass = WIFI_MOD_CLASS_HT;
128 if (GetVhtSupported ())
129 {
130 modulationClass = WIFI_MOD_CLASS_VHT;
131 }
132 if (GetHeSupported ())
133 {
134 modulationClass = WIFI_MOD_CLASS_HE;
135 }
136 if (mode.GetModulationClass () == modulationClass)
137 {
138 for (uint8_t k = 1; k <= GetPhy ()->GetMaxSupportedTxSpatialStreams (); k++)
139 {
140 if (mode.IsAllowed (j, k))
141 {
142 RateStats stats;
143 stats.mode = mode;
144 stats.channelWidth = j;
145 stats.nss = k;
146
147 station->m_mcsStats.push_back (stats);
148 }
149 }
150 }
151 }
152 }
153
154 if (station->m_mcsStats.empty ())
155 {
156 // Add legacy non-HT modes.
157 for (uint8_t i = 0; i < GetNSupported (station); i++)
158 {
159 RateStats stats;
160 stats.mode = GetSupported (station, i);
163 {
164 stats.channelWidth = 22;
165 }
166 else
167 {
168 stats.channelWidth = 20;
169 }
170 stats.nss = 1;
171 station->m_mcsStats.push_back (stats);
172 }
173 }
174
175 NS_ASSERT_MSG (!station->m_mcsStats.empty (), "No usable MCS found");
176
177 UpdateNextMode (st);
178}
179
180void
182{
183 NS_LOG_FUNCTION (this << station << rxSnr << txMode);
184}
185
186void
188{
189 NS_LOG_FUNCTION (this << station);
190}
191
192void
194{
195 NS_LOG_FUNCTION (this << st);
197 auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
198 Decay (st, station->m_lastMode);
199 station->m_mcsStats.at (station->m_lastMode).fails++;
200 UpdateNextMode (st);
201}
202
203void
205 double rtsSnr)
206{
207 NS_LOG_FUNCTION (this << st << ctsSnr << ctsMode.GetUniqueName () << rtsSnr);
208}
209
210void
212{
214 auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
215
216 double maxThroughput = 0.0;
217 double frameSuccessRate = 1.0;
218
219 NS_ASSERT (!station->m_mcsStats.empty ());
220
221 // Use the most robust MCS if frameSuccessRate is 0 for all MCS.
222 station->m_nextMode = 0;
223
224 for (uint32_t i = 0; i < station->m_mcsStats.size (); i++)
225 {
226 Decay (st, i);
227 const WifiMode mode{station->m_mcsStats.at (i).mode};
228
229 uint16_t guardInterval = GetModeGuardInterval (st, mode);
230 double rate = mode.GetDataRate (station->m_mcsStats.at (i).channelWidth,
231 guardInterval,
232 station->m_mcsStats.at (i).nss);
233
234 // Thompson sampling
235 frameSuccessRate = SampleBetaVariable (1.0 + station->m_mcsStats.at (i).success,
236 1.0 + station->m_mcsStats.at (i).fails);
237 NS_LOG_DEBUG ("Draw"
238 << " success=" << station->m_mcsStats.at (i).success
239 << " fails=" << station->m_mcsStats.at (i).fails
240 << " frameSuccessRate=" << frameSuccessRate
241 << " mode=" << mode);
242 if (frameSuccessRate * rate > maxThroughput)
243 {
244 maxThroughput = frameSuccessRate * rate;
245 station->m_nextMode = i;
246 }
247 }
248}
249
250void
252 double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss)
253{
254 NS_LOG_FUNCTION (this << st << ackSnr << ackMode.GetUniqueName () << dataSnr);
256 auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
257 Decay (st, station->m_lastMode);
258 station->m_mcsStats.at (station->m_lastMode).success++;
259 UpdateNextMode (st);
260}
261
262void
264 uint16_t nFailedMpdus, double rxSnr, double dataSnr,
265 uint16_t dataChannelWidth, uint8_t dataNss)
266{
267 NS_LOG_FUNCTION (this << st << nSuccessfulMpdus << nFailedMpdus << rxSnr << dataSnr);
269 auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
270
271 Decay (st, station->m_lastMode);
272 station->m_mcsStats.at (station->m_lastMode).success += nSuccessfulMpdus;
273 station->m_mcsStats.at (station->m_lastMode).fails += nFailedMpdus;
274
275 UpdateNextMode (st);
276}
277
278void
280{
281 NS_LOG_FUNCTION (this << station);
282}
283
284void
286{
287 NS_LOG_FUNCTION (this << station);
288}
289
290uint16_t
292{
294 {
296 }
297 else if ((mode.GetModulationClass () == WIFI_MOD_CLASS_HT) ||
299 {
300 return std::max<uint16_t> (GetShortGuardIntervalSupported (st) ? 400 : 800,
301 GetShortGuardIntervalSupported () ? 400 : 800);
302 }
303 else
304 {
305 return 800;
306 }
307}
308
311{
312 NS_LOG_FUNCTION (this << st);
314 auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
315
316 auto &stats = station->m_mcsStats.at (station->m_nextMode);
317 WifiMode mode = stats.mode;
318 uint16_t channelWidth = std::min (stats.channelWidth, GetPhy ()->GetChannelWidth ());
319 uint8_t nss = stats.nss;
320 uint16_t guardInterval = GetModeGuardInterval (st, mode);
321
322 station->m_lastMode = station->m_nextMode;
323
324 NS_LOG_DEBUG ("Using"
325 << " mode=" << mode
326 << " channelWidth=" << channelWidth
327 << " nss=" << +nss
328 << " guardInterval=" << guardInterval);
329
330 uint64_t rate = mode.GetDataRate (channelWidth, guardInterval, nss);
331 if (m_currentRate != rate)
332 {
333 NS_LOG_DEBUG ("New datarate: " << rate);
334 m_currentRate = rate;
335 }
336
337 return WifiTxVector (
338 mode,
342 GetModeGuardInterval (st, mode),
344 nss,
345 0, // NESS
346 GetChannelWidthForTransmission (mode, channelWidth),
347 GetAggregation (station),
348 false);
349}
350
353{
354 NS_LOG_FUNCTION (this << st);
356 auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
357
358 // Use the most robust MCS for the control channel.
359 auto &stats = station->m_mcsStats.at (0);
360 WifiMode mode = stats.mode;
361 uint16_t channelWidth = std::min (stats.channelWidth, GetPhy ()->GetChannelWidth ());
362 uint8_t nss = stats.nss;
363
364 // Make sure control frames are sent using 1 spatial stream.
365 NS_ASSERT (nss == 1);
366
367 return WifiTxVector (
368 mode, GetDefaultTxPowerLevel (),
370 GetModeGuardInterval (st, mode),
372 nss,
373 0, // NESS
374 GetChannelWidthForTransmission (mode, channelWidth),
375 GetAggregation (station),
376 false);
377}
378
379double
381{
382 double X = m_gammaRandomVariable->GetValue (alpha, 1.0);
383 double Y = m_gammaRandomVariable->GetValue (beta, 1.0);
384 return X / (X + Y);
385}
386
387void
389{
390 NS_LOG_FUNCTION (this << st << i);
392 auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
393
394 Time now = Simulator::Now ();
395 auto &stats = station->m_mcsStats.at (i);
396 if (now > stats.lastDecay)
397 {
398 const double coefficient =
399 std::exp (m_decay * (stats.lastDecay - now).GetSeconds ());
400
401 stats.success *= coefficient;
402 stats.fails *= coefficient;
403 stats.lastDecay = now;
404 }
405}
406
407int64_t
409{
410 NS_LOG_FUNCTION (this << stream);
411 m_gammaRandomVariable->SetStream (stream);
412 return 1;
413}
414
415} //namespace ns3
#define min(a, b)
Definition: 80211b.c:42
#define max(a, b)
Definition: 80211b.c:43
This class can be used to hold variables of floating point type such as 'double' or 'float'.
Definition: double.h:41
static Time Now(void)
Return the current simulation virtual time.
Definition: simulator.cc:195
Thompson Sampling rate control algorithm.
uint16_t GetModeGuardInterval(WifiRemoteStation *st, WifiMode mode) const
Returns guard interval in nanoseconds for the given mode.
void DoReportRxOk(WifiRemoteStation *station, double rxSnr, WifiMode txMode) override
This method is a pure virtual method that must be implemented by the sub-class.
void InitializeStation(WifiRemoteStation *station) const
Initializes station rate tables.
void DoReportDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportDataOk(WifiRemoteStation *station, double ackSnr, WifiMode ackMode, double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportAmpduTxStatus(WifiRemoteStation *station, uint16_t nSuccessfulMpdus, uint16_t nFailedMpdus, double rxSnr, double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss) override
Typically called per A-MPDU, either when a Block ACK was successfully received or when a BlockAckTime...
TracedValue< uint64_t > m_currentRate
Trace rate changes.
double SampleBetaVariable(uint64_t alpha, uint64_t beta) const
Sample beta random variable with given parameters.
WifiRemoteStation * DoCreateStation() const override
Ptr< GammaRandomVariable > m_gammaRandomVariable
Variable used to sample beta-distributed random variables.
void DoReportFinalRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
static TypeId GetTypeId(void)
Get the type ID.
WifiTxVector DoGetDataTxVector(WifiRemoteStation *station) override
void DoReportFinalDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
double m_decay
Exponential decay coefficient, Hz.
void DoReportRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void UpdateNextMode(WifiRemoteStation *station) const
Draws a new MCS and related parameters to try next time for this station.
int64_t AssignStreams(int64_t stream) override
Assign a fixed random variable stream number to the random variables used by this model.
WifiTxVector DoGetRtsTxVector(WifiRemoteStation *station) override
void DoReportRtsOk(WifiRemoteStation *station, double ctsSnr, WifiMode ctsMode, double rtsSnr) override
This method is a pure virtual method that must be implemented by the sub-class.
void Decay(WifiRemoteStation *st, size_t i) const
Applies exponential decay to MCS statistics.
Simulation virtual time values and global simulation resolution.
Definition: nstime.h:103
a unique identifier for an interface.
Definition: type-id.h:59
TypeId SetParent(TypeId tid)
Set the parent TypeId.
Definition: type-id.cc:922
represent a single transmission mode
Definition: wifi-mode.h:48
WifiModulationClass GetModulationClass() const
Definition: wifi-mode.cc:177
std::string GetUniqueName(void) const
Definition: wifi-mode.cc:140
uint64_t GetDataRate(uint16_t channelWidth, uint16_t guardInterval, uint8_t nss) const
Definition: wifi-mode.cc:114
std::list< WifiMode > GetMcsList(void) const
The WifiPhy::GetMcsList() method is used (e.g., by a WifiRemoteStationManager) to determine the set o...
Definition: wifi-phy.cc:1750
uint8_t GetMaxSupportedTxSpatialStreams(void) const
Definition: wifi-phy.cc:1099
uint16_t GetChannelWidth(void) const
Definition: wifi-phy.cc:901
hold a list of per-remote-station state.
uint16_t GetChannelWidth(const WifiRemoteStation *station) const
Return the channel width supported by the station.
bool GetVhtSupported(void) const
Return whether the device has VHT capability support enabled.
Ptr< WifiPhy > GetPhy(void) const
Return the WifiPhy.
uint8_t GetNSupported(const WifiRemoteStation *station) const
Return the number of modes supported by the given station.
bool GetAggregation(const WifiRemoteStation *station) const
Return whether the given station supports A-MPDU.
bool GetShortPreambleEnabled(void) const
Return whether the device uses short PHY preambles.
bool GetHeSupported(void) const
Return whether the device has HE capability support enabled.
bool GetShortGuardIntervalSupported(void) const
Return whether the device has SGI support enabled.
WifiMode GetSupported(const WifiRemoteStation *station, uint8_t i) const
Return whether mode associated with the specified station at the specified index.
uint16_t GetGuardInterval(void) const
Return the supported HE guard interval duration (in nanoseconds).
This class mimics the TXVECTOR which is to be passed to the PHY in order to define the parameters whi...
#define NS_ASSERT(condition)
At runtime, in debugging builds, if this condition is not true, the program prints the source file,...
Definition: assert.h:67
#define NS_ASSERT_MSG(condition, message)
At runtime, in debugging builds, if this condition is not true, the program prints the message to out...
Definition: assert.h:88
Ptr< const AttributeAccessor > MakeDoubleAccessor(T1 a1)
Definition: double.h:42
#define NS_LOG_COMPONENT_DEFINE(name)
Define a Log component with a specific name.
Definition: log.h:205
#define NS_LOG_DEBUG(msg)
Use NS_LOG to output a message of level LOG_DEBUG.
Definition: log.h:273
#define NS_LOG_FUNCTION(parameters)
If log level LOG_FUNCTION is enabled, this macro will output all input parameters separated by ",...
#define NS_OBJECT_ENSURE_REGISTERED(type)
Register an Object subclass with the TypeId system.
Definition: object-base.h:45
Ptr< const TraceSourceAccessor > MakeTraceSourceAccessor(T a)
Create a TraceSourceAccessor which will control access to the underlying trace source.
WifiModulationClass
This enumeration defines the modulation classes per (Table 10-6 "Modulation classes"; IEEE 802....
@ WIFI_MOD_CLASS_HR_DSSS
HR/DSSS (Clause 16)
@ WIFI_MOD_CLASS_HT
HT (Clause 19)
@ WIFI_MOD_CLASS_VHT
VHT (Clause 22)
@ WIFI_MOD_CLASS_HE
HE (Clause 27)
@ WIFI_MOD_CLASS_DSSS
DSSS (Clause 15)
Every class exported by the ns3 library is enclosed in the ns3 namespace.
uint16_t GetChannelWidthForTransmission(WifiMode mode, uint16_t maxSupportedChannelWidth)
Return the channel width that corresponds to the selected mode (instead of letting the PHY's default ...
WifiPreamble GetPreambleForTransmission(WifiModulationClass modulation, bool useShortPreamble)
Return the preamble to be used for the transmission.
float alpha
Plot alpha value (transparency)
A structure containing parameters of a single rate and its statistics.
uint16_t channelWidth
channel width in MHz
uint8_t nss
Number of spatial streams.
double success
averaged number of successful transmissions
double fails
averaged number of failed transmissions
Time lastDecay
last time exponential decay was applied to this rate
Holds station state and collected statistics.
size_t m_nextMode
Mode to select for the next transmission.
std::vector< RateStats > m_mcsStats
Collected statistics.
size_t m_lastMode
Most recently used mode, used to write statistics.
hold per-remote-station state.