Draft NEON that might work
[rust-lightning] / lightning / src / util / simd_f32.rs
1 #[cfg(not(target_feature = "sse"))]
2 mod non_simd {
3         #[derive(Clone, Copy)]
4         pub(crate) struct FourF32(f32, f32, f32, f32);
5         impl FourF32 {
6                 #[inline(always)]
7                 pub(crate) fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
8                         Self(a, b, c, d)
9                 }
10                 #[inline(always)]
11                 pub(crate) fn from_ints(a: u16, b: u16, c: u16, d: u16) -> Self {
12                         Self(a as f32, b as f32, c as f32, d as f32)
13                 }
14                 #[inline(always)]
15                 pub(crate) fn hsub(&self) -> Self {
16                         // _mm_hsub_ps with the second argument zeros
17                         Self(self.1 - self.0, self.3 - self.2, 0.0, 0.0)
18                 }
19                 #[inline(always)]
20                 pub(crate) fn consuming_sum(&self) -> f32 {
21                         self.0 + self.1 + self.2 + self.3
22                 }
23                 #[inline(always)]
24                 pub(crate) fn dump(self) -> (f32, f32, f32, f32) {
25                         (self.3, self.2, self.1, self.0)
26                 }
27         }
28         impl core::ops::Div<FourF32> for FourF32 {
29                 type Output = FourF32;
30                 #[inline(always)]
31                 fn div(self, o: FourF32) -> Self {
32                         Self(self.0 / o.0, self.1 / o.1, self.2 / o.2, self.3 / o.3)
33                 }
34         }
35         impl core::ops::Mul<FourF32> for FourF32 {
36                 type Output = FourF32;
37                 #[inline(always)]
38                 fn mul(self, o: FourF32) -> Self {
39                         Self(self.0 * o.0, self.1 * o.1, self.2 * o.2, self.3 * o.3)
40                 }
41         }
42         impl core::ops::Add<FourF32> for FourF32 {
43                 type Output = FourF32;
44                 #[inline(always)]
45                 fn add(self, o: FourF32) -> Self {
46                         Self(self.0 + o.0, self.1 + o.1, self.2 + o.2, self.3 + o.3)
47                 }
48         }
49         impl core::ops::Sub<FourF32> for FourF32 {
50                 type Output = FourF32;
51                 #[inline(always)]
52                 fn sub(self, o: FourF32) -> Self {
53                         Self(self.0 - o.0, self.1 - o.1, self.2 - o.2, self.3 - o.3)
54                 }
55         }
56 }
57 #[cfg(not(target_feature = "sse"))]
58 pub(crate) use non_simd::*;
59
60 #[cfg(target_feature = "sse")]
61 mod x86_sse {
62         #[cfg(target_arch = "x86")]
63         use core::arch::x86::*;
64         #[cfg(target_arch = "x86_64")]
65         use core::arch::x86_64::*;
66
67         #[repr(align(16))]
68         struct AlignedFloats([f32; 4]);
69
70         #[derive(Clone, Copy)]
71         pub(crate) struct FourF32(__m128);
72         impl FourF32 {
73                 #[inline(always)]
74                 pub(crate) fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
75                         Self(unsafe { _mm_set_ps(a, b, c, d) })
76                 }
77                 #[inline(always)]
78                 pub(crate) fn from_ints(a: u16, b: u16, c: u16, d: u16) -> Self {
79                         unsafe {
80                                 let ints =_mm_set_epi32(a as i32, b as i32, c as i32, d as i32);
81                                 Self(_mm_cvtepi32_ps(ints))
82                         }
83                 }
84                 #[inline(always)]
85                 pub(crate) fn hsub(&self) -> Self {
86                         let dummy = unsafe { _mm_setzero_ps() };
87                         Self(unsafe { _mm_hsub_ps(self.0, dummy) })
88                 }
89                 #[inline(always)]
90                 pub(crate) fn consuming_sum(self) -> f32 {
91                         let im = unsafe {
92                                 let dummy = _mm_setzero_ps();
93                                 Self(_mm_hadd_ps(self.0, dummy))
94                         };
95                         let res = im.dump();
96                         res.2 + res.3
97                 }
98                 #[inline(always)]
99                 pub(crate) fn dump(self) -> (f32, f32, f32, f32) {
100                         let mut res = AlignedFloats([0.0; 4]);
101                         unsafe { _mm_store_ps(&mut res.0[0], self.0) };
102                         (res.0[3], res.0[2], res.0[1], res.0[0])
103                 }
104         }
105         impl core::ops::Div<FourF32> for FourF32 {
106                 type Output = FourF32;
107                 #[inline(always)]
108                 fn div(self, o: FourF32) -> Self {
109                         Self(unsafe { _mm_div_ps(self.0, o.0) })
110                 }
111         }
112         impl core::ops::Mul<FourF32> for FourF32 {
113                 type Output = FourF32;
114                 #[inline(always)]
115                 fn mul(self, o: FourF32) -> Self {
116                         Self(unsafe { _mm_mul_ps(self.0, o.0) })
117                 }
118         }
119         impl core::ops::Add<FourF32> for FourF32 {
120                 type Output = FourF32;
121                 #[inline(always)]
122                 fn add(self, o: FourF32) -> Self {
123                         Self(unsafe { _mm_add_ps(self.0, o.0) })
124                 }
125         }
126         impl core::ops::Sub<FourF32> for FourF32 {
127                 type Output = FourF32;
128                 #[inline(always)]
129                 fn sub(self, o: FourF32) -> Self {
130                         Self(unsafe { _mm_sub_ps(self.0, o.0) })
131                 }
132         }
133 }
134 #[cfg(target_feature = "sse")]
135 pub(crate) use x86_sse::*;
136
137 #[cfg(all(target_feature = "neon", target_arch = "aarch64"))]
138 mod aarch64_neon {
139         use core::arch::aarch64::*;
140
141         // Not actualy clear if the relevant instructions require alignment, but there's no harm in it
142         // and it may improve performance.
143         #[repr(align(16))]
144         struct AlignedFloats([f32; 4]);
145         #[repr(align(16))]
146         struct AlignedInts([u32; 4]);
147
148         #[derive(Clone, Copy)]
149         pub(crate) struct FourF32(float32x4_t);
150         impl FourF32 {
151                 #[inline(always)]
152                 pub(crate) fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
153                         let data = AlignedFloats([a, b, c, d]);
154                         Self(unsafe { vld1q_f32(&data.0[0]) })
155                 }
156                 #[inline(always)]
157                 pub(crate) fn from_ints(a: u16, b: u16, c: u16, d: u16) -> Self {
158                         let data = AlignedInts([a as u32, b as u32, c as u32, d as u32]);
159                         let ints = unsafe { vld1q_u32(&data.0[0]) };
160                         Self(unsafe { vcvtq_f32_u32(ints) })
161                 }
162                 #[inline(always)]
163                 pub(crate) fn hsub(&self) -> Self {
164                         let dummy = Self::new(0.0, 0.0, 0.0, 0.0).0; // XXX: There has to be a faster way
165                         Self(unsafe { vpaddq_f32(self.0, dummy) })
166                 }
167                 #[inline(always)]
168                 pub(crate) fn consuming_sum(self) -> f32 {
169                         unsafe { vaddvq_f32(self.0) }
170                 }
171                 #[inline(always)]
172                 pub(crate) fn dump(self) -> (f32, f32, f32, f32) {
173                         let mut res = AlignedFloats([0.0; 4]);
174                         unsafe { vst1q_f32(&mut res.0[0], self.0) };
175                         (res.0[3], res.0[2], res.0[1], res.0[0])
176                 }
177         }
178         impl core::ops::Div<FourF32> for FourF32 {
179                 type Output = FourF32;
180                 #[inline(always)]
181                 fn div(self, o: FourF32) -> Self {
182                         Self(unsafe { vdivq_f32(self.0, o.0) })
183                 }
184         }
185         impl core::ops::Mul<FourF32> for FourF32 {
186                 type Output = FourF32;
187                 #[inline(always)]
188                 fn mul(self, o: FourF32) -> Self {
189                         Self(unsafe { vmulq_f32(self.0, o.0) })
190                 }
191         }
192         impl core::ops::Add<FourF32> for FourF32 {
193                 type Output = FourF32;
194                 #[inline(always)]
195                 fn add(self, o: FourF32) -> Self {
196                         Self(unsafe { vaddq_f32(self.0, o.0) })
197                 }
198         }
199         impl core::ops::Sub<FourF32> for FourF32 {
200                 type Output = FourF32;
201                 #[inline(always)]
202                 fn sub(self, o: FourF32) -> Self {
203                         Self(unsafe { vsubq_f32(self.0, o.0) })
204                 }
205         }
206 }
207 #[cfg(all(target_feature = "neon", target_arch = "aarch64"))]
208 pub(crate) use aarch64_neon::*;