Skip to content

Commit 95abb0e

Browse files
authored
libm: Improved integer utilities, implement shifts and bug fixes for i256 and u256
`i256` and `u256` - operators now use the same overflow convention as primitives - implement `<<` and `-` (previously just `>>` and `+`) - implement `Ord` correctly (the previous `PartialOrd` was broken) - correct `i256::SIGNED` to `true` The `Int`-trait is extended with `trailing_zeros`, `carrying_add`, and `borrowing_sub`.
1 parent cc53499 commit 95abb0e

File tree

5 files changed

+223
-60
lines changed

5 files changed

+223
-60
lines changed

libm-test/benches/icount.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,22 @@ fn icount_bench_u256_add(cases: Vec<(u256, u256)>) {
119119
}
120120
}
121121

122+
#[library_benchmark]
123+
#[bench::linspace(setup_u256_add())]
124+
fn icount_bench_u256_sub(cases: Vec<(u256, u256)>) {
125+
for (x, y) in cases.iter().copied() {
126+
black_box(black_box(x) - black_box(y));
127+
}
128+
}
129+
130+
#[library_benchmark]
131+
#[bench::linspace(setup_u256_shift())]
132+
fn icount_bench_u256_shl(cases: Vec<(u256, u32)>) {
133+
for (x, y) in cases.iter().copied() {
134+
black_box(black_box(x) << black_box(y));
135+
}
136+
}
137+
122138
#[library_benchmark]
123139
#[bench::linspace(setup_u256_shift())]
124140
fn icount_bench_u256_shr(cases: Vec<(u256, u32)>) {
@@ -129,7 +145,7 @@ fn icount_bench_u256_shr(cases: Vec<(u256, u32)>) {
129145

130146
library_benchmark_group!(
131147
name = icount_bench_u128_group;
132-
benchmarks = icount_bench_u128_widen_mul, icount_bench_u256_add, icount_bench_u256_shr
148+
benchmarks = icount_bench_u128_widen_mul, icount_bench_u256_add, icount_bench_u256_sub, icount_bench_u256_shl, icount_bench_u256_shr
133149
);
134150

135151
#[library_benchmark]

libm-test/tests/u256.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,62 @@ fn mp_u256_add() {
111111
let y = random_u256(&mut rng);
112112
assign_bigint(&mut bx, x);
113113
assign_bigint(&mut by, y);
114-
let actual = x + y;
114+
let actual = if u256::MAX - x >= y {
115+
x + y
116+
} else {
117+
// otherwise (u256::MAX - x) < y, so the wrapped result is
118+
// (x + y) - (u256::MAX + 1) == y - (u256::MAX - x) - 1
119+
y - (u256::MAX - x) - 1_u128.widen()
120+
};
115121
bx += &by;
116122
check_one(|| hexu(x), || Some(hexu(y)), actual, &mut bx);
117123
}
118124
}
119125

126+
#[test]
127+
fn mp_u256_sub() {
128+
let mut rng = ChaCha8Rng::from_seed(*SEED);
129+
let mut bx = BigInt::new();
130+
let mut by = BigInt::new();
131+
132+
for _ in 0..bigint_fuzz_iteration_count() {
133+
let x = random_u256(&mut rng);
134+
let y = random_u256(&mut rng);
135+
assign_bigint(&mut bx, x);
136+
assign_bigint(&mut by, y);
137+
138+
// since the operators (may) panic on overflow,
139+
// we should test something that doesn't
140+
let actual = if x >= y { x - y } else { y - x };
141+
bx -= &by;
142+
bx.abs_mut();
143+
check_one(|| hexu(x), || Some(hexu(y)), actual, &mut bx);
144+
}
145+
}
146+
147+
#[test]
148+
fn mp_u256_shl() {
149+
let mut rng = ChaCha8Rng::from_seed(*SEED);
150+
let mut bx = BigInt::new();
151+
152+
for _ in 0..bigint_fuzz_iteration_count() {
153+
let x = random_u256(&mut rng);
154+
let shift: u32 = rng.random_range(0..256);
155+
assign_bigint(&mut bx, x);
156+
let actual = x << shift;
157+
bx <<= shift;
158+
check_one(|| hexu(x), || Some(shift.to_string()), actual, &mut bx);
159+
}
160+
}
161+
120162
#[test]
121163
fn mp_u256_shr() {
122164
let mut rng = ChaCha8Rng::from_seed(*SEED);
123165
let mut bx = BigInt::new();
124166

125167
for _ in 0..bigint_fuzz_iteration_count() {
126168
let x = random_u256(&mut rng);
127-
let shift: u32 = rng.random_range(0..255);
169+
let shift: u32 = rng.random_range(0..256);
128170
assign_bigint(&mut bx, x);
129171
let actual = x >> shift;
130172
bx >>= shift;

libm/src/math/support/big.rs

Lines changed: 79 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ const U128_LO_MASK: u128 = u64::MAX as u128;
1111

1212
/// A 256-bit unsigned integer represented as two 128-bit native-endian limbs.
1313
#[allow(non_camel_case_types)]
14-
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
14+
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Eq, Ord)]
1515
pub struct u256 {
16-
pub lo: u128,
1716
pub hi: u128,
17+
pub lo: u128,
1818
}
1919

2020
impl u256 {
@@ -28,17 +28,17 @@ impl u256 {
2828
pub fn signed(self) -> i256 {
2929
i256 {
3030
lo: self.lo,
31-
hi: self.hi,
31+
hi: self.hi as i128,
3232
}
3333
}
3434
}
3535

3636
/// A 256-bit signed integer represented as two 128-bit native-endian limbs.
3737
#[allow(non_camel_case_types)]
38-
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
38+
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Eq, Ord)]
3939
pub struct i256 {
40+
pub hi: i128,
4041
pub lo: u128,
41-
pub hi: u128,
4242
}
4343

4444
impl i256 {
@@ -47,7 +47,7 @@ impl i256 {
4747
pub fn unsigned(self) -> u256 {
4848
u256 {
4949
lo: self.lo,
50-
hi: self.hi,
50+
hi: self.hi as u128,
5151
}
5252
}
5353
}
@@ -73,17 +73,17 @@ impl MinInt for i256 {
7373

7474
type Unsigned = u256;
7575

76-
const SIGNED: bool = false;
76+
const SIGNED: bool = true;
7777
const BITS: u32 = 256;
7878
const ZERO: Self = Self { lo: 0, hi: 0 };
7979
const ONE: Self = Self { lo: 1, hi: 0 };
8080
const MIN: Self = Self {
81-
lo: 0,
82-
hi: 1 << 127,
81+
lo: u128::MIN,
82+
hi: i128::MIN,
8383
};
8484
const MAX: Self = Self {
8585
lo: u128::MAX,
86-
hi: u128::MAX >> 1,
86+
hi: i128::MAX,
8787
};
8888
}
8989

@@ -109,60 +109,86 @@ macro_rules! impl_common {
109109
}
110110
}
111111

112-
impl ops::Shl<u32> for $ty {
112+
impl ops::Add<Self> for $ty {
113113
type Output = Self;
114114

115-
fn shl(self, _rhs: u32) -> Self::Output {
116-
unimplemented!("only used to meet trait bounds")
115+
fn add(self, rhs: Self) -> Self::Output {
116+
let (lo, carry) = self.lo.overflowing_add(rhs.lo);
117+
let (hi, of) = Int::carrying_add(self.hi, rhs.hi, carry);
118+
debug_assert!(!of, "attempt to add with overflow");
119+
Self { lo, hi }
117120
}
118121
}
119-
};
120-
}
121122

122-
impl_common!(i256);
123-
impl_common!(u256);
123+
impl ops::Sub<Self> for $ty {
124+
type Output = Self;
124125

125-
impl ops::Add<Self> for u256 {
126-
type Output = Self;
126+
fn sub(self, rhs: Self) -> Self::Output {
127+
let (lo, borrow) = self.lo.overflowing_sub(rhs.lo);
128+
let (hi, of) = Int::borrowing_sub(self.hi, rhs.hi, borrow);
129+
debug_assert!(!of, "attempt to subtract with overflow");
130+
Self { lo, hi }
131+
}
132+
}
127133

128-
fn add(self, rhs: Self) -> Self::Output {
129-
let (lo, carry) = self.lo.overflowing_add(rhs.lo);
130-
let hi = self.hi.wrapping_add(carry as u128).wrapping_add(rhs.hi);
134+
impl ops::Shl<u32> for $ty {
135+
type Output = Self;
131136

132-
Self { lo, hi }
133-
}
134-
}
137+
fn shl(mut self, rhs: u32) -> Self::Output {
138+
debug_assert!(rhs < Self::BITS, "attempt to shift left with overflow");
135139

136-
impl ops::Shr<u32> for u256 {
137-
type Output = Self;
140+
let half_bits = Self::BITS / 2;
141+
let low_mask = half_bits - 1;
142+
let s = rhs & low_mask;
138143

139-
fn shr(mut self, rhs: u32) -> Self::Output {
140-
debug_assert!(rhs < Self::BITS, "attempted to shift right with overflow");
141-
if rhs >= Self::BITS {
142-
return Self::ZERO;
143-
}
144+
let lo = self.lo;
145+
let hi = self.hi;
144146

145-
if rhs == 0 {
146-
return self;
147-
}
147+
self.lo = lo << s;
148148

149-
if rhs < 128 {
150-
self.lo >>= rhs;
151-
self.lo |= self.hi << (128 - rhs);
152-
} else {
153-
self.lo = self.hi >> (rhs - 128);
149+
if rhs & half_bits == 0 {
150+
self.hi = (lo >> (low_mask ^ s) >> 1) as _;
151+
self.hi |= hi << s;
152+
} else {
153+
self.hi = self.lo as _;
154+
self.lo = 0;
155+
}
156+
self
157+
}
154158
}
155159

156-
if rhs < 128 {
157-
self.hi >>= rhs;
158-
} else {
159-
self.hi = 0;
160-
}
160+
impl ops::Shr<u32> for $ty {
161+
type Output = Self;
161162

162-
self
163-
}
163+
fn shr(mut self, rhs: u32) -> Self::Output {
164+
debug_assert!(rhs < Self::BITS, "attempt to shift right with overflow");
165+
166+
let half_bits = Self::BITS / 2;
167+
let low_mask = half_bits - 1;
168+
let s = rhs & low_mask;
169+
170+
let lo = self.lo;
171+
let hi = self.hi;
172+
173+
self.hi = hi >> s;
174+
175+
#[allow(unused_comparisons)]
176+
if rhs & half_bits == 0 {
177+
self.lo = (hi << (low_mask ^ s) << 1) as _;
178+
self.lo |= lo >> s;
179+
} else {
180+
self.lo = self.hi as _;
181+
self.hi = if hi < 0 { !0 } else { 0 };
182+
}
183+
self
184+
}
185+
}
186+
};
164187
}
165188

189+
impl_common!(i256);
190+
impl_common!(u256);
191+
166192
impl HInt for u128 {
167193
type D = u256;
168194

@@ -200,19 +226,18 @@ impl HInt for u128 {
200226
}
201227

202228
fn widen_hi(self) -> Self::D {
203-
self.widen() << <Self as MinInt>::BITS
229+
u256 { lo: 0, hi: self }
204230
}
205231
}
206232

207233
impl HInt for i128 {
208234
type D = i256;
209235

210236
fn widen(self) -> Self::D {
211-
let mut ret = self.unsigned().zero_widen().signed();
212-
if self.is_negative() {
213-
ret.hi = u128::MAX;
237+
i256 {
238+
lo: self as u128,
239+
hi: if self < 0 { -1 } else { 0 },
214240
}
215-
ret
216241
}
217242

218243
fn zero_widen(self) -> Self::D {
@@ -228,7 +253,7 @@ impl HInt for i128 {
228253
}
229254

230255
fn widen_hi(self) -> Self::D {
231-
self.widen() << <Self as MinInt>::BITS
256+
i256 { lo: 0, hi: self }
232257
}
233258
}
234259

@@ -252,6 +277,6 @@ impl DInt for i256 {
252277
}
253278

254279
fn hi(self) -> Self::H {
255-
self.hi as i128
280+
self.hi
256281
}
257282
}

0 commit comments

Comments
 (0)