A Discrete-Event Network Simulator
API
Loading...
Searching...
No Matches
thompson-sampling-wifi-manager.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 IITP RAS
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License version 2 as
6 * published by the Free Software Foundation;
7 *
8 * This program is distributed in the hope that it will be useful,
9 * but WITHOUT ANY WARRANTY; without even the implied warranty of
10 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 * GNU General Public License for more details.
12 *
13 * You should have received a copy of the GNU General Public License
14 * along with this program; if not, write to the Free Software
15 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
16 *
17 * Author: Alexander Krotov <krotov@iitp.ru>
18 */
19
21
22#include "ns3/core-module.h"
23#include "ns3/double.h"
24#include "ns3/log.h"
25#include "ns3/packet.h"
26#include "ns3/wifi-phy.h"
27
28#include <cstdint>
29#include <cstdlib>
30#include <fstream>
31#include <iostream>
32#include <string>
33
34namespace ns3
35{
36
37/**
38 * A structure containing parameters of a single rate and its
39 * statistics.
40 */
42{
43 WifiMode mode; ///< MCS
44 uint16_t channelWidth; ///< channel width in MHz
45 uint8_t nss; ///< Number of spatial streams
46
47 double success{0.0}; ///< averaged number of successful transmissions
48 double fails{0.0}; ///< averaged number of failed transmissions
49 Time lastDecay{0}; ///< last time exponential decay was applied to this rate
50};
51
52/**
53 * Holds station state and collected statistics.
54 *
55 * This struct extends from WifiRemoteStation to hold additional
56 * information required by ThompsonSamplingWifiManager.
57 */
59{
60 size_t m_nextMode; //!< Mode to select for the next transmission
61 size_t m_lastMode; //!< Most recently used mode, used to write statistics
62
63 std::vector<RateStats> m_mcsStats; //!< Collected statistics
64};
65
67
68NS_LOG_COMPONENT_DEFINE("ThompsonSamplingWifiManager");
69
72{
73 static TypeId tid =
74 TypeId("ns3::ThompsonSamplingWifiManager")
76 .SetGroupName("Wifi")
77 .AddConstructor<ThompsonSamplingWifiManager>()
78 .AddAttribute(
79 "Decay",
80 "Exponential decay coefficient, Hz; zero is a valid value for static scenarios",
81 DoubleValue(1.0),
83 MakeDoubleChecker<double>(0.0))
84 .AddTraceSource("Rate",
85 "Traced value for rate changes (b/s)",
87 "ns3::TracedValueCallback::Uint64");
88 return tid;
89}
90
92 : m_currentRate{0}
93{
94 NS_LOG_FUNCTION(this);
95
96 m_gammaRandomVariable = CreateObject<GammaRandomVariable>();
97}
98
100{
101 NS_LOG_FUNCTION(this);
102}
103
106{
107 NS_LOG_FUNCTION(this);
108 auto station = new ThompsonSamplingWifiRemoteStation();
109 station->m_nextMode = 0;
110 station->m_lastMode = 0;
111 return station;
112}
113
114void
116{
117 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
118 if (!station->m_mcsStats.empty())
119 {
120 return;
121 }
122
123 // Add HT, VHT or HE MCSes
124 for (const auto& mode : GetPhy()->GetMcsList())
125 {
126 for (uint16_t j = 20; j <= GetPhy()->GetChannelWidth(); j *= 2)
127 {
128 WifiModulationClass modulationClass = WIFI_MOD_CLASS_HT;
129 if (GetVhtSupported())
130 {
131 modulationClass = WIFI_MOD_CLASS_VHT;
132 }
133 if (GetHeSupported())
134 {
135 modulationClass = WIFI_MOD_CLASS_HE;
136 }
137 if (mode.GetModulationClass() == modulationClass)
138 {
139 for (uint8_t k = 1; k <= GetPhy()->GetMaxSupportedTxSpatialStreams(); k++)
140 {
141 if (mode.IsAllowed(j, k))
142 {
143 RateStats stats;
144 stats.mode = mode;
145 stats.channelWidth = j;
146 stats.nss = k;
147
148 station->m_mcsStats.push_back(stats);
149 }
150 }
151 }
152 }
153 }
154
155 if (station->m_mcsStats.empty())
156 {
157 // Add legacy non-HT modes.
158 for (uint8_t i = 0; i < GetNSupported(station); i++)
159 {
160 RateStats stats;
161 stats.mode = GetSupported(station, i);
164 {
165 stats.channelWidth = 22;
166 }
167 else
168 {
169 stats.channelWidth = 20;
170 }
171 stats.nss = 1;
172 station->m_mcsStats.push_back(stats);
173 }
174 }
175
176 NS_ASSERT_MSG(!station->m_mcsStats.empty(), "No usable MCS found");
177
178 UpdateNextMode(st);
179}
180
181void
183{
184 NS_LOG_FUNCTION(this << station << rxSnr << txMode);
185}
186
187void
189{
190 NS_LOG_FUNCTION(this << station);
191}
192
193void
195{
196 NS_LOG_FUNCTION(this << st);
198 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
199 Decay(st, station->m_lastMode);
200 station->m_mcsStats.at(station->m_lastMode).fails++;
201 UpdateNextMode(st);
202}
203
204void
206 double ctsSnr,
207 WifiMode ctsMode,
208 double rtsSnr)
209{
210 NS_LOG_FUNCTION(this << st << ctsSnr << ctsMode.GetUniqueName() << rtsSnr);
211}
212
213void
215{
217 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
218
219 double maxThroughput = 0.0;
220 double frameSuccessRate = 1.0;
221
222 NS_ASSERT(!station->m_mcsStats.empty());
223
224 // Use the most robust MCS if frameSuccessRate is 0 for all MCS.
225 station->m_nextMode = 0;
226
227 for (uint32_t i = 0; i < station->m_mcsStats.size(); i++)
228 {
229 Decay(st, i);
230 const WifiMode mode{station->m_mcsStats.at(i).mode};
231
232 uint16_t guardInterval = GetModeGuardInterval(st, mode);
233 double rate = mode.GetDataRate(station->m_mcsStats.at(i).channelWidth,
234 guardInterval,
235 station->m_mcsStats.at(i).nss);
236
237 // Thompson sampling
238 frameSuccessRate = SampleBetaVariable(1.0 + station->m_mcsStats.at(i).success,
239 1.0 + station->m_mcsStats.at(i).fails);
240 NS_LOG_DEBUG("Draw"
241 << " success=" << station->m_mcsStats.at(i).success
242 << " fails=" << station->m_mcsStats.at(i).fails
243 << " frameSuccessRate=" << frameSuccessRate << " mode=" << mode);
244 if (frameSuccessRate * rate > maxThroughput)
245 {
246 maxThroughput = frameSuccessRate * rate;
247 station->m_nextMode = i;
248 }
249 }
250}
251
252void
254 double ackSnr,
255 WifiMode ackMode,
256 double dataSnr,
257 uint16_t dataChannelWidth,
258 uint8_t dataNss)
259{
260 NS_LOG_FUNCTION(this << st << ackSnr << ackMode.GetUniqueName() << dataSnr);
262 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
263 Decay(st, station->m_lastMode);
264 station->m_mcsStats.at(station->m_lastMode).success++;
265 UpdateNextMode(st);
266}
267
268void
270 uint16_t nSuccessfulMpdus,
271 uint16_t nFailedMpdus,
272 double rxSnr,
273 double dataSnr,
274 uint16_t dataChannelWidth,
275 uint8_t dataNss)
276{
277 NS_LOG_FUNCTION(this << st << nSuccessfulMpdus << nFailedMpdus << rxSnr << dataSnr);
279 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
280
281 Decay(st, station->m_lastMode);
282 station->m_mcsStats.at(station->m_lastMode).success += nSuccessfulMpdus;
283 station->m_mcsStats.at(station->m_lastMode).fails += nFailedMpdus;
284
285 UpdateNextMode(st);
286}
287
288void
290{
291 NS_LOG_FUNCTION(this << station);
292}
293
294void
296{
297 NS_LOG_FUNCTION(this << station);
298}
299
300uint16_t
302{
304 {
305 return std::max(GetGuardInterval(st), GetGuardInterval());
306 }
307 else if ((mode.GetModulationClass() == WIFI_MOD_CLASS_HT) ||
309 {
310 return std::max<uint16_t>(GetShortGuardIntervalSupported(st) ? 400 : 800,
311 GetShortGuardIntervalSupported() ? 400 : 800);
312 }
313 else
314 {
315 return 800;
316 }
317}
318
321{
322 NS_LOG_FUNCTION(this << st << allowedWidth);
324 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
325
326 auto& stats = station->m_mcsStats.at(station->m_nextMode);
327 WifiMode mode = stats.mode;
328 uint16_t channelWidth = std::min(stats.channelWidth, allowedWidth);
329 uint8_t nss = stats.nss;
330 uint16_t guardInterval = GetModeGuardInterval(st, mode);
331
332 station->m_lastMode = station->m_nextMode;
333
334 NS_LOG_DEBUG("Using"
335 << " mode=" << mode << " channelWidth=" << channelWidth << " nss=" << +nss
336 << " guardInterval=" << guardInterval);
337
338 uint64_t rate = mode.GetDataRate(channelWidth, guardInterval, nss);
339 if (m_currentRate != rate)
340 {
341 NS_LOG_DEBUG("New datarate: " << rate);
342 m_currentRate = rate;
343 }
344
345 return WifiTxVector(
346 mode,
349 GetModeGuardInterval(st, mode),
351 nss,
352 0, // NESS
353 GetPhy()->GetTxBandwidth(mode, channelWidth),
354 GetAggregation(station),
355 false);
356}
357
360{
361 NS_LOG_FUNCTION(this << st);
363 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
364
365 // Use the most robust MCS for the control channel.
366 auto& stats = station->m_mcsStats.at(0);
367 WifiMode mode = stats.mode;
368 uint8_t nss = stats.nss;
369
370 // Make sure control frames are sent using 1 spatial stream.
371 NS_ASSERT(nss == 1);
372
373 return WifiTxVector(
374 mode,
377 GetModeGuardInterval(st, mode),
379 nss,
380 0, // NESS
381 GetPhy()->GetTxBandwidth(mode, stats.channelWidth),
382 GetAggregation(station),
383 false);
384}
385
386double
387ThompsonSamplingWifiManager::SampleBetaVariable(uint64_t alpha, uint64_t beta) const
388{
389 double X = m_gammaRandomVariable->GetValue(alpha, 1.0);
390 double Y = m_gammaRandomVariable->GetValue(beta, 1.0);
391 return X / (X + Y);
392}
393
394void
396{
397 NS_LOG_FUNCTION(this << st << i);
399 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
400
401 Time now = Simulator::Now();
402 auto& stats = station->m_mcsStats.at(i);
403 if (now > stats.lastDecay)
404 {
405 const double coefficient = std::exp(m_decay * (stats.lastDecay - now).GetSeconds());
406
407 stats.success *= coefficient;
408 stats.fails *= coefficient;
409 stats.lastDecay = now;
410 }
411}
412
413int64_t
415{
416 NS_LOG_FUNCTION(this << stream);
417 m_gammaRandomVariable->SetStream(stream);
418 return 1;
419}
420
421} // namespace ns3
This class can be used to hold variables of floating point type such as 'double' or 'float'.
Definition: double.h:42
static Time Now()
Return the current simulation virtual time.
Definition: simulator.cc:208
Thompson Sampling rate control algorithm.
uint16_t GetModeGuardInterval(WifiRemoteStation *st, WifiMode mode) const
Returns guard interval in nanoseconds for the given mode.
void DoReportRxOk(WifiRemoteStation *station, double rxSnr, WifiMode txMode) override
This method is a pure virtual method that must be implemented by the sub-class.
void InitializeStation(WifiRemoteStation *station) const
Initializes station rate tables.
void DoReportDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportDataOk(WifiRemoteStation *station, double ackSnr, WifiMode ackMode, double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportAmpduTxStatus(WifiRemoteStation *station, uint16_t nSuccessfulMpdus, uint16_t nFailedMpdus, double rxSnr, double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss) override
Typically called per A-MPDU, either when a Block ACK was successfully received or when a BlockAckTime...
TracedValue< uint64_t > m_currentRate
Trace rate changes.
double SampleBetaVariable(uint64_t alpha, uint64_t beta) const
Sample beta random variable with given parameters.
WifiRemoteStation * DoCreateStation() const override
Ptr< GammaRandomVariable > m_gammaRandomVariable
Variable used to sample beta-distributed random variables.
void DoReportFinalRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportFinalDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
static TypeId GetTypeId()
Get the type ID.
WifiTxVector DoGetDataTxVector(WifiRemoteStation *station, uint16_t allowedWidth) override
double m_decay
Exponential decay coefficient, Hz.
void DoReportRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void UpdateNextMode(WifiRemoteStation *station) const
Draws a new MCS and related parameters to try next time for this station.
int64_t AssignStreams(int64_t stream) override
Assign a fixed random variable stream number to the random variables used by this model.
WifiTxVector DoGetRtsTxVector(WifiRemoteStation *station) override
void DoReportRtsOk(WifiRemoteStation *station, double ctsSnr, WifiMode ctsMode, double rtsSnr) override
This method is a pure virtual method that must be implemented by the sub-class.
void Decay(WifiRemoteStation *st, size_t i) const
Applies exponential decay to MCS statistics.
Simulation virtual time values and global simulation resolution.
Definition: nstime.h:105
a unique identifier for an interface.
Definition: type-id.h:59
TypeId SetParent(TypeId tid)
Set the parent TypeId.
Definition: type-id.cc:932
represent a single transmission mode
Definition: wifi-mode.h:51
std::string GetUniqueName() const
Definition: wifi-mode.cc:148
WifiModulationClass GetModulationClass() const
Definition: wifi-mode.cc:185
uint64_t GetDataRate(uint16_t channelWidth, uint16_t guardInterval, uint8_t nss) const
Definition: wifi-mode.cc:122
uint16_t GetChannelWidth() const
Definition: wifi-phy.cc:1072
uint8_t GetMaxSupportedTxSpatialStreams() const
Definition: wifi-phy.cc:1322
std::list< WifiMode > GetMcsList() const
The WifiPhy::GetMcsList() method is used (e.g., by a WifiRemoteStationManager) to determine the set o...
Definition: wifi-phy.cc:2034
hold a list of per-remote-station state.
uint8_t GetNSupported(const WifiRemoteStation *station) const
Return the number of modes supported by the given station.
Ptr< WifiPhy > GetPhy() const
Return the WifiPhy.
uint16_t GetGuardInterval() const
Return the supported HE guard interval duration (in nanoseconds).
bool GetAggregation(const WifiRemoteStation *station) const
Return whether the given station supports A-MPDU.
bool GetShortGuardIntervalSupported() const
Return whether the device has SGI support enabled.
bool GetVhtSupported() const
Return whether the device has VHT capability support enabled.
bool GetShortPreambleEnabled() const
Return whether the device uses short PHY preambles.
WifiMode GetSupported(const WifiRemoteStation *station, uint8_t i) const
Return whether mode associated with the specified station at the specified index.
bool GetHeSupported() const
Return whether the device has HE capability support enabled.
This class mimics the TXVECTOR which is to be passed to the PHY in order to define the parameters whi...
#define NS_ASSERT(condition)
At runtime, in debugging builds, if this condition is not true, the program prints the source file,...
Definition: assert.h:66
#define NS_ASSERT_MSG(condition, message)
At runtime, in debugging builds, if this condition is not true, the program prints the message to out...
Definition: assert.h:86
Ptr< const AttributeAccessor > MakeDoubleAccessor(T1 a1)
Definition: double.h:43
#define NS_LOG_COMPONENT_DEFINE(name)
Define a Log component with a specific name.
Definition: log.h:202
#define NS_LOG_DEBUG(msg)
Use NS_LOG to output a message of level LOG_DEBUG.
Definition: log.h:268
#define NS_LOG_FUNCTION(parameters)
If log level LOG_FUNCTION is enabled, this macro will output all input parameters separated by ",...
#define NS_OBJECT_ENSURE_REGISTERED(type)
Register an Object subclass with the TypeId system.
Definition: object-base.h:46
Ptr< const TraceSourceAccessor > MakeTraceSourceAccessor(T a)
Create a TraceSourceAccessor which will control access to the underlying trace source.
WifiModulationClass
This enumeration defines the modulation classes per (Table 10-6 "Modulation classes"; IEEE 802....
@ WIFI_MOD_CLASS_HR_DSSS
HR/DSSS (Clause 16)
@ WIFI_MOD_CLASS_HT
HT (Clause 19)
@ WIFI_MOD_CLASS_VHT
VHT (Clause 22)
@ WIFI_MOD_CLASS_HE
HE (Clause 27)
@ WIFI_MOD_CLASS_DSSS
DSSS (Clause 15)
Every class exported by the ns3 library is enclosed in the ns3 namespace.
WifiPreamble GetPreambleForTransmission(WifiModulationClass modulation, bool useShortPreamble)
Return the preamble to be used for the transmission.
A structure containing parameters of a single rate and its statistics.
uint16_t channelWidth
channel width in MHz
uint8_t nss
Number of spatial streams.
double success
averaged number of successful transmissions
double fails
averaged number of failed transmissions
Time lastDecay
last time exponential decay was applied to this rate
Holds station state and collected statistics.
size_t m_nextMode
Mode to select for the next transmission.
std::vector< RateStats > m_mcsStats
Collected statistics.
size_t m_lastMode
Most recently used mode, used to write statistics.
hold per-remote-station state.