|
| 1 | +// SPDX-License-Identifier: GPL-2.0-only |
| 2 | + |
| 3 | +//! TCP CUBIC congestion control algorithm. |
| 4 | +//! |
| 5 | +//! Based on: |
| 6 | +//! Sangtae Ha, Injong Rhee, and Lisong Xu. 2008. |
| 7 | +//! CUBIC: A New TCP-Friendly High-Speed TCP Variant. |
| 8 | +//! SIGOPS Oper. Syst. Rev. 42, 5 (July 2008), 64–74. |
| 9 | +//! <https://doi.org/10.1145/1400097.1400105> |
| 10 | +//! |
| 11 | +//! CUBIC is also described in [RFC9438](https://www.rfc-editor.org/rfc/rfc9438). |
| 12 | +
|
| 13 | +use core::cmp::{max, min}; |
| 14 | +use core::num::NonZeroU32; |
| 15 | +use kernel::net::tcp; |
| 16 | +use kernel::net::tcp::cong::{self, hystart, hystart::HystartDetect}; |
| 17 | +use kernel::prelude::*; |
| 18 | +use kernel::time; |
| 19 | +use kernel::{c_str, module_cca}; |
| 20 | + |
| 21 | +const BICTCP_BETA_SCALE: u32 = 1024; |
| 22 | + |
| 23 | +// TODO: Convert to module parameters once they are available. Currently these |
| 24 | +// are the defaults from the C implementation. |
| 25 | +// TODO: Use `NonZeroU32` where appropriate. |
| 26 | +/// Whether to use fast convergence. This is a heuristic to increase the |
| 27 | +/// release of bandwidth by existing flows to speed up the convergence to a |
| 28 | +/// steady state when a new flow joins the link. |
| 29 | +const FAST_CONVERGENCE: bool = true; |
| 30 | +/// The factor for multiplicative decrease of cwnd upon a loss event. Will be |
| 31 | +/// divided by `BICTCP_BETA_SCALE`, approximately 0.7. |
| 32 | +const BETA: u32 = 717; |
| 33 | +/// The initial value of ssthresh for new connections. Setting this to `None` |
| 34 | +/// implies `i32::MAX`. |
| 35 | +const INITIAL_SSTHRESH: Option<u32> = None; |
| 36 | +/// The parameter `C` that scales the cubic term is defined as `BIC_SCALE/2^10`. |
| 37 | +/// (For C: Dimension: Time^-2, Unit: s^-2). |
| 38 | +const BIC_SCALE: u32 = 41; |
| 39 | +/// In environments where CUBIC grows cwnd less aggressively than normal TCP, |
| 40 | +/// enabling this option causes it to behave like normal TCP instead. This is |
| 41 | +/// the case in short RTT and/or low bandwidth delay product networks. |
| 42 | +const TCP_FRIENDLINESS: bool = true; |
| 43 | +/// Whether to use the [HyStart] slow start algorithm. |
| 44 | +/// |
| 45 | +/// [HyStart]: hystart::HyStart |
| 46 | +const HYSTART: bool = true; |
| 47 | + |
| 48 | +impl hystart::HyStart for Cubic { |
| 49 | + /// Which mechanism to use for deciding when it is time to exit slow start. |
| 50 | + const DETECT: HystartDetect = HystartDetect::Both; |
| 51 | + /// Lower bound for cwnd during hybrid slow start. |
| 52 | + const LOW_WINDOW: u32 = 16; |
| 53 | + /// Spacing between ACKs indicating an ACK-train. |
| 54 | + /// (Dimension: Time. Unit: us). |
| 55 | + const ACK_DELTA: time::Usecs32 = 2000; |
| 56 | +} |
| 57 | + |
| 58 | +// TODO: Those are computed based on the module parameters in the init. Even |
| 59 | +// with module parameters available this will be a bit tricky to do in Rust. |
| 60 | +/// Factor of `8/3 * (1 + beta) / (1 - beta)` that is used in various |
| 61 | +/// calculations. (Dimension: none) |
| 62 | +const BETA_SCALE: u32 = ((8 * (BICTCP_BETA_SCALE + BETA)) / 3) / (BICTCP_BETA_SCALE - BETA); |
| 63 | +/// Factor of `2^10*C/SRTT` where `SRTT = 100ms` that is used in various |
| 64 | +/// calculations. (Dimension: Time^-3, Unit: s^-3). |
| 65 | +const CUBE_RTT_SCALE: u32 = BIC_SCALE * 10; |
| 66 | +/// Factor of `SRTT/C` where `SRTT = 100ms` and `C` from above. |
| 67 | +/// (Dimension: Time^3. Unit: (ms)^3) |
| 68 | +// Note: C uses a custom time unit of 2^-10 s called `BICTCP_HZ`. This |
| 69 | +// implementation consistently uses milliseconds instead. |
| 70 | +const CUBE_FACTOR: u64 = 1_000_000_000 * (1u64 << 10) / (CUBE_RTT_SCALE as u64); |
| 71 | + |
| 72 | +module_cca! { |
| 73 | + type: Cubic, |
| 74 | + name: "tcp_cubic_rust", |
| 75 | + author: "Rust for Linux Contributors", |
| 76 | + description: "TCP CUBIC congestion control algorithm, Rust implementation", |
| 77 | + license: "GPL v2", |
| 78 | +} |
| 79 | + |
| 80 | +struct Cubic {} |
| 81 | + |
| 82 | +#[vtable] |
| 83 | +impl cong::Algorithm for Cubic { |
| 84 | + type Data = CubicState; |
| 85 | + |
| 86 | + const NAME: &'static CStr = c_str!("cubic_rust"); |
| 87 | + |
| 88 | + fn init(sk: &mut cong::Sock<'_, Self>) { |
| 89 | + if HYSTART { |
| 90 | + <Self as hystart::HyStart>::reset(sk) |
| 91 | + } else if let Some(ssthresh) = INITIAL_SSTHRESH { |
| 92 | + sk.tcp_sk_mut().set_snd_ssthresh(ssthresh); |
| 93 | + } |
| 94 | + |
| 95 | + // TODO: remove |
| 96 | + pr_info!( |
| 97 | + "init: socket created: start {}us", |
| 98 | + sk.inet_csk_ca().start_time |
| 99 | + ); |
| 100 | + } |
| 101 | + |
| 102 | + // TODO: remove |
| 103 | + fn release(sk: &mut cong::Sock<'_, Self>) { |
| 104 | + pr_info!( |
| 105 | + "release: socket destroyed: start {}us, end {}us", |
| 106 | + sk.inet_csk_ca().start_time, |
| 107 | + time::ktime_get_boot_fast_us32(), |
| 108 | + ); |
| 109 | + } |
| 110 | + |
| 111 | + fn cwnd_event(sk: &mut cong::Sock<'_, Self>, ev: cong::Event) { |
| 112 | + if matches!(ev, cong::Event::TxStart) { |
| 113 | + // Here we cannot avoid jiffies as the `lsndtime` field is measured |
| 114 | + // in jiffies. |
| 115 | + let now = time::jiffies32(); |
| 116 | + let delta: time::Jiffies32 = now.wrapping_sub(sk.tcp_sk().lsndtime()); |
| 117 | + |
| 118 | + if (delta as i32) <= 0 { |
| 119 | + return; |
| 120 | + } |
| 121 | + |
| 122 | + let ca = sk.inet_csk_ca_mut(); |
| 123 | + // Ok, lets switch to SI units. |
| 124 | + let now = time::ktime_get_boot_fast_ms32(); |
| 125 | + let delta = time::jiffies_to_msecs(delta as time::Jiffies); |
| 126 | + // TODO: remove |
| 127 | + pr_debug!("cwnd_event: TxStart, now {}ms, delta {}ms", now, delta); |
| 128 | + // We were application limited, i.e., idle, for a while. If we are |
| 129 | + // in congestion avoidance, shift `epoch_start` by the time we were |
| 130 | + // idle to keep cwnd growth to cubic curve. |
| 131 | + ca.epoch_start = ca.epoch_start.map(|mut epoch_start| { |
| 132 | + epoch_start = epoch_start.wrapping_add(delta); |
| 133 | + if tcp::after(epoch_start, now) { |
| 134 | + epoch_start = now; |
| 135 | + } |
| 136 | + epoch_start |
| 137 | + }); |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) { |
| 142 | + if matches!(new_state, cong::State::Loss) { |
| 143 | + pr_info!( |
| 144 | + // TODO: remove |
| 145 | + "set_state: Loss, time {}us, start {}us", |
| 146 | + time::ktime_get_boot_fast_us32(), |
| 147 | + sk.inet_csk_ca().start_time |
| 148 | + ); |
| 149 | + sk.inet_csk_ca_mut().reset(); |
| 150 | + <Self as hystart::HyStart>::reset(sk); |
| 151 | + } |
| 152 | + } |
| 153 | + |
| 154 | + fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) { |
| 155 | + // Some samples do not include RTTs. |
| 156 | + let Some(rtt_us) = sample.rtt_us() else { |
| 157 | + // TODO: remove |
| 158 | + pr_info!( |
| 159 | + "pkts_acked: no RTT sample, start {}us", |
| 160 | + sk.inet_csk_ca().start_time, |
| 161 | + ); |
| 162 | + return; |
| 163 | + }; |
| 164 | + |
| 165 | + let epoch_start = sk.inet_csk_ca().epoch_start; |
| 166 | + // For some time after existing fast recovery the samples might still be |
| 167 | + // inaccurate. |
| 168 | + if epoch_start.is_some_and(|epoch_start| { |
| 169 | + time::ktime_get_boot_fast_ms32().wrapping_sub(epoch_start) < time::MSEC_PER_SEC |
| 170 | + }) { |
| 171 | + // TODO: remove |
| 172 | + pr_debug!( |
| 173 | + "pkts_acked: {}ms - {}ms < 1s, too close to epoch_start", |
| 174 | + time::ktime_get_boot_fast_ms32(), |
| 175 | + epoch_start.unwrap() |
| 176 | + ); |
| 177 | + return; |
| 178 | + } |
| 179 | + |
| 180 | + let delay = max(1, rtt_us); |
| 181 | + let cwnd = sk.tcp_sk().snd_cwnd(); |
| 182 | + let in_slow_start = sk.tcp_sk().in_slow_start(); |
| 183 | + let ca = sk.inet_csk_ca_mut(); |
| 184 | + |
| 185 | + // TODO: remove |
| 186 | + pr_debug!( |
| 187 | + "pkts_acked: delay {}us, cwnd {}, ss {}", |
| 188 | + delay, |
| 189 | + cwnd, |
| 190 | + in_slow_start |
| 191 | + ); |
| 192 | + |
| 193 | + // First call after reset or the delay decreased. |
| 194 | + if ca.hystart_state.delay_min.is_none() |
| 195 | + || ca |
| 196 | + .hystart_state |
| 197 | + .delay_min |
| 198 | + .is_some_and(|delay_min| delay_min > delay) |
| 199 | + { |
| 200 | + ca.hystart_state.delay_min = Some(delay); |
| 201 | + } |
| 202 | + |
| 203 | + if in_slow_start && HYSTART && ca.hystart_state.in_hystart::<Self>(cwnd) { |
| 204 | + hystart::HyStart::update(sk, delay); |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 { |
| 209 | + let cwnd = sk.tcp_sk().snd_cwnd(); |
| 210 | + let ca = sk.inet_csk_ca_mut(); |
| 211 | + |
| 212 | + pr_info!( |
| 213 | + // TODO: remove |
| 214 | + "ssthresh: time {}us, start {}us", |
| 215 | + time::ktime_get_boot_fast_us32(), |
| 216 | + ca.start_time |
| 217 | + ); |
| 218 | + |
| 219 | + // Epoch has ended. |
| 220 | + ca.epoch_start = None; |
| 221 | + ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE { |
| 222 | + (cwnd * (BICTCP_BETA_SCALE + BETA)) / (2 * BICTCP_BETA_SCALE) |
| 223 | + } else { |
| 224 | + cwnd |
| 225 | + }; |
| 226 | + |
| 227 | + max((cwnd * BETA) / BICTCP_BETA_SCALE, 2) |
| 228 | + } |
| 229 | + |
| 230 | + fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 { |
| 231 | + pr_info!( |
| 232 | + // TODO: remove |
| 233 | + "undo_cwnd: time {}us, start {}us", |
| 234 | + time::ktime_get_boot_fast_us32(), |
| 235 | + sk.inet_csk_ca().start_time |
| 236 | + ); |
| 237 | + |
| 238 | + cong::reno::undo_cwnd(sk) |
| 239 | + } |
| 240 | + |
| 241 | + fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) { |
| 242 | + if !sk.tcp_is_cwnd_limited() { |
| 243 | + return; |
| 244 | + } |
| 245 | + |
| 246 | + let tp = sk.tcp_sk_mut(); |
| 247 | + |
| 248 | + if tp.in_slow_start() { |
| 249 | + acked = tp.slow_start(acked); |
| 250 | + if acked == 0 { |
| 251 | + pr_info!( |
| 252 | + // TODO: remove |
| 253 | + "cong_avoid: new cwnd {}, time {}us, ssthresh {}, start {}us, ss 1", |
| 254 | + sk.tcp_sk().snd_cwnd(), |
| 255 | + time::ktime_get_boot_fast_us32(), |
| 256 | + sk.tcp_sk().snd_ssthresh(), |
| 257 | + sk.inet_csk_ca().start_time |
| 258 | + ); |
| 259 | + return; |
| 260 | + } |
| 261 | + } |
| 262 | + |
| 263 | + let cwnd = tp.snd_cwnd(); |
| 264 | + let cnt = sk.inet_csk_ca_mut().update(cwnd, acked); |
| 265 | + sk.tcp_sk_mut().cong_avoid_ai(cnt, acked); |
| 266 | + |
| 267 | + pr_info!( |
| 268 | + // TODO: remove |
| 269 | + "cong_avoid: new cwnd {}, time {}us, ssthresh {}, start {}us, ss 0", |
| 270 | + sk.tcp_sk().snd_cwnd(), |
| 271 | + time::ktime_get_boot_fast_us32(), |
| 272 | + sk.tcp_sk().snd_ssthresh(), |
| 273 | + sk.inet_csk_ca().start_time |
| 274 | + ); |
| 275 | + } |
| 276 | +} |
| 277 | + |
| 278 | +#[allow(non_snake_case)] |
| 279 | +struct CubicState { |
| 280 | + /// Increase cwnd by one step after `cnt` ACKs. |
| 281 | + cnt: NonZeroU32, |
| 282 | + /// W__last_max. |
| 283 | + last_max_cwnd: u32, |
| 284 | + /// Value of cwnd before it was updated the last time. |
| 285 | + last_cwnd: u32, |
| 286 | + /// Time when `last_cwnd` was updated. |
| 287 | + last_time: time::Msecs32, |
| 288 | + /// Value of cwnd where the plateau of the cubic function is located. |
| 289 | + origin_point: u32, |
| 290 | + /// Time it takes to reach `origin_point`, measured from the beginning of |
| 291 | + /// an epoch. |
| 292 | + K: time::Msecs32, |
| 293 | + /// Time when the current epoch has started. `None` when not in congestion |
| 294 | + /// avoidance. |
| 295 | + epoch_start: Option<time::Msecs32>, |
| 296 | + /// Number of packets that have been ACKed in the current epoch. |
| 297 | + ack_cnt: u32, |
| 298 | + /// Estimate for the cwnd of TCP Reno. |
| 299 | + tcp_cwnd: u32, |
| 300 | + /// State of the HyStart slow start algorithm. |
| 301 | + hystart_state: hystart::HyStartState, |
| 302 | + /// Time when the connection was created. |
| 303 | + // TODO: remove |
| 304 | + start_time: time::Usecs32, |
| 305 | +} |
| 306 | + |
| 307 | +impl hystart::HasHyStartState for CubicState { |
| 308 | + fn hy(&self) -> &hystart::HyStartState { |
| 309 | + &self.hystart_state |
| 310 | + } |
| 311 | + |
| 312 | + fn hy_mut(&mut self) -> &mut hystart::HyStartState { |
| 313 | + &mut self.hystart_state |
| 314 | + } |
| 315 | +} |
| 316 | + |
| 317 | +impl Default for CubicState { |
| 318 | + fn default() -> Self { |
| 319 | + Self { |
| 320 | + // NOTE: Initializing this to 1 deviates from the C code. It does |
| 321 | + // not change the behavior. |
| 322 | + cnt: NonZeroU32::MIN, |
| 323 | + last_max_cwnd: 0, |
| 324 | + last_cwnd: 0, |
| 325 | + last_time: 0, |
| 326 | + origin_point: 0, |
| 327 | + K: 0, |
| 328 | + epoch_start: None, |
| 329 | + ack_cnt: 0, |
| 330 | + tcp_cwnd: 0, |
| 331 | + hystart_state: hystart::HyStartState::default(), |
| 332 | + // TODO: remove |
| 333 | + start_time: time::ktime_get_boot_fast_us32(), |
| 334 | + } |
| 335 | + } |
| 336 | +} |
| 337 | + |
| 338 | +impl CubicState { |
| 339 | + /// Checks if the current CUBIC increase is less aggressive than normal TCP, |
| 340 | + /// i.e., if we are in the TCP-friendly region. If so, returns `cnt` that |
| 341 | + /// increases at the speed of normal TCP. |
| 342 | + #[inline] |
| 343 | + fn tcp_friendliness(&mut self, cnt: u32, cwnd: u32) -> u32 { |
| 344 | + if !TCP_FRIENDLINESS { |
| 345 | + return cnt; |
| 346 | + } |
| 347 | + |
| 348 | + // Estimate cwnd of normal TCP. |
| 349 | + // cwnd/3 * (1 + BETA)/(1 - BETA) |
| 350 | + let delta = (cwnd * BETA_SCALE) >> 3; |
| 351 | + // W__tcp(t) = W__tcp(t__0) + (acks(t) - acks(t__0)) / delta |
| 352 | + while self.ack_cnt > delta { |
| 353 | + self.ack_cnt -= delta; |
| 354 | + self.tcp_cwnd += 1; |
| 355 | + } |
| 356 | + |
| 357 | + //TODO: remove |
| 358 | + pr_info!( |
| 359 | + "tcp_friendliness: tcp_cwnd {}, cwnd {}, start {}us", |
| 360 | + self.tcp_cwnd, |
| 361 | + cwnd, |
| 362 | + self.start_time, |
| 363 | + ); |
| 364 | + |
| 365 | + // We are slower than normal TCP. |
| 366 | + if self.tcp_cwnd > cwnd { |
| 367 | + let delta = self.tcp_cwnd - cwnd; |
| 368 | + |
| 369 | + min(cnt, cwnd / delta) |
| 370 | + } else { |
| 371 | + cnt |
| 372 | + } |
| 373 | + } |
| 374 | + |
| 375 | + /// Returns the new value of `cnt` to keep the window grow on the cubic |
| 376 | + /// curve. |
| 377 | + fn update(&mut self, cwnd: u32, acked: u32) -> NonZeroU32 { |
| 378 | + let now: time::Msecs32 = time::ktime_get_boot_fast_ms32(); |
| 379 | + |
| 380 | + self.ack_cnt += acked; |
| 381 | + |
| 382 | + if self.last_cwnd == cwnd && now.wrapping_sub(self.last_time) <= time::MSEC_PER_SEC / 32 { |
| 383 | + return self.cnt; |
| 384 | + } |
| 385 | + |
| 386 | + // We can update the CUBIC function at most once every ms. |
| 387 | + if self.epoch_start.is_some() && now == self.last_time { |
| 388 | + let cnt = self.tcp_friendliness(self.cnt.get(), cwnd); |
| 389 | + |
| 390 | + // SAFETY: 2 != 0. QED. |
| 391 | + self.cnt = unsafe { NonZeroU32::new_unchecked(max(2, cnt)) }; |
| 392 | + |
| 393 | + return self.cnt; |
| 394 | + } |
| 395 | + |
| 396 | + self.last_cwnd = cwnd; |
| 397 | + self.last_time = now; |
| 398 | + |
| 399 | + if self.epoch_start.is_none() { |
| 400 | + self.epoch_start = Some(now); |
| 401 | + self.ack_cnt = acked; |
| 402 | + self.tcp_cwnd = cwnd; |
| 403 | + |
| 404 | + if self.last_max_cwnd <= cwnd { |
| 405 | + self.K = 0; |
| 406 | + self.origin_point = cwnd; |
| 407 | + } else { |
| 408 | + // K = (SRTT/C * (W__max - cwnd))^1/3 |
| 409 | + self.K = cubic_root(CUBE_FACTOR * ((self.last_max_cwnd - cwnd) as u64)); |
| 410 | + self.origin_point = self.last_max_cwnd; |
| 411 | + } |
| 412 | + } |
| 413 | + |
| 414 | + // PANIC: This is always `Some`. |
| 415 | + let epoch_start: time::Msecs32 = self.epoch_start.unwrap(); |
| 416 | + let Some(delay_min) = self.hystart_state.delay_min else { |
| 417 | + pr_err!("update: delay_min was None"); |
| 418 | + return self.cnt; |
| 419 | + }; |
| 420 | + |
| 421 | + // NOTE: Addition might overflow after 50 days without a loss, C uses a |
| 422 | + // `u64` here. |
| 423 | + let t: time::Msecs32 = |
| 424 | + now.wrapping_sub(epoch_start) + (delay_min / (time::USEC_PER_MSEC as time::Usecs32)); |
| 425 | + let offs: time::Msecs32 = if t < self.K { self.K - t } else { t - self.K }; |
| 426 | + |
| 427 | + // Calculate c/rtt * (t-K)^3 and change units to seconds. |
| 428 | + // Widen type to prevent overflow. |
| 429 | + let offs = offs as u64; |
| 430 | + let delta = (((CUBE_RTT_SCALE as u64 * offs * offs * offs) >> 10) / 1_000_000_000) as u32; |
| 431 | + // Calculate the full cubic function c/rtt * (t - K)^3 + W__max. |
| 432 | + let target = if t < self.K { |
| 433 | + self.origin_point - delta |
| 434 | + } else { |
| 435 | + self.origin_point + delta |
| 436 | + }; |
| 437 | + |
| 438 | + // TODO: remove |
| 439 | + pr_info!( |
| 440 | + "update: now {}ms, epoch_start {}ms, t {}ms, K {}ms, |t - K| {}ms, last_max_cwnd {}, origin_point {}, target {}, start {}us", |
| 441 | + now, |
| 442 | + epoch_start, |
| 443 | + t, |
| 444 | + self.K, |
| 445 | + offs, |
| 446 | + self.last_max_cwnd, |
| 447 | + self.origin_point, |
| 448 | + target, |
| 449 | + self.start_time, |
| 450 | + ); |
| 451 | + |
| 452 | + let mut cnt = if target > cwnd { |
| 453 | + cwnd / (target - cwnd) |
| 454 | + } else { |
| 455 | + // Effectively keeps cwnd constant for the next RTT. |
| 456 | + 100 * cwnd |
| 457 | + }; |
| 458 | + |
| 459 | + // In initial epoch or after timeout we grow at a minimum rate. |
| 460 | + if self.last_max_cwnd == 0 { |
| 461 | + cnt = min(cnt, 20); |
| 462 | + } |
| 463 | + |
| 464 | + // SAFETY: 2 != 0. QED. |
| 465 | + self.cnt = unsafe { NonZeroU32::new_unchecked(max(2, self.tcp_friendliness(cnt, cwnd))) }; |
| 466 | + |
| 467 | + self.cnt |
| 468 | + } |
| 469 | + |
| 470 | + fn reset(&mut self) { |
| 471 | + // TODO: remove |
| 472 | + let tmp = self.start_time; |
| 473 | + |
| 474 | + *self = Self::default(); |
| 475 | + |
| 476 | + // TODO: remove |
| 477 | + self.start_time = tmp; |
| 478 | + } |
| 479 | +} |
| 480 | + |
| 481 | +/// Calculate the cubic root of `a` using a table lookup followed by one |
| 482 | +/// Newton-Raphson iteration. |
| 483 | +// E[ |(cubic_root(x) - x.cbrt()) / x.cbrt()| ] = 0.71% for x in 1..1_000_000. |
| 484 | +// E[ |(cubic_root(x) - x.cbrt()) / x.cbrt()| ] = 8.87% for x in 1..63. |
| 485 | +// Where everything is `f64` and `.cbrt` is Rust's builtin. No overflow panics |
| 486 | +// in this domain. |
| 487 | +const fn cubic_root(a: u64) -> u32 { |
| 488 | + const V: [u8; 64] = [ |
| 489 | + 0, 54, 54, 54, 118, 118, 118, 118, 123, 129, 134, 138, 143, 147, 151, 156, 157, 161, 164, |
| 490 | + 168, 170, 173, 176, 179, 181, 185, 187, 190, 192, 194, 197, 199, 200, 202, 204, 206, 209, |
| 491 | + 211, 213, 215, 217, 219, 221, 222, 224, 225, 227, 229, 231, 232, 234, 236, 237, 239, 240, |
| 492 | + 242, 244, 245, 246, 248, 250, 251, 252, 254, |
| 493 | + ]; |
| 494 | + |
| 495 | + let mut b = fls64(a) as u32; |
| 496 | + if b < 7 { |
| 497 | + return ((V[a as usize] as u32) + 35) >> 6; |
| 498 | + } |
| 499 | + |
| 500 | + b = ((b * 84) >> 8) - 1; |
| 501 | + let shift = a >> (b * 3); |
| 502 | + |
| 503 | + let mut x = (((V[shift as usize] as u32) + 10) << b) >> 6; |
| 504 | + x = 2 * x + (a / ((x * (x - 1)) as u64)) as u32; |
| 505 | + |
| 506 | + (x * 341) >> 10 |
| 507 | +} |
| 508 | + |
| 509 | +/// Find last set bit in a 64-bit word. |
| 510 | +/// |
| 511 | +/// The last (most significant) bit is at position 64. |
| 512 | +#[inline] |
| 513 | +const fn fls64(x: u64) -> u8 { |
| 514 | + (64 - x.leading_zeros()) as u8 |
| 515 | +} |
0 commit comments