/**
* @file example_attention.cpp
* @author SHUAIWEN CUI (SHUAIWEN001@e.ntu.edu.sg)
* @brief Multi-head self-attention example for tiny_ai.
*
* Dataset : Iris (150 samples, 4 features, 3 classes — from iris_data.hpp)
* Model : Dense(4,32)+ReLU → reshape [B,32]→[B,4,8]
* → Attention(embed_dim=8, heads=2) → GlobalAvgPool
* → Dense(8,3) [raw logits — cross-entropy handles softmax]
* Training: Adam, 100 epochs, batch=16, cross-entropy
*
* @version 1.0
* @date 2025-05-01
* @copyright Copyright (c) 2025
*/
#include "tiny_ai.h"
#include "iris_data.hpp"
#include <cstdio>
#include <cstdlib>
// ESP-IDF task watchdog feeding for long-running training loops
#if defined(ESP_PLATFORM) || defined(IDF_VER) || defined(ESP32)
#include "esp_task_wdt.h"
#endif
#ifdef __cplusplus
using namespace tiny;
using namespace tiny_data;
// ============================================================================
// Entry point
// ============================================================================
void example_attention(void)
{
printf("\n");
printf("========================================\n");
printf(" tiny_ai | Attention Example (Iris)\n");
printf("========================================\n");
// ---- Dataset ----
Dataset dataset(
&IRIS_X[0][0], IRIS_Y,
IRIS_N_SAMPLES, IRIS_N_FEATURES, IRIS_N_CLASSES);
Dataset train_ds(dataset), test_ds(dataset);
dataset.split(0.2f, train_ds, test_ds, 42);
printf("Dataset split: %d train / %d test\n",
train_ds.size(), test_ds.size());
// ---- Model components ----
const int SEQ_LEN = IRIS_N_FEATURES; // 4 tokens (one per feature)
const int EMB_DIM = 8;
const int N_HEADS = 2;
Dense embed_proj(IRIS_N_FEATURES, SEQ_LEN * EMB_DIM, true);
ActivationLayer embed_act(ActType::RELU);
Attention attn(EMB_DIM, N_HEADS, true);
GlobalAvgPool gap;
Dense classifier(EMB_DIM, IRIS_N_CLASSES, true);
printf("Model summary:\n");
printf(" Dense(%d, %d) + ReLU\n", IRIS_N_FEATURES, SEQ_LEN * EMB_DIM);
printf(" reshape [B, %d] -> [B, %d, %d] (tokens, embed_dim)\n",
SEQ_LEN * EMB_DIM, SEQ_LEN, EMB_DIM);
printf(" Attention(embed_dim=%d, heads=%d, head_dim=%d)\n",
EMB_DIM, N_HEADS, EMB_DIM / N_HEADS);
printf(" GlobalAvgPool\n");
printf(" Dense(%d, %d) [raw logits]\n", EMB_DIM, IRIS_N_CLASSES);
Adam opt(1e-3f);
#if TINY_AI_TRAINING_ENABLED
std::vector<ParamGroup> params;
embed_proj.collect_params(params);
attn.collect_params(params);
classifier.collect_params(params);
opt.init(params);
const int batch_size = 16;
const int n_epochs = 100;
const int print_every = 20;
int *y_batch = (int *)TINY_AI_MALLOC((size_t)batch_size * sizeof(int));
if (!y_batch)
{
printf(" Memory allocation failed!\n");
return;
}
printf("\nTraining...\n");
for (int epoch = 0; epoch < n_epochs; epoch++)
{
train_ds.shuffle(epoch + 1);
float epoch_loss = 0.0f;
int n_batches = 0;
Tensor X_batch;
while (true)
{
int actual = train_ds.next_batch(X_batch, y_batch, batch_size);
if (actual == 0) break;
// ------ Forward ------
Tensor e0 = embed_proj.forward(X_batch);
Tensor e1 = embed_act.forward(e0);
e1.reshape_3d(actual, SEQ_LEN, EMB_DIM);
Tensor a0 = attn.forward(e1);
Tensor p0 = gap.forward(a0);
Tensor logits = classifier.forward(p0);
float loss = cross_entropy_forward(logits, y_batch);
epoch_loss += loss;
// ------ Backward ------
opt.zero_grad(params);
Tensor dlogits = cross_entropy_backward(logits, y_batch);
Tensor dp0 = classifier.backward(dlogits);
Tensor da0 = gap.backward(dp0);
Tensor de1 = attn.backward(da0);
de1.reshape_2d(actual, SEQ_LEN * EMB_DIM);
Tensor de0 = embed_act.backward(de1);
embed_proj.backward(de0);
opt.step(params);
n_batches++;
}
if ((epoch + 1) % print_every == 0)
printf("Epoch [%3d/%3d] loss: %.6f\n",
epoch + 1, n_epochs, epoch_loss / (float)n_batches);
}
TINY_AI_FREE(y_batch);
// ---- Evaluate ----
auto eval_accuracy = [&](Dataset &ds, const char *tag)
{
ds.reset();
int correct = 0, total = 0;
int *yb = (int *)TINY_AI_MALLOC((size_t)batch_size * sizeof(int));
int *yp = (int *)TINY_AI_MALLOC((size_t)batch_size * sizeof(int));
if (!yb || !yp) { TINY_AI_FREE(yb); TINY_AI_FREE(yp); return; }
Tensor Xb;
while (true)
{
int actual = ds.next_batch(Xb, yb, batch_size);
if (actual == 0) break;
Tensor e0 = embed_proj.forward(Xb);
Tensor e1 = embed_act.forward(e0);
e1.reshape_3d(actual, SEQ_LEN, EMB_DIM);
Tensor a0 = attn.forward(e1);
Tensor p0 = gap.forward(a0);
Tensor logits = classifier.forward(p0);
int n_cls = logits.shape[1];
for (int i = 0; i < actual; i++)
{
float best = logits.at(i, 0);
int pred = 0;
for (int c = 1; c < n_cls; c++)
if (logits.at(i, c) > best) { best = logits.at(i, c); pred = c; }
yp[i] = pred;
}
for (int i = 0; i < actual; i++) if (yp[i] == yb[i]) correct++;
total += actual;
}
TINY_AI_FREE(yb);
TINY_AI_FREE(yp);
printf(" %s accuracy: %.2f%%\n", tag, 100.0f * correct / total);
};
printf("\n--- Float32 Results ---\n");
eval_accuracy(train_ds, "Train");
eval_accuracy(test_ds, "Test ");
#else
printf("(Training disabled)\n");
#endif
printf("\nexample_attention DONE\n");
}
#endif // __cplusplus