Skip to content

Commit 61365e5

Browse files
authored
Add GenAIVision Example (Ameba-AIoT#281)
* Add GenAIVision Example * Update GenAIVision Example - create NNGenAIVision class - add vision prompt support for Llama model * Update GenAIVision API
1 parent 756eb07 commit 61365e5

File tree

3 files changed

+383
-0
lines changed

3 files changed

+383
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
3+
This sketch shows the example of image prompts using APIs.
4+
5+
openAI platform - openAI vision
6+
https://platform.openai.com/docs/guides/vision
7+
8+
Google AI Studio - Gemini vision
9+
https://ai.google.dev/gemini-api/docs/vision
10+
11+
GroqCloud - Llama vision
12+
https://console.groq.com/docs/overview
13+
14+
Example Guide: https://ameba-arduino-doc.readthedocs.io/en/latest/amebapro2/Example_Guides/Neural%20Network/Generative%20AI%20Vision.html
15+
16+
Credit : ChungYi Fu (Kaohsiung, Taiwan)
17+
18+
*/
19+
20+
String openAI_key = ""; // paste your generated openAI API key here
21+
String Gemini_key = ""; // paste your generated Gemini API key here
22+
String Llama_key = ""; // paste your generated Llama API key here
23+
char wifi_ssid[] = "Network_SSID5"; // change to your network SSID
24+
char wifi_pass[] = "Password"; // change to your network password
25+
26+
#include <WiFi.h>
27+
#include "NNGenAIVision.h"
28+
#include "VideoStream.h"
29+
WiFiSSLClient client;
30+
NNGenAIVision llm;
31+
VideoSetting config(768, 768, CAM_FPS, VIDEO_JPEG, 1);
32+
#define CHANNEL 0
33+
34+
uint32_t img_addr = 0;
35+
uint32_t img_len = 0;
36+
37+
String prompt_msg = "Please describe the image, and if there is a text, please summarize the content";
38+
39+
void initWiFi()
40+
{
41+
for (int i = 0; i < 2; i++) {
42+
WiFi.begin(wifi_ssid, wifi_pass);
43+
44+
delay(1000);
45+
Serial.println("");
46+
Serial.print("Connecting to ");
47+
Serial.println(wifi_ssid);
48+
49+
long int StartTime = millis();
50+
while (WiFi.status() != WL_CONNECTED) {
51+
delay(500);
52+
if ((StartTime + 5000) < millis()) {
53+
break;
54+
}
55+
}
56+
57+
if (WiFi.status() == WL_CONNECTED) {
58+
Serial.println("");
59+
Serial.println("STAIP address: ");
60+
Serial.println(WiFi.localIP());
61+
Serial.println("");
62+
63+
break;
64+
}
65+
}
66+
}
67+
68+
void setup()
69+
{
70+
Serial.begin(115200);
71+
72+
initWiFi();
73+
74+
config.setRotation(0);
75+
Camera.configVideoChannel(CHANNEL, config);
76+
Camera.videoInit();
77+
Camera.channelBegin(CHANNEL);
78+
Camera.printInfo();
79+
80+
delay(5000);
81+
82+
// Vision prompt using same taken image
83+
Camera.getImage(0, &img_addr, &img_len);
84+
85+
// openAI vision prompt
86+
llm.openaivision(openAI_key, prompt_msg, img_addr, img_len, client);
87+
88+
// Gemini vision prompt
89+
// llm.geminivision(Gemini_key, prompt_msg, img_addr, img_len, client);
90+
91+
// Llama vision prompt
92+
// llm.llamavision(Llama_key, prompt_msg, img_addr, img_len, client);
93+
}
94+
95+
void loop()
96+
{
97+
// do nothing
98+
}
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
#include "NNGenAIVision.h"
2+
#include <ArduinoJson.h>
3+
#include "Base64.h"
4+
5+
NNGenAIVision::NNGenAIVision()
6+
{
7+
}
8+
9+
NNGenAIVision::~NNGenAIVision()
10+
{
11+
}
12+
13+
// Model: gpt-4o-mini, Server: openAI Platform
14+
void NNGenAIVision::openaivision(String key, String message, uint32_t img_addr, uint32_t img_len, WiFiSSLClient client)
15+
{
16+
const char *myDomain = "api.openai.com";
17+
String getResponse = "", Feedback = "";
18+
Serial.println("Connect to " + String(myDomain));
19+
if (client.connect(myDomain, 443)) {
20+
Serial.println("Connection successful");
21+
22+
uint8_t *fbBuf = (uint8_t *)img_addr;
23+
size_t fbLen = img_len;
24+
25+
char *input = (char *)fbBuf;
26+
char output[base64_enc_len(3)];
27+
String imageFile = "data:image/jpeg;base64,";
28+
for (int i = 0; i < fbLen; i++) {
29+
base64_encode(output, (input++), 3);
30+
if (i % 3 == 0) {
31+
imageFile += String(output);
32+
}
33+
}
34+
String Data = "{\"model\": \"gpt-4o-mini\", \"messages\": [{\"role\": \"user\",\"content\": [{ \"type\": \"text\", \"text\": \"" + message + "\"},{\"type\": \"image_url\", \"image_url\": {\"url\": \"" + imageFile + "\"}}]}]}";
35+
36+
client.println("POST /v1/chat/completions HTTP/1.1");
37+
client.println("Host: " + String(myDomain));
38+
client.println("Authorization: Bearer " + key);
39+
client.println("Content-Type: application/json; charset=utf-8");
40+
client.println("Content-Length: " + String(Data.length()));
41+
client.println("Connection: close");
42+
client.println();
43+
44+
int Index;
45+
for (Index = 0; Index < Data.length(); Index = Index + 1024) {
46+
client.print(Data.substring(Index, Index + 1024));
47+
}
48+
49+
int waitTime = 10000;
50+
long startTime = millis();
51+
boolean state = false;
52+
boolean markState = false;
53+
while ((startTime + waitTime) > millis()) {
54+
Serial.print(".");
55+
delay(100);
56+
while (client.available()) {
57+
char c = client.read();
58+
if (String(c) == "{") {
59+
markState = true;
60+
}
61+
if (state == true && markState == true) {
62+
Feedback += String(c);
63+
}
64+
if (c == '\n') {
65+
if (getResponse.length() == 0) {
66+
state = true;
67+
}
68+
getResponse = "";
69+
} else if (c != '\r') {
70+
getResponse += String(c);
71+
}
72+
startTime = millis();
73+
}
74+
if (Feedback.length() > 0) {
75+
break;
76+
}
77+
}
78+
Serial.println();
79+
client.stop();
80+
81+
JsonObject obj;
82+
DynamicJsonDocument doc(4096);
83+
deserializeJson(doc, Feedback);
84+
obj = doc.as<JsonObject>();
85+
getResponse = obj["choices"][0]["message"]["content"].as<String>();
86+
if (getResponse == "null") {
87+
getResponse = obj["error"]["message"].as<String>();
88+
}
89+
} else {
90+
getResponse = "Connected to " + String(myDomain) + " failed.";
91+
}
92+
Serial.println("Response from GPT:");
93+
Serial.println(getResponse);
94+
}
95+
96+
// Model: gemini-1.5-flash, Server: Google AI Studio
97+
void NNGenAIVision::geminivision(String key, String message, uint32_t img_addr, uint32_t img_len, WiFiSSLClient client)
98+
{
99+
const char *myDomain = "generativelanguage.googleapis.com";
100+
String getResponse = "", Feedback = "";
101+
Serial.println("Connect to " + String(myDomain));
102+
if (client.connect(myDomain, 443)) {
103+
Serial.println("Connection successful");
104+
105+
uint8_t *fbBuf = (uint8_t *)img_addr;
106+
size_t fbLen = img_len;
107+
char *input = (char *)fbBuf;
108+
char output[base64_enc_len(3)];
109+
String imageFile = "";
110+
for (int i = 0; i < fbLen; i++) {
111+
base64_encode(output, (input++), 3);
112+
if (i % 3 == 0) {
113+
imageFile += String(output);
114+
}
115+
}
116+
String Data = "{\"contents\": [{\"parts\": [{\"text\": \"" + message + "\"}, {\"inline_data\": {\"mime_type\":\"image/jpeg\",\"data\":\"" + imageFile + "\"}}]}]}";
117+
// String Data = "{\"contents\": [{\"parts\": [{\"text\": \""+message+"\"}]}]}";
118+
119+
client.println("POST /v1beta/models/gemini-1.5-flash-latest:generateContent?key=" + key + " HTTP/1.1");
120+
client.println("Host: " + String(myDomain));
121+
client.println("Content-Type: application/json; charset=utf-8");
122+
client.println("Content-Length: " + String(Data.length()));
123+
client.println("Connection: close");
124+
client.println();
125+
126+
int Index;
127+
for (Index = 0; Index < Data.length(); Index = Index + 1024) {
128+
client.print(Data.substring(Index, Index + 1024));
129+
}
130+
131+
int waitTime = 10000;
132+
long startTime = millis();
133+
boolean state = false;
134+
boolean markState = false;
135+
while ((startTime + waitTime) > millis()) {
136+
Serial.print(".");
137+
delay(100);
138+
while (client.available()) {
139+
char c = client.read();
140+
if (String(c) == "{") {
141+
markState = true;
142+
}
143+
if (state == true && markState == true) {
144+
Feedback += String(c);
145+
}
146+
if (c == '\n') {
147+
if (getResponse.length() == 0) {
148+
state = true;
149+
}
150+
getResponse = "";
151+
} else if (c != '\r') {
152+
getResponse += String(c);
153+
}
154+
startTime = millis();
155+
}
156+
if (Feedback.length() > 0) {
157+
break;
158+
}
159+
}
160+
Serial.println();
161+
client.stop();
162+
163+
JsonObject obj;
164+
DynamicJsonDocument doc(4096);
165+
deserializeJson(doc, Feedback);
166+
obj = doc.as<JsonObject>();
167+
getResponse = obj["candidates"][0]["content"]["parts"][0]["text"].as<String>();
168+
if (getResponse == "null") {
169+
getResponse = obj["error"]["message"].as<String>();
170+
}
171+
} else {
172+
getResponse = "Connected to " + String(myDomain) + " failed.";
173+
}
174+
Serial.println("Response from Gemini:");
175+
Serial.println(getResponse);
176+
}
177+
178+
// Model: llama-3.2-90b-vision-preview, Server: groq
179+
void NNGenAIVision::llamavision(String key, String message, uint32_t img_addr, uint32_t img_len, WiFiSSLClient client)
180+
{
181+
const char *myDomain = "api.groq.com";
182+
String getResponse = "", Feedback = "";
183+
Serial.println("Connect to " + String(myDomain));
184+
if (client.connect(myDomain, 443)) {
185+
Serial.println("Connection successful");
186+
187+
uint8_t *fbBuf = (uint8_t *)img_addr;
188+
size_t fbLen = img_len;
189+
190+
char *input = (char *)fbBuf;
191+
char output[base64_enc_len(3)];
192+
String imageFile = "data:image/jpeg;base64,";
193+
for (int i = 0; i < fbLen; i++) {
194+
base64_encode(output, (input++), 3);
195+
if (i % 3 == 0) {
196+
imageFile += String(output);
197+
}
198+
}
199+
String Data = "{\"model\": \"llama-3.2-90b-vision-preview\", \"messages\": [{\"role\": \"user\",\"content\": [{ \"type\": \"text\", \"text\": \"" + message + "\"},{\"type\": \"image_url\", \"image_url\": {\"url\": \"" + imageFile + "\"}}]}]}";
200+
201+
client.println("POST /openai/v1/chat/completions HTTP/1.1");
202+
client.println("Host: " + String(myDomain));
203+
client.println("Authorization: Bearer " + key);
204+
client.println("Content-Type: application/json; charset=utf-8");
205+
client.println("Content-Length: " + String(Data.length()));
206+
client.println("Connection: close");
207+
client.println();
208+
209+
int Index;
210+
for (Index = 0; Index < Data.length(); Index = Index + 1024) {
211+
client.print(Data.substring(Index, Index + 1024));
212+
}
213+
214+
int waitTime = 10000;
215+
long startTime = millis();
216+
boolean state = false;
217+
boolean markState = false;
218+
while ((startTime + waitTime) > millis()) {
219+
Serial.print(".");
220+
delay(100);
221+
while (client.available()) {
222+
char c = client.read();
223+
if (String(c) == "{") {
224+
markState = true;
225+
}
226+
if (state == true && markState == true) {
227+
Feedback += String(c);
228+
}
229+
if (c == '\n') {
230+
if (getResponse.length() == 0) {
231+
state = true;
232+
}
233+
getResponse = "";
234+
} else if (c != '\r') {
235+
getResponse += String(c);
236+
}
237+
startTime = millis();
238+
}
239+
if (Feedback.length() > 0) {
240+
break;
241+
}
242+
}
243+
Serial.println();
244+
client.stop();
245+
246+
JsonObject obj;
247+
DynamicJsonDocument doc(4096);
248+
deserializeJson(doc, Feedback);
249+
obj = doc.as<JsonObject>();
250+
getResponse = obj["choices"][0]["message"]["content"].as<String>();
251+
if (getResponse == "null") {
252+
getResponse = obj["error"]["message"].as<String>();
253+
}
254+
} else {
255+
getResponse = "Connected to " + String(myDomain) + " failed.";
256+
}
257+
Serial.println("Response from Llama:");
258+
Serial.println(getResponse);
259+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef __NN_GENAIVISION_H__
2+
#define __NN_GENAIVISION_H__
3+
4+
#include "WiFi.h"
5+
#include <ArduinoJson.h>
6+
#include "Base64.h"
7+
8+
#ifdef __cplusplus
9+
extern "C" {
10+
#endif
11+
12+
#ifdef __cplusplus
13+
}
14+
#endif
15+
16+
class NNGenAIVision {
17+
public:
18+
NNGenAIVision();
19+
~NNGenAIVision();
20+
21+
void openaivision(String key, String message, uint32_t img_addr, uint32_t img_len, WiFiSSLClient client);
22+
void geminivision(String key, String message, uint32_t img_addr, uint32_t img_len, WiFiSSLClient client);
23+
void llamavision(String key, String message, uint32_t img_addr, uint32_t img_len, WiFiSSLClient client);
24+
};
25+
26+
#endif

0 commit comments

Comments
 (0)