Code
tiny_quant.h(C 接口)
/**
* @file tiny_quant.h
* @brief C-compatible interface for tiny_ai quantisation utilities.
* Provides min-max calibration, INT8/INT16 quantise/dequantise,
* and a simple INT8 dense-layer forward pass.
*/
#pragma once
#include "tiny_quant_config.h"
#ifdef __cplusplus
extern "C"
{
#endif
tiny_error_t tiny_quant_calibrate_minmax(const float *data, int n,
tiny_dtype_t dtype,
tiny_quant_params_t *params);
tiny_error_t tiny_quant_f32_to_int8(const float *src, int8_t *dst, int n,
const tiny_quant_params_t *params);
tiny_error_t tiny_quant_int8_to_f32(const int8_t *src, float *dst, int n,
const tiny_quant_params_t *params);
tiny_error_t tiny_quant_f32_to_int16(const float *src, int16_t *dst, int n,
const tiny_quant_params_t *params);
tiny_error_t tiny_quant_int16_to_f32(const int16_t *src, float *dst, int n,
const tiny_quant_params_t *params);
tiny_error_t tiny_quant_dense_forward_int8(
const int8_t *input,
const int8_t *weight,
const int32_t *bias,
int8_t *output,
int batch, int in_feat, int out_feat,
float input_scale, float weight_scale, float output_scale);
#ifdef __cplusplus
}
#endif
tiny_quant.c(C 实现)
/**
* @file tiny_quant.c
* @brief C implementation of INT8/INT16 quantisation utilities.
*/
#include "tiny_quant.h"
#include <math.h>
#include <stdint.h>
static inline int32_t clamp_i32(int32_t v, int32_t lo, int32_t hi)
{
return v < lo ? lo : (v > hi ? hi : v);
}
tiny_error_t tiny_quant_calibrate_minmax(const float *data, int n,
tiny_dtype_t dtype,
tiny_quant_params_t *params)
{
if (!data || !params || n <= 0) return TINY_ERR_INVALID_ARG;
float val_max = 0.0f;
for (int i = 0; i < n; i++)
{
float a = fabsf(data[i]);
if (a > val_max) val_max = a;
}
if (val_max < TINY_MATH_MIN_DENOMINATOR) val_max = 1.0f;
params->dtype = dtype;
params->zero_point = 0;
if (dtype == TINY_DTYPE_INT8)
params->scale = val_max / (float)TINY_INT8_MAX;
else if (dtype == TINY_DTYPE_INT16)
params->scale = val_max / (float)TINY_INT16_MAX;
else
return TINY_ERR_AI_INVALID_DTYPE;
return TINY_OK;
}
tiny_error_t tiny_quant_f32_to_int8(const float *src, int8_t *dst, int n,
const tiny_quant_params_t *params)
{
if (!src || !dst || !params || n <= 0) return TINY_ERR_INVALID_ARG;
if (params->scale < TINY_MATH_MIN_DENOMINATOR) return TINY_ERR_AI_QUANT_FAILED;
float inv_scale = 1.0f / params->scale;
for (int i = 0; i < n; i++)
{
int32_t q = (int32_t)roundf(src[i] * inv_scale) + params->zero_point;
dst[i] = (int8_t)clamp_i32(q, TINY_INT8_MIN, TINY_INT8_MAX);
}
return TINY_OK;
}
tiny_error_t tiny_quant_int8_to_f32(const int8_t *src, float *dst, int n,
const tiny_quant_params_t *params)
{
if (!src || !dst || !params || n <= 0) return TINY_ERR_INVALID_ARG;
for (int i = 0; i < n; i++)
dst[i] = ((float)src[i] - (float)params->zero_point) * params->scale;
return TINY_OK;
}
tiny_error_t tiny_quant_f32_to_int16(const float *src, int16_t *dst, int n,
const tiny_quant_params_t *params)
{
if (!src || !dst || !params || n <= 0) return TINY_ERR_INVALID_ARG;
if (params->scale < TINY_MATH_MIN_DENOMINATOR) return TINY_ERR_AI_QUANT_FAILED;
float inv_scale = 1.0f / params->scale;
for (int i = 0; i < n; i++)
{
int32_t q = (int32_t)roundf(src[i] * inv_scale) + params->zero_point;
dst[i] = (int16_t)clamp_i32(q, TINY_INT16_MIN, TINY_INT16_MAX);
}
return TINY_OK;
}
tiny_error_t tiny_quant_int16_to_f32(const int16_t *src, float *dst, int n,
const tiny_quant_params_t *params)
{
if (!src || !dst || !params || n <= 0) return TINY_ERR_INVALID_ARG;
for (int i = 0; i < n; i++)
dst[i] = ((float)src[i] - (float)params->zero_point) * params->scale;
return TINY_OK;
}
tiny_error_t tiny_quant_dense_forward_int8(
const int8_t *input,
const int8_t *weight,
const int32_t *bias,
int8_t *output,
int batch, int in_feat, int out_feat,
float input_scale, float weight_scale, float output_scale)
{
if (!input || !weight || !output) return TINY_ERR_INVALID_ARG;
if (output_scale < TINY_MATH_MIN_DENOMINATOR) return TINY_ERR_AI_QUANT_FAILED;
float combined_scale = (input_scale * weight_scale) / output_scale;
for (int b = 0; b < batch; b++)
{
for (int o = 0; o < out_feat; o++)
{
int32_t acc = bias ? bias[o] : 0;
const int8_t *w_row = weight + o * in_feat;
const int8_t *x_row = input + b * in_feat;
for (int i = 0; i < in_feat; i++)
acc += (int32_t)x_row[i] * (int32_t)w_row[i];
int32_t q = (int32_t)roundf((float)acc * combined_scale);
output[b * out_feat + o] = (int8_t)clamp_i32(q, TINY_INT8_MIN, TINY_INT8_MAX);
}
}
return TINY_OK;
}
tiny_quant.hpp(C++ 接口)
/**
* @file tiny_quant.hpp
* @brief C++ quantisation utilities for tiny_ai.
*/
#pragma once
#include "tiny_quant_config.h"
#include "tiny_quant.h"
#include "tiny_fp8.hpp"
#ifdef __cplusplus
#include <stdint.h>
namespace tiny { class Tensor; }
namespace tiny
{
struct QuantParams
{
tiny_dtype_t dtype = TINY_DTYPE_INT8;
float scale = 1.0f;
int zero_point = 0;
QuantParams() = default;
QuantParams(tiny_dtype_t d, float s, int zp = 0) : dtype(d), scale(s), zero_point(zp) {}
tiny_quant_params_t to_c() const
{
tiny_quant_params_t p;
p.dtype = dtype;
p.scale = scale;
p.zero_point = zero_point;
return p;
}
};
QuantParams calibrate(const Tensor &t, tiny_dtype_t dtype = TINY_DTYPE_INT8);
tiny_error_t quantize (const Tensor &src, uint8_t *dst, const QuantParams ¶ms);
tiny_error_t dequantize(const uint8_t *src, Tensor &dst, const QuantParams ¶ms);
int8_t *quantize_weights(const Tensor &t, QuantParams ¶ms);
tiny_error_t requantize_int8(const int8_t *src, int8_t *dst, int n,
float src_scale, float dst_scale);
} // namespace tiny
#endif // __cplusplus
tiny_quant.cpp(C++ 实现)
/**
* @file tiny_quant.cpp
* @brief C++ quantisation utilities — Tensor-level calibration and PTQ helpers.
*/
#include "tiny_quant.hpp"
#include "tiny_tensor.hpp"
#ifdef __cplusplus
#include <cmath>
#include <cstring>
namespace tiny
{
QuantParams calibrate(const Tensor &t, tiny_dtype_t dtype)
{
QuantParams qp;
qp.dtype = dtype;
qp.zero_point = 0;
if (dtype == TINY_DTYPE_FP8_E4M3 || dtype == TINY_DTYPE_FP8_E5M2)
{
float max_val = (dtype == TINY_DTYPE_FP8_E4M3) ? TINY_FP8_E4M3_MAX : TINY_FP8_E5M2_MAX;
float abs_max = 0.0f;
for (int i = 0; i < t.size; i++)
{
float a = fabsf(t.data[i]);
if (a > abs_max) abs_max = a;
}
if (abs_max < TINY_MATH_MIN_DENOMINATOR) abs_max = 1.0f;
qp.scale = abs_max / max_val;
return qp;
}
tiny_quant_params_t cp;
tiny_quant_calibrate_minmax(t.data, t.size, dtype, &cp);
qp.scale = cp.scale;
qp.zero_point = cp.zero_point;
return qp;
}
tiny_error_t quantize(const Tensor &src, uint8_t *dst, const QuantParams ¶ms)
{
if (!dst) return TINY_ERR_AI_INVALID_SHAPE;
switch (params.dtype)
{
case TINY_DTYPE_INT8:
{
tiny_quant_params_t cp = params.to_c();
return tiny_quant_f32_to_int8(src.data, (int8_t *)dst, src.size, &cp);
}
case TINY_DTYPE_INT16:
{
tiny_quant_params_t cp = params.to_c();
return tiny_quant_f32_to_int16(src.data, (int16_t *)dst, src.size, &cp);
}
case TINY_DTYPE_FP8_E4M3:
fp32_to_fp8_e4m3_batch(src.data, dst, src.size); return TINY_OK;
case TINY_DTYPE_FP8_E5M2:
fp32_to_fp8_e5m2_batch(src.data, dst, src.size); return TINY_OK;
default:
return TINY_ERR_AI_INVALID_DTYPE;
}
}
tiny_error_t dequantize(const uint8_t *src, Tensor &dst, const QuantParams ¶ms)
{
if (!src) return TINY_ERR_AI_INVALID_SHAPE;
switch (params.dtype)
{
case TINY_DTYPE_INT8:
{
tiny_quant_params_t cp = params.to_c();
return tiny_quant_int8_to_f32((const int8_t *)src, dst.data, dst.size, &cp);
}
case TINY_DTYPE_INT16:
{
tiny_quant_params_t cp = params.to_c();
return tiny_quant_int16_to_f32((const int16_t *)src, dst.data, dst.size, &cp);
}
case TINY_DTYPE_FP8_E4M3:
fp8_e4m3_to_fp32_batch(src, dst.data, dst.size); return TINY_OK;
case TINY_DTYPE_FP8_E5M2:
fp8_e5m2_to_fp32_batch(src, dst.data, dst.size); return TINY_OK;
default:
return TINY_ERR_AI_INVALID_DTYPE;
}
}
int8_t *quantize_weights(const Tensor &t, QuantParams ¶ms)
{
params = calibrate(t, TINY_DTYPE_INT8);
int8_t *buf = (int8_t *)TINY_AI_MALLOC((size_t)t.size * sizeof(int8_t));
if (!buf) return nullptr;
tiny_quant_params_t cp = params.to_c();
tiny_quant_f32_to_int8(t.data, buf, t.size, &cp);
return buf;
}
tiny_error_t requantize_int8(const int8_t *src, int8_t *dst, int n,
float src_scale, float dst_scale)
{
if (!src || !dst || n <= 0) return TINY_ERR_INVALID_ARG;
if (dst_scale < TINY_MATH_MIN_DENOMINATOR) return TINY_ERR_AI_QUANT_FAILED;
float ratio = src_scale / dst_scale;
for (int i = 0; i < n; i++)
{
int32_t v = (int32_t)roundf((float)src[i] * ratio);
dst[i] = (int8_t)(v < TINY_INT8_MIN ? TINY_INT8_MIN : (v > TINY_INT8_MAX ? TINY_INT8_MAX : v));
}
return TINY_OK;
}
} // namespace tiny
#endif // __cplusplus