A Discrete-Event Network Simulator
API
Loading...
Searching...
No Matches
thompson-sampling-wifi-manager.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 IITP RAS
3 *
4 * SPDX-License-Identifier: GPL-2.0-only
5 *
6 * Author: Alexander Krotov <krotov@iitp.ru>
7 */
8
10
11#include "ns3/core-module.h"
12#include "ns3/double.h"
13#include "ns3/log.h"
14#include "ns3/packet.h"
15#include "ns3/wifi-phy.h"
16
17#include <cstdint>
18#include <cstdlib>
19#include <fstream>
20#include <iostream>
21#include <string>
22
23namespace ns3
24{
25
26/**
27 * A structure containing parameters of a single rate and its
28 * statistics.
29 */
31{
32 WifiMode mode; ///< MCS
33 MHz_u channelWidth; ///< channel width
34 uint8_t nss; ///< Number of spatial streams
35
36 double success{0.0}; ///< averaged number of successful transmissions
37 double fails{0.0}; ///< averaged number of failed transmissions
38 Time lastDecay{0}; ///< last time exponential decay was applied to this rate
39};
40
41/**
42 * Holds station state and collected statistics.
43 *
44 * This struct extends from WifiRemoteStation to hold additional
45 * information required by ThompsonSamplingWifiManager.
46 */
48{
49 size_t m_nextMode; //!< Mode to select for the next transmission
50 size_t m_lastMode; //!< Most recently used mode, used to write statistics
51
52 std::vector<RateStats> m_mcsStats; //!< Collected statistics
53};
54
56
57NS_LOG_COMPONENT_DEFINE("ThompsonSamplingWifiManager");
58
61{
62 static TypeId tid =
63 TypeId("ns3::ThompsonSamplingWifiManager")
65 .SetGroupName("Wifi")
66 .AddConstructor<ThompsonSamplingWifiManager>()
67 .AddAttribute(
68 "Decay",
69 "Exponential decay coefficient, Hz; zero is a valid value for static scenarios",
70 DoubleValue(1.0),
73 .AddTraceSource("Rate",
74 "Traced value for rate changes (b/s)",
76 "ns3::TracedValueCallback::Uint64");
77 return tid;
78}
79
87
92
95{
96 NS_LOG_FUNCTION(this);
97 auto station = new ThompsonSamplingWifiRemoteStation();
98 station->m_nextMode = 0;
99 station->m_lastMode = 0;
100 return station;
101}
102
103void
105{
106 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
107 if (!station->m_mcsStats.empty())
108 {
109 return;
110 }
111
112 // Add HT, VHT or HE MCSes
113 for (const auto& mode : GetPhy()->GetMcsList())
114 {
115 for (MHz_u j{20}; j <= GetPhy()->GetChannelWidth(); j *= 2)
116 {
117 WifiModulationClass modulationClass = WIFI_MOD_CLASS_HT;
118 if (GetVhtSupported())
119 {
120 modulationClass = WIFI_MOD_CLASS_VHT;
121 }
122 if (GetHeSupported())
123 {
124 modulationClass = WIFI_MOD_CLASS_HE;
125 }
126 if (mode.GetModulationClass() == modulationClass)
127 {
128 for (uint8_t k = 1; k <= GetPhy()->GetMaxSupportedTxSpatialStreams(); k++)
129 {
130 if (mode.IsAllowed(j, k))
131 {
132 RateStats stats;
133 stats.mode = mode;
134 stats.channelWidth = j;
135 stats.nss = k;
136
137 station->m_mcsStats.push_back(stats);
138 }
139 }
140 }
141 }
142 }
143
144 if (station->m_mcsStats.empty())
145 {
146 // Add legacy non-HT modes.
147 for (uint8_t i = 0; i < GetNSupported(station); i++)
148 {
149 RateStats stats;
150 stats.mode = GetSupported(station, i);
153 {
154 stats.channelWidth = MHz_u{22};
155 }
156 else
157 {
158 stats.channelWidth = MHz_u{20};
159 }
160 stats.nss = 1;
161 station->m_mcsStats.push_back(stats);
162 }
163 }
164
165 NS_ASSERT_MSG(!station->m_mcsStats.empty(), "No usable MCS found");
166
167 UpdateNextMode(st);
168}
169
170void
172{
173 NS_LOG_FUNCTION(this << station << rxSnr << txMode);
174}
175
176void
181
182void
184{
185 NS_LOG_FUNCTION(this << st);
187 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
188 Decay(st, station->m_lastMode);
189 station->m_mcsStats.at(station->m_lastMode).fails++;
190 UpdateNextMode(st);
191}
192
193void
195 double ctsSnr,
196 WifiMode ctsMode,
197 double rtsSnr)
198{
199 NS_LOG_FUNCTION(this << st << ctsSnr << ctsMode.GetUniqueName() << rtsSnr);
200}
201
202void
204{
206 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
207
208 double maxThroughput = 0.0;
209 double frameSuccessRate = 1.0;
210
211 NS_ASSERT(!station->m_mcsStats.empty());
212
213 // Use the most robust MCS if frameSuccessRate is 0 for all MCS.
214 station->m_nextMode = 0;
215
216 for (uint32_t i = 0; i < station->m_mcsStats.size(); i++)
217 {
218 Decay(st, i);
219 const auto mode{station->m_mcsStats.at(i).mode};
220
221 const auto guardInterval = GetModeGuardInterval(st, mode);
222 const auto rate = mode.GetDataRate(station->m_mcsStats.at(i).channelWidth,
223 guardInterval,
224 station->m_mcsStats.at(i).nss);
225
226 // Thompson sampling
227 frameSuccessRate = SampleBetaVariable(1.0 + station->m_mcsStats.at(i).success,
228 1.0 + station->m_mcsStats.at(i).fails);
229 NS_LOG_DEBUG("Draw success=" << station->m_mcsStats.at(i).success
230 << " fails=" << station->m_mcsStats.at(i).fails
231 << " frameSuccessRate=" << frameSuccessRate
232 << " mode=" << mode);
233 if (frameSuccessRate * rate > maxThroughput)
234 {
235 maxThroughput = frameSuccessRate * rate;
236 station->m_nextMode = i;
237 }
238 }
239}
240
241void
243 double ackSnr,
244 WifiMode ackMode,
245 double dataSnr,
246 MHz_u dataChannelWidth,
247 uint8_t dataNss)
248{
249 NS_LOG_FUNCTION(this << st << ackSnr << ackMode.GetUniqueName() << dataSnr);
251 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
252 Decay(st, station->m_lastMode);
253 station->m_mcsStats.at(station->m_lastMode).success++;
254 UpdateNextMode(st);
255}
256
257void
259 uint16_t nSuccessfulMpdus,
260 uint16_t nFailedMpdus,
261 double rxSnr,
262 double dataSnr,
263 MHz_u dataChannelWidth,
264 uint8_t dataNss)
265{
266 NS_LOG_FUNCTION(this << st << nSuccessfulMpdus << nFailedMpdus << rxSnr << dataSnr);
268 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
269
270 Decay(st, station->m_lastMode);
271 station->m_mcsStats.at(station->m_lastMode).success += nSuccessfulMpdus;
272 station->m_mcsStats.at(station->m_lastMode).fails += nFailedMpdus;
273
274 UpdateNextMode(st);
275}
276
277void
282
283void
288
289Time
291{
293 {
294 return std::max(GetGuardInterval(st), GetGuardInterval());
295 }
296 else if ((mode.GetModulationClass() == WIFI_MOD_CLASS_HT) ||
298 {
300 return NanoSeconds(useSgi ? 400 : 800);
301 }
302 else
303 {
304 return NanoSeconds(800);
305 }
306}
307
310{
311 NS_LOG_FUNCTION(this << st << allowedWidth);
313 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
314
315 auto& stats = station->m_mcsStats.at(station->m_nextMode);
316 const auto mode = stats.mode;
317 const auto channelWidth = std::min(stats.channelWidth, allowedWidth);
318 const auto nss = stats.nss;
319 const auto guardInterval = GetModeGuardInterval(st, mode);
320
321 station->m_lastMode = station->m_nextMode;
322
323 NS_LOG_DEBUG("Using mode=" << mode << " channelWidth=" << channelWidth << " nss=" << +nss
324 << " guardInterval=" << guardInterval);
325
326 const auto rate = mode.GetDataRate(channelWidth, guardInterval, nss);
327 if (m_currentRate != rate)
328 {
329 NS_LOG_DEBUG("New datarate: " << rate);
330 m_currentRate = rate;
331 }
332
333 return WifiTxVector(
334 mode,
336 GetPreambleForTransmission(mode.GetModulationClass(), GetShortPreambleEnabled()),
337 GetModeGuardInterval(st, mode),
339 nss,
340 0, // NESS
341 GetPhy()->GetTxBandwidth(mode, channelWidth),
342 GetAggregation(station),
343 false);
344}
345
348{
349 NS_LOG_FUNCTION(this << st);
351 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
352
353 // Use the most robust MCS for the control channel.
354 auto& stats = station->m_mcsStats.at(0);
355 WifiMode mode = stats.mode;
356 uint8_t nss = stats.nss;
357
358 // Make sure control frames are sent using 1 spatial stream.
359 NS_ASSERT(nss == 1);
360
361 return WifiTxVector(
362 mode,
365 GetModeGuardInterval(st, mode),
367 nss,
368 0, // NESS
369 GetPhy()->GetTxBandwidth(mode, stats.channelWidth),
370 GetAggregation(station),
371 false);
372}
373
374double
375ThompsonSamplingWifiManager::SampleBetaVariable(uint64_t alpha, uint64_t beta) const
376{
377 double X = m_gammaRandomVariable->GetValue(alpha, 1.0);
378 double Y = m_gammaRandomVariable->GetValue(beta, 1.0);
379 return X / (X + Y);
380}
381
382void
384{
385 NS_LOG_FUNCTION(this << st << i);
387 auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
388
389 Time now = Simulator::Now();
390 auto& stats = station->m_mcsStats.at(i);
391 if (now > stats.lastDecay)
392 {
393 const double coefficient = std::exp(m_decay * (stats.lastDecay - now).GetSeconds());
394
395 stats.success *= coefficient;
396 stats.fails *= coefficient;
397 stats.lastDecay = now;
398 }
399}
400
401int64_t
403{
404 NS_LOG_FUNCTION(this << stream);
405 m_gammaRandomVariable->SetStream(stream);
406 return 1;
407}
408
409} // namespace ns3
This class can be used to hold variables of floating point type such as 'double' or 'float'.
Definition double.h:31
static Time Now()
Return the current simulation virtual time.
Definition simulator.cc:197
Thompson Sampling rate control algorithm.
void DoReportRxOk(WifiRemoteStation *station, double rxSnr, WifiMode txMode) override
This method is a pure virtual method that must be implemented by the sub-class.
void InitializeStation(WifiRemoteStation *station) const
Initializes station rate tables.
void DoReportDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
TracedValue< uint64_t > m_currentRate
Trace rate changes.
Time GetModeGuardInterval(WifiRemoteStation *st, WifiMode mode) const
Returns guard interval for the given mode.
double SampleBetaVariable(uint64_t alpha, uint64_t beta) const
Sample beta random variable with given parameters.
WifiRemoteStation * DoCreateStation() const override
Ptr< GammaRandomVariable > m_gammaRandomVariable
Variable used to sample beta-distributed random variables.
void DoReportFinalRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportFinalDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
double m_decay
Exponential decay coefficient, Hz.
void DoReportRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportDataOk(WifiRemoteStation *station, double ackSnr, WifiMode ackMode, double dataSnr, MHz_u dataChannelWidth, uint8_t dataNss) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportAmpduTxStatus(WifiRemoteStation *station, uint16_t nSuccessfulMpdus, uint16_t nFailedMpdus, double rxSnr, double dataSnr, MHz_u dataChannelWidth, uint8_t dataNss) override
Typically called per A-MPDU, either when a Block ACK was successfully received or when a BlockAckTime...
void UpdateNextMode(WifiRemoteStation *station) const
Draws a new MCS and related parameters to try next time for this station.
int64_t AssignStreams(int64_t stream) override
Assign a fixed random variable stream number to the random variables used by this model.
WifiTxVector DoGetDataTxVector(WifiRemoteStation *station, MHz_u allowedWidth) override
WifiTxVector DoGetRtsTxVector(WifiRemoteStation *station) override
void DoReportRtsOk(WifiRemoteStation *station, double ctsSnr, WifiMode ctsMode, double rtsSnr) override
This method is a pure virtual method that must be implemented by the sub-class.
void Decay(WifiRemoteStation *st, size_t i) const
Applies exponential decay to MCS statistics.
Simulation virtual time values and global simulation resolution.
Definition nstime.h:94
a unique identifier for an interface.
Definition type-id.h:49
TypeId SetParent(TypeId tid)
Set the parent TypeId.
Definition type-id.cc:1001
represent a single transmission mode
Definition wifi-mode.h:40
const std::string & GetUniqueName() const
Definition wifi-mode.cc:137
WifiModulationClass GetModulationClass() const
Definition wifi-mode.cc:174
MHz_u GetChannelWidth() const
Definition wifi-phy.cc:1099
uint8_t GetMaxSupportedTxSpatialStreams() const
Definition wifi-phy.cc:1378
std::list< WifiMode > GetMcsList() const
The WifiPhy::GetMcsList() method is used (e.g., by a WifiRemoteStationManager) to determine the set o...
Definition wifi-phy.cc:2131
hold a list of per-remote-station state.
Time GetGuardInterval() const
Return the shortest supported HE guard interval duration.
uint8_t GetNSupported(const WifiRemoteStation *station) const
Return the number of modes supported by the given station.
Ptr< WifiPhy > GetPhy() const
Return the WifiPhy.
bool GetAggregation(const WifiRemoteStation *station) const
Return whether the given station supports A-MPDU.
bool GetShortGuardIntervalSupported() const
Return whether the device has SGI support enabled.
bool GetVhtSupported() const
Return whether the device has VHT capability support enabled on the link this manager is associated w...
bool GetShortPreambleEnabled() const
Return whether the device uses short PHY preambles.
WifiMode GetSupported(const WifiRemoteStation *station, uint8_t i) const
Return whether mode associated with the specified station at the specified index.
bool GetHeSupported() const
Return whether the device has HE capability support enabled.
This class mimics the TXVECTOR which is to be passed to the PHY in order to define the parameters whi...
#define NS_ASSERT(condition)
At runtime, in debugging builds, if this condition is not true, the program prints the source file,...
Definition assert.h:55
#define NS_ASSERT_MSG(condition, message)
At runtime, in debugging builds, if this condition is not true, the program prints the message to out...
Definition assert.h:75
Ptr< const AttributeChecker > MakeDoubleChecker()
Definition double.h:82
Ptr< const AttributeAccessor > MakeDoubleAccessor(T1 a1)
Create an AttributeAccessor for a class data member, or a lone class get functor or set method.
Definition double.h:32
#define NS_LOG_COMPONENT_DEFINE(name)
Define a Log component with a specific name.
Definition log.h:191
#define NS_LOG_DEBUG(msg)
Use NS_LOG to output a message of level LOG_DEBUG.
Definition log.h:257
#define NS_LOG_FUNCTION(parameters)
If log level LOG_FUNCTION is enabled, this macro will output all input parameters separated by ",...
Ptr< T > CreateObject(Args &&... args)
Create an object by type, with varying number of constructor parameters.
Definition object.h:619
#define NS_OBJECT_ENSURE_REGISTERED(type)
Register an Object subclass with the TypeId system.
Definition object-base.h:35
Time NanoSeconds(uint64_t value)
Construct a Time in the indicated unit.
Definition nstime.h:1381
Ptr< const TraceSourceAccessor > MakeTraceSourceAccessor(T a)
Create a TraceSourceAccessor which will control access to the underlying trace source.
WifiModulationClass
This enumeration defines the modulation classes per (Table 10-6 "Modulation classes"; IEEE 802....
@ WIFI_MOD_CLASS_HR_DSSS
HR/DSSS (Clause 16)
@ WIFI_MOD_CLASS_HT
HT (Clause 19)
@ WIFI_MOD_CLASS_VHT
VHT (Clause 22)
@ WIFI_MOD_CLASS_HE
HE (Clause 27)
@ WIFI_MOD_CLASS_DSSS
DSSS (Clause 15)
Every class exported by the ns3 library is enclosed in the ns3 namespace.
WifiPreamble GetPreambleForTransmission(WifiModulationClass modulation, bool useShortPreamble)
Return the preamble to be used for the transmission.
A structure containing parameters of a single rate and its statistics.
uint8_t nss
Number of spatial streams.
double success
averaged number of successful transmissions
double fails
averaged number of failed transmissions
Time lastDecay
last time exponential decay was applied to this rate
Holds station state and collected statistics.
size_t m_nextMode
Mode to select for the next transmission.
std::vector< RateStats > m_mcsStats
Collected statistics.
size_t m_lastMode
Most recently used mode, used to write statistics.
hold per-remote-station state.