HIP: Heterogenous-computing Interface for Portability
amd_hip_bfloat16.h
Go to the documentation of this file.
1 
29 #ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_
30 #define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_
31 
32 #include "host_defines.h"
33 #if defined(__HIPCC_RTC__)
34  #define __HOST_DEVICE__ __device__
35 #else
36  #define __HOST_DEVICE__ __host__ __device__
37 #endif
38 
39 #if __cplusplus < 201103L || !defined(__HIPCC__)
40 
41 // If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
42 // include a minimal definition of hip_bfloat16
43 
44 #include <stdint.h>
46 typedef struct
47 {
48  uint16_t data;
49 } hip_bfloat16;
50 
51 #else // __cplusplus < 201103L || !defined(__HIPCC__)
52 
53 #include <hip/hip_runtime.h>
54 
55 #pragma clang diagnostic push
56 #pragma clang diagnostic ignored "-Wshadow"
57 struct hip_bfloat16
58 {
59  __hip_uint16_t data;
60 
61  enum truncate_t
62  {
63  truncate
64  };
65 
66  __HOST_DEVICE__ hip_bfloat16() = default;
67 
68  // round upper 16 bits of IEEE float to convert to bfloat16
69  explicit __HOST_DEVICE__ hip_bfloat16(float f)
70  : data(float_to_bfloat16(f))
71  {
72  }
73 
74  explicit __HOST_DEVICE__ hip_bfloat16(float f, truncate_t)
75  : data(truncate_float_to_bfloat16(f))
76  {
77  }
78 
79  // zero extend lower 16 bits of bfloat16 to convert to IEEE float
80  __HOST_DEVICE__ operator float() const
81  {
82  union
83  {
84  uint32_t int32;
85  float fp32;
86  } u = {uint32_t(data) << 16};
87  return u.fp32;
88  }
89 
90  __HOST_DEVICE__ hip_bfloat16 &operator=(const float& f)
91  {
92  data = float_to_bfloat16(f);
93  return *this;
94  }
95 
96  static __HOST_DEVICE__ hip_bfloat16 round_to_bfloat16(float f)
97  {
98  hip_bfloat16 output;
99  output.data = float_to_bfloat16(f);
100  return output;
101  }
102 
103  static __HOST_DEVICE__ hip_bfloat16 round_to_bfloat16(float f, truncate_t)
104  {
105  hip_bfloat16 output;
106  output.data = truncate_float_to_bfloat16(f);
107  return output;
108  }
109 
110 private:
111  static __HOST_DEVICE__ __hip_uint16_t float_to_bfloat16(float f)
112  {
113  union
114  {
115  float fp32;
116  uint32_t int32;
117  } u = {f};
118  if(~u.int32 & 0x7f800000)
119  {
120  // When the exponent bits are not all 1s, then the value is zero, normal,
121  // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
122  // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
123  // This causes the bfloat16's mantissa to be incremented by 1 if the 16
124  // least significant bits of the float mantissa are greater than 0x8000,
125  // or if they are equal to 0x8000 and the least significant bit of the
126  // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
127  // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
128  // has the value 0x7f, then incrementing it causes it to become 0x00 and
129  // the exponent is incremented by one, which is the next higher FP value
130  // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
131  // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
132  // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
133  // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
134  // incrementing it causes it to become an exponent of 0xFF and a mantissa
135  // of 0x00, which is Inf, the next higher value to the unrounded value.
136  u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
137  }
138  else if(u.int32 & 0xffff)
139  {
140  // When all of the exponent bits are 1, the value is Inf or NaN.
141  // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
142  // mantissa bit. Quiet NaN is indicated by the most significant mantissa
143  // bit being 1. Signaling NaN is indicated by the most significant
144  // mantissa bit being 0 but some other bit(s) being 1. If any of the
145  // lower 16 bits of the mantissa are 1, we set the least significant bit
146  // of the bfloat16 mantissa, in order to preserve signaling NaN in case
147  // the bloat16's mantissa bits are all 0.
148  u.int32 |= 0x10000; // Preserve signaling NaN
149  }
150  return __hip_uint16_t(u.int32 >> 16);
151  }
152 
153  // Truncate instead of rounding, preserving SNaN
154  static __HOST_DEVICE__ __hip_uint16_t truncate_float_to_bfloat16(float f)
155  {
156  union
157  {
158  float fp32;
159  uint32_t int32;
160  } u = {f};
161  return __hip_uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
162  }
163 };
164 #pragma clang diagnostic pop
165 
166 typedef struct
167 {
168  __hip_uint16_t data;
169 } hip_bfloat16_public;
170 
171 static_assert(__hip_internal::is_standard_layout<hip_bfloat16>{},
172  "hip_bfloat16 is not a standard layout type, and thus is "
173  "incompatible with C.");
174 
175 static_assert(__hip_internal::is_trivial<hip_bfloat16>{},
176  "hip_bfloat16 is not a trivial type, and thus is "
177  "incompatible with C.");
178 #if !defined(__HIPCC_RTC__)
179 static_assert(sizeof(hip_bfloat16) == sizeof(hip_bfloat16_public)
180  && offsetof(hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
181  "internal hip_bfloat16 does not match public hip_bfloat16");
182 
183 inline std::ostream& operator<<(std::ostream& os, const hip_bfloat16& bf16)
184 {
185  return os << float(bf16);
186 }
187 #endif
188 
189 inline __HOST_DEVICE__ hip_bfloat16 operator+(hip_bfloat16 a)
190 {
191  return a;
192 }
193 inline __HOST_DEVICE__ hip_bfloat16 operator-(hip_bfloat16 a)
194 {
195  a.data ^= 0x8000;
196  return a;
197 }
198 inline __HOST_DEVICE__ hip_bfloat16 operator+(hip_bfloat16 a, hip_bfloat16 b)
199 {
200  return hip_bfloat16(float(a) + float(b));
201 }
202 inline __HOST_DEVICE__ hip_bfloat16 operator-(hip_bfloat16 a, hip_bfloat16 b)
203 {
204  return hip_bfloat16(float(a) - float(b));
205 }
206 inline __HOST_DEVICE__ hip_bfloat16 operator*(hip_bfloat16 a, hip_bfloat16 b)
207 {
208  return hip_bfloat16(float(a) * float(b));
209 }
210 inline __HOST_DEVICE__ hip_bfloat16 operator/(hip_bfloat16 a, hip_bfloat16 b)
211 {
212  return hip_bfloat16(float(a) / float(b));
213 }
214 inline __HOST_DEVICE__ bool operator<(hip_bfloat16 a, hip_bfloat16 b)
215 {
216  return float(a) < float(b);
217 }
218 inline __HOST_DEVICE__ bool operator==(hip_bfloat16 a, hip_bfloat16 b)
219 {
220  return float(a) == float(b);
221 }
222 inline __HOST_DEVICE__ bool operator>(hip_bfloat16 a, hip_bfloat16 b)
223 {
224  return b < a;
225 }
226 inline __HOST_DEVICE__ bool operator<=(hip_bfloat16 a, hip_bfloat16 b)
227 {
228  return !(a > b);
229 }
230 inline __HOST_DEVICE__ bool operator!=(hip_bfloat16 a, hip_bfloat16 b)
231 {
232  return !(a == b);
233 }
234 inline __HOST_DEVICE__ bool operator>=(hip_bfloat16 a, hip_bfloat16 b)
235 {
236  return !(a < b);
237 }
238 inline __HOST_DEVICE__ hip_bfloat16& operator+=(hip_bfloat16& a, hip_bfloat16 b)
239 {
240  return a = a + b;
241 }
242 inline __HOST_DEVICE__ hip_bfloat16& operator-=(hip_bfloat16& a, hip_bfloat16 b)
243 {
244  return a = a - b;
245 }
246 inline __HOST_DEVICE__ hip_bfloat16& operator*=(hip_bfloat16& a, hip_bfloat16 b)
247 {
248  return a = a * b;
249 }
250 inline __HOST_DEVICE__ hip_bfloat16& operator/=(hip_bfloat16& a, hip_bfloat16 b)
251 {
252  return a = a / b;
253 }
254 inline __HOST_DEVICE__ hip_bfloat16& operator++(hip_bfloat16& a)
255 {
256  return a += hip_bfloat16(1.0f);
257 }
258 inline __HOST_DEVICE__ hip_bfloat16& operator--(hip_bfloat16& a)
259 {
260  return a -= hip_bfloat16(1.0f);
261 }
262 inline __HOST_DEVICE__ hip_bfloat16 operator++(hip_bfloat16& a, int)
263 {
264  hip_bfloat16 orig = a;
265  ++a;
266  return orig;
267 }
268 inline __HOST_DEVICE__ hip_bfloat16 operator--(hip_bfloat16& a, int)
269 {
270  hip_bfloat16 orig = a;
271  --a;
272  return orig;
273 }
274 
275 namespace std
276 {
277  constexpr __HOST_DEVICE__ bool isinf(hip_bfloat16 a)
278  {
279  return !(~a.data & 0x7f80) && !(a.data & 0x7f);
280  }
281  constexpr __HOST_DEVICE__ bool isnan(hip_bfloat16 a)
282  {
283  return !(~a.data & 0x7f80) && +(a.data & 0x7f);
284  }
285  constexpr __HOST_DEVICE__ bool iszero(hip_bfloat16 a)
286  {
287  return !(a.data & 0x7fff);
288  }
289 }
290 
291 #endif // __cplusplus < 201103L || !defined(__HIPCC__)
292 
293 #endif // _HIP_BFLOAT16_H_
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16 &l)
Operator to unary+ on a __hip_bfloat16 number.
Definition: amd_hip_bf16.h:835
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & operator/=(__hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to divide-assign two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:930
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator/(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to divide two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:921
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & operator*=(__hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to multiply-assign two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:826
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16 &l)
Operator to negate a __hip_bfloat16 number.
Definition: amd_hip_bf16.h:850
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & operator+=(__hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to add-assign two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:903
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator*(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to multiply two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:817
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator++(__hip_bfloat16 &l, const int)
Operator to post increment a __hip_bfloat16 number.
Definition: amd_hip_bf16.h:865
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator--(__hip_bfloat16 &l, const int)
Operator to post decrement a __hip_bfloat16 number.
Definition: amd_hip_bf16.h:884
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & operator-=(__hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to subtract-assign two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:912
__BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform an equal compare on two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:1463
__BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a not equal on two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:1471
__BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a greater than on two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:1495
__BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a less than equal on two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:1487
__BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a less than on two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:1479
__BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a greater than equal on two __hip_bfloat16 numbers.
Definition: amd_hip_bf16.h:1503
Struct to represent a 16 bit brain floating point number.
Definition: amd_hip_bfloat16.h:47