代码 tiny_fp8.hpp /**
* @file tiny_fp8.hpp
* @brief Software FP8 implementation for tiny_ai.
* Supports OCP FP8 E4M3FN (weights/activations) and E5M2 (gradients).
* ESP32-S3 has no FP8 hardware; values are stored as uint8_t and
* upcasted to float32 for all arithmetic.
*/
#pragma once
#include "tiny_quant_config.h"
#include <stdint.h>
#ifdef __cplusplus
#include <cstring>
#include <cmath>
namespace tiny
{
uint8_t fp32_to_fp8_e4m3 ( float val );
float fp8_e4m3_to_fp32 ( uint8_t fp8 );
void fp32_to_fp8_e4m3_batch ( const float * src , uint8_t * dst , int n );
void fp8_e4m3_to_fp32_batch ( const uint8_t * src , float * dst , int n );
uint8_t fp32_to_fp8_e5m2 ( float val );
float fp8_e5m2_to_fp32 ( uint8_t fp8 );
void fp32_to_fp8_e5m2_batch ( const float * src , uint8_t * dst , int n );
void fp8_e5m2_to_fp32_batch ( const uint8_t * src , float * dst , int n );
uint8_t fp32_to_fp8 ( float val , tiny_dtype_t dtype );
float fp8_to_fp32 ( uint8_t fp8 , tiny_dtype_t dtype );
void fp32_to_fp8_batch ( const float * src , uint8_t * dst , int n , tiny_dtype_t dtype );
void fp8_to_fp32_batch ( const uint8_t * src , float * dst , int n , tiny_dtype_t dtype );
} // namespace tiny
#endif // __cplusplus
tiny_fp8.cpp /**
* @file tiny_fp8.cpp
* @brief Software FP8 implementation — E4M3FN and E5M2 formats.
*/
#include "tiny_fp8.hpp"
#ifdef __cplusplus
#include <cstring>
#include <cmath>
#include <stdint.h>
namespace tiny
{
static inline uint32_t f32_bits ( float f ) { uint32_t u ; memcpy ( & u , & f , sizeof ( u )); return u ; }
static inline float bits_f32 ( uint32_t u ) { float f ; memcpy ( & f , & u , sizeof ( f )); return f ; }
static inline float f32_nan () { return bits_f32 ( 0x7FC00000u ); }
static inline float f32_inf () { return bits_f32 ( 0x7F800000u ); }
static inline float f32_ninf () { return bits_f32 ( 0xFF800000u ); }
// ============================================================================
// FP8 E4M3FN
// ============================================================================
uint8_t fp32_to_fp8_e4m3 ( float val )
{
if ( val != val ) return TINY_FP8_E4M3_NAN ;
uint32_t bits = f32_bits ( val );
uint8_t sign = ( uint8_t )(( bits >> 31 ) & 1u );
int exp = ( int )(( bits >> 23 ) & 0xFFu ) - 127 ;
uint32_t mant = bits & 0x7FFFFFu ;
if ( exp > 8 ) return ( sign << 7u ) | 0x7Eu ;
if ( exp < -9 ) return ( sign << 7u );
int new_exp = exp + 7 ;
if ( new_exp <= 0 )
{
uint32_t full_mant = ( mant | 0x800000u );
int shift = 21 + ( 1 - new_exp );
if ( shift >= 24 ) return ( sign << 7u );
uint8_t m3 = ( uint8_t )(( full_mant + ( 1u << ( shift - 1 ))) >> shift );
if ( m3 > 7u ) m3 = 7u ;
return ( uint8_t )(( sign << 7u ) | m3 );
}
uint32_t round_bit = 1u << 20 ;
uint32_t sticky = mant & ( round_bit - 1u );
uint8_t m3 = ( uint8_t )(( mant + round_bit ) >> 21 );
if ( m3 > 7u ) { m3 = 0u ; new_exp ++ ; }
if ( new_exp > 15 || ( new_exp == 15 && m3 == 7u ))
return ( uint8_t )(( sign << 7u ) | 0x7Eu );
( void ) sticky ;
return ( uint8_t )(( sign << 7u ) | (( uint8_t )( new_exp & 0xFu ) << 3u ) | ( m3 & 0x7u ));
}
float fp8_e4m3_to_fp32 ( uint8_t fp8 )
{
if (( fp8 & 0x7Fu ) == 0x7Fu ) return f32_nan ();
uint8_t sign = ( fp8 >> 7u ) & 1u ;
uint8_t exp4 = ( fp8 >> 3u ) & 0xFu ;
uint8_t mant = fp8 & 0x7u ;
float val ;
if ( exp4 == 0u )
val = ( float )( sign ? -1.0f : 1.0f ) * powf ( 2.0f , -6.0f ) * (( float ) mant / 8.0f );
else
val = ( float )( sign ? -1.0f : 1.0f ) * powf ( 2.0f , ( float )(( int ) exp4 - 7 )) * ( 1.0f + ( float ) mant / 8.0f );
return val ;
}
void fp32_to_fp8_e4m3_batch ( const float * src , uint8_t * dst , int n )
{
for ( int i = 0 ; i < n ; i ++ ) dst [ i ] = fp32_to_fp8_e4m3 ( src [ i ]);
}
void fp8_e4m3_to_fp32_batch ( const uint8_t * src , float * dst , int n )
{
for ( int i = 0 ; i < n ; i ++ ) dst [ i ] = fp8_e4m3_to_fp32 ( src [ i ]);
}
// ============================================================================
// FP8 E5M2
// ============================================================================
uint8_t fp32_to_fp8_e5m2 ( float val )
{
uint32_t bits = f32_bits ( val );
uint8_t sign = ( uint8_t )(( bits >> 31u ) & 1u );
int exp = ( int )(( bits >> 23u ) & 0xFFu ) - 127 ;
uint32_t mant = bits & 0x7FFFFFu ;
if ( val != val ) return ( uint8_t )(( sign << 7u ) | TINY_FP8_E5M2_NAN );
if (( bits & 0x7FFFFFFFu ) == 0x7F800000u )
return ( uint8_t )(( sign << 7u ) | TINY_FP8_E5M2_INF );
if ( exp > 15 ) return ( uint8_t )(( sign << 7u ) | TINY_FP8_E5M2_INF );
if ( exp < -16 ) return ( sign << 7u );
int new_exp = exp + 15 ;
if ( new_exp <= 0 )
{
uint32_t full_mant = ( mant | 0x800000u );
int shift = 22 + ( 1 - new_exp );
if ( shift >= 24 ) return ( sign << 7u );
uint8_t m2 = ( uint8_t )(( full_mant + ( 1u << ( shift - 1 ))) >> shift );
if ( m2 > 3u ) m2 = 3u ;
return ( uint8_t )(( sign << 7u ) | m2 );
}
uint32_t round_bit = 1u << 21 ;
uint8_t m2 = ( uint8_t )(( mant + round_bit ) >> 22 );
if ( m2 > 3u ) { m2 = 0u ; new_exp ++ ; }
if ( new_exp > 30 || ( new_exp == 30 && m2 == 3u ))
return ( uint8_t )(( sign << 7u ) | TINY_FP8_E5M2_INF );
return ( uint8_t )(( sign << 7u ) | (( uint8_t )( new_exp & 0x1Fu ) << 2u ) | ( m2 & 0x3u ));
}
float fp8_e5m2_to_fp32 ( uint8_t fp8 )
{
uint8_t sign = ( fp8 >> 7u ) & 1u ;
uint8_t exp5 = ( fp8 >> 2u ) & 0x1Fu ;
uint8_t mant = fp8 & 0x3u ;
if ( exp5 == 0x1Fu )
{
if ( mant == 0u ) return sign ? f32_ninf () : f32_inf ();
return f32_nan ();
}
float val ;
if ( exp5 == 0u )
val = ( sign ? -1.0f : 1.0f ) * powf ( 2.0f , -14.0f ) * (( float ) mant / 4.0f );
else
val = ( sign ? -1.0f : 1.0f ) * powf ( 2.0f , ( float )(( int ) exp5 - 15 )) * ( 1.0f + ( float ) mant / 4.0f );
return val ;
}
void fp32_to_fp8_e5m2_batch ( const float * src , uint8_t * dst , int n )
{
for ( int i = 0 ; i < n ; i ++ ) dst [ i ] = fp32_to_fp8_e5m2 ( src [ i ]);
}
void fp8_e5m2_to_fp32_batch ( const uint8_t * src , float * dst , int n )
{
for ( int i = 0 ; i < n ; i ++ ) dst [ i ] = fp8_e5m2_to_fp32 ( src [ i ]);
}
// ============================================================================
// Dispatch
// ============================================================================
uint8_t fp32_to_fp8 ( float val , tiny_dtype_t dtype )
{
if ( dtype == TINY_DTYPE_FP8_E5M2 ) return fp32_to_fp8_e5m2 ( val );
return fp32_to_fp8_e4m3 ( val );
}
float fp8_to_fp32 ( uint8_t fp8 , tiny_dtype_t dtype )
{
if ( dtype == TINY_DTYPE_FP8_E5M2 ) return fp8_e5m2_to_fp32 ( fp8 );
return fp8_e4m3_to_fp32 ( fp8 );
}
void fp32_to_fp8_batch ( const float * src , uint8_t * dst , int n , tiny_dtype_t dtype )
{
if ( dtype == TINY_DTYPE_FP8_E5M2 ) { fp32_to_fp8_e5m2_batch ( src , dst , n ); return ; }
fp32_to_fp8_e4m3_batch ( src , dst , n );
}
void fp8_to_fp32_batch ( const uint8_t * src , float * dst , int n , tiny_dtype_t dtype )
{
if ( dtype == TINY_DTYPE_FP8_E5M2 ) { fp8_e5m2_to_fp32_batch ( src , dst , n ); return ; }
fp8_e4m3_to_fp32_batch ( src , dst , n );
}
} // namespace tiny
#endif // __cplusplus