Skip to content

Commit eddadf0

Browse files
committed
for testing only
1 parent 8ae4832 commit eddadf0

File tree

1 file changed

+360
-0
lines changed

1 file changed

+360
-0
lines changed

stream.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
import serial
2+
import time
3+
import numpy as np
4+
import serial.tools.list_ports
5+
import sys
6+
import threading
7+
import pyqtgraph as pg
8+
from pyqtgraph.Qt import QtWidgets, QtCore
9+
from scipy.fft import rfft, rfftfreq
10+
11+
class CircularBuffer:
12+
def __init__(self, size):
13+
self.size = size
14+
self.buffer = np.zeros(size, dtype=np.float32)
15+
self.write_index = 0
16+
self.read_index = 0
17+
self.count = 0
18+
self.lock = threading.Lock()
19+
20+
def write(self, data):
21+
with self.lock:
22+
if isinstance(data, (list, np.ndarray)):
23+
for value in data:
24+
self.buffer[self.write_index] = value
25+
self.write_index = (self.write_index + 1) % self.size
26+
if self.count < self.size:
27+
self.count += 1
28+
else:
29+
self.buffer[self.write_index] = data
30+
self.write_index = (self.write_index + 1) % self.size
31+
if self.count < self.size:
32+
self.count += 1
33+
34+
def read_block(self, block_size, start_offset=0):
35+
with self.lock:
36+
if self.count < block_size:
37+
return None
38+
39+
# Calculate the actual read start position
40+
read_start = (self.read_index + start_offset) % self.size
41+
# Read the block
42+
if read_start + block_size <= self.size:
43+
return self.buffer[read_start:read_start + block_size].copy()
44+
else:
45+
first_part = self.buffer[read_start:].copy()
46+
second_part = self.buffer[:block_size - len(first_part)].copy()
47+
return np.concatenate([first_part, second_part])
48+
49+
def advance_read_pointer(self, count=1):
50+
with self.lock:
51+
self.read_index = (self.read_index + count) % self.size
52+
self.count = max(0, self.count - count)
53+
54+
def get_count(self):
55+
with self.lock:
56+
return self.count
57+
58+
def is_full(self):
59+
with self.lock:
60+
return self.count >= self.size
61+
62+
class Chords_USB:
63+
SYNC_BYTE1 = 0xc7
64+
SYNC_BYTE2 = 0x7c
65+
END_BYTE = 0x01
66+
HEADER_LENGTH = 3
67+
68+
supported_boards = {
69+
"UNO-R3": {"sampling_rate": 250, "Num_channels": 6},
70+
"UNO-CLONE": {"sampling_rate": 250, "Num_channels": 6},
71+
"GENUINO-UNO": {"sampling_rate": 250, "Num_channels": 6},
72+
"UNO-R4": {"sampling_rate": 512, "Num_channels": 6},
73+
"RPI-PICO-RP2040": {"sampling_rate": 500, "Num_channels": 3},
74+
"NANO-CLONE": {"sampling_rate": 250, "Num_channels": 8},
75+
"NANO-CLASSIC": {"sampling_rate": 250, "Num_channels": 8},
76+
"STM32F4-BLACK-PILL": {"sampling_rate": 500, "Num_channels": 8},
77+
"STM32G4-CORE-BOARD": {"sampling_rate": 500, "Num_channels": 16},
78+
"MEGA-2560-R3": {"sampling_rate": 250, "Num_channels": 16},
79+
"MEGA-2560-CLONE": {"sampling_rate": 250, "Num_channels": 16},
80+
"GIGA-R1": {"sampling_rate": 500, "Num_channels": 6},
81+
"NPG-LITE": {"sampling_rate": 500, "Num_channels": 3},
82+
}
83+
84+
def __init__(self):
85+
self.ser = None
86+
self.buffer = bytearray()
87+
self.retry_limit = 4
88+
self.packet_length = None
89+
self.num_channels = None
90+
self.board = ""
91+
self.streaming_active = False
92+
self.sampling_rate = 512
93+
94+
# Circular buffers
95+
self.serial_buffer = CircularBuffer(512) # 512 samples for serial data
96+
self.fft_ready = threading.Event()
97+
self.fft_counter = 0
98+
self.overlap_samples = 16 # Overlap for smoother FFT updates
99+
100+
def connect_hardware(self, port, baudrate, timeout=1):
101+
try:
102+
self.ser = serial.Serial(port, baudrate=baudrate, timeout=timeout)
103+
retry_counter = 0
104+
response = None
105+
106+
while retry_counter < self.retry_limit:
107+
self.ser.write(b'WHORU\n')
108+
try:
109+
response = self.ser.readline().strip().decode()
110+
except UnicodeDecodeError:
111+
response = None
112+
113+
if response in self.supported_boards:
114+
self.board = response
115+
print(f"{response} detected at {port} with baudrate {baudrate}")
116+
self.num_channels = self.supported_boards[self.board]["Num_channels"]
117+
self.sampling_rate = self.supported_boards[self.board]["sampling_rate"]
118+
self.packet_length = (2 * self.num_channels) + self.HEADER_LENGTH + 1
119+
return True
120+
121+
retry_counter += 1
122+
123+
self.ser.close()
124+
except Exception as e:
125+
print(f"Connection Error: {e}")
126+
return False
127+
128+
def detect_hardware(self, timeout=1):
129+
baudrates = [230400, 115200]
130+
ports = serial.tools.list_ports.comports()
131+
132+
for port in ports:
133+
for baud in baudrates:
134+
print(f"Trying {port.device} at {baud}...")
135+
if self.connect_hardware(port.device, baud, timeout):
136+
return True
137+
138+
print("Unable to detect supported hardware.")
139+
return False
140+
141+
def send_command(self, command):
142+
if self.ser and self.ser.is_open:
143+
self.ser.flushInput()
144+
self.ser.flushOutput()
145+
self.ser.write(f"{command}\n".encode())
146+
time.sleep(0.1)
147+
response = self.ser.readline().decode('utf-8', errors='ignore').strip()
148+
return response
149+
return None
150+
151+
def read_data(self):
152+
try:
153+
raw_data = self.ser.read(self.ser.in_waiting or 1)
154+
if raw_data == b'':
155+
raise serial.SerialException("Serial port disconnected or No data received.")
156+
self.buffer.extend(raw_data)
157+
158+
while len(self.buffer) >= self.packet_length:
159+
sync_index = self.buffer.find(bytes([self.SYNC_BYTE1, self.SYNC_BYTE2]))
160+
if sync_index == -1:
161+
self.buffer.clear()
162+
continue
163+
164+
if len(self.buffer) >= (sync_index + self.packet_length):
165+
packet = self.buffer[sync_index:sync_index + self.packet_length]
166+
if packet[0] == self.SYNC_BYTE1 and packet[1] == self.SYNC_BYTE2 and packet[-1] == self.END_BYTE:
167+
channel_data = []
168+
169+
for ch in range(self.num_channels):
170+
high_byte = packet[2 * ch + self.HEADER_LENGTH]
171+
low_byte = packet[2 * ch + self.HEADER_LENGTH + 1]
172+
value = (high_byte << 8) | low_byte
173+
channel_data.append(float(value))
174+
175+
self.serial_buffer.write(channel_data[0])
176+
177+
if self.serial_buffer.get_count() >= 512: # Set FFT ready flag when we have enough data
178+
self.fft_ready.set()
179+
180+
del self.buffer[:sync_index + self.packet_length]
181+
else:
182+
del self.buffer[:sync_index + 1]
183+
except serial.SerialException as e:
184+
print(f"Serial error: {e}")
185+
self.cleanup()
186+
187+
def start_streaming(self):
188+
self.send_command('START')
189+
self.streaming_active = True
190+
try:
191+
while self.streaming_active:
192+
self.read_data()
193+
except KeyboardInterrupt:
194+
print("KeyboardInterrupt received.")
195+
self.cleanup()
196+
197+
def stop_streaming(self):
198+
self.streaming_active = False
199+
self.send_command('STOP')
200+
201+
def cleanup(self):
202+
self.stop_streaming()
203+
try:
204+
if self.ser and self.ser.is_open:
205+
self.ser.close()
206+
except Exception as e:
207+
print(f"Error during cleanup: {e}")
208+
209+
def signal_handler(self, sig, frame):
210+
self.cleanup()
211+
sys.exit(0)
212+
213+
class FFTProcessor:
214+
def __init__(self, serial_reader):
215+
self.serial_reader = serial_reader
216+
self.fft_buffer_size = 512
217+
self.overlap_samples = 16
218+
self.latest_fft_data = None
219+
self.latest_freqs = None
220+
self.latest_raw_data = None
221+
self.processing_active = False
222+
self.data_lock = threading.Lock()
223+
224+
def start_processing(self):
225+
self.processing_active = True
226+
processing_thread = threading.Thread(target=self._process_fft)
227+
processing_thread.daemon = True
228+
processing_thread.start()
229+
230+
def stop_processing(self):
231+
self.processing_active = False
232+
233+
def _process_fft(self):
234+
fft_counter = 0
235+
236+
while self.processing_active:
237+
if self.serial_reader.serial_buffer.get_count() >= self.fft_buffer_size:
238+
offset = (fft_counter * self.overlap_samples) % self.fft_buffer_size
239+
fft_data = self.serial_reader.serial_buffer.read_block(self.fft_buffer_size, start_offset=offset)
240+
241+
if fft_data is not None:
242+
fft_data = fft_data - np.mean(fft_data)
243+
window = np.hanning(len(fft_data))
244+
windowed_signal = fft_data * window
245+
fft_result = rfft(windowed_signal)
246+
freqs = rfftfreq(len(fft_data), 1.0/self.serial_reader.sampling_rate)
247+
power_spectrum = np.abs(fft_result)**2
248+
power_spectrum = power_spectrum / (self.serial_reader.sampling_rate * np.sum(window**2))
249+
250+
with self.data_lock:
251+
self.latest_fft_data = power_spectrum
252+
self.latest_freqs = freqs
253+
self.latest_raw_data = fft_data
254+
255+
fft_counter += 1
256+
if len(power_spectrum) > 5:
257+
start_idx = int(2.0 * len(power_spectrum) / (self.serial_reader.sampling_rate / 2))
258+
259+
sorted_indices = np.argsort(power_spectrum[start_idx:])[::-1] + start_idx
260+
peak1_idx = sorted_indices[0]
261+
peak1_freq = freqs[peak1_idx]
262+
print(f"Peak Frequency: {peak1_freq:.2f} Hz")
263+
264+
self.serial_reader.serial_buffer.advance_read_pointer(self.overlap_samples)
265+
266+
else:
267+
time.sleep(0.01)
268+
269+
def get_latest_data(self):
270+
with self.data_lock:
271+
return (
272+
self.latest_raw_data.copy() if self.latest_raw_data is not None else None,
273+
self.latest_fft_data.copy() if self.latest_fft_data is not None else None,
274+
self.latest_freqs.copy() if self.latest_freqs is not None else None
275+
)
276+
277+
class EEGMonitor(QtWidgets.QMainWindow):
278+
def __init__(self, serial_reader, fft_processor):
279+
super().__init__()
280+
self.serial_reader = serial_reader
281+
self.fft_processor = fft_processor
282+
self.setWindowTitle("EEG Monitor")
283+
self.setGeometry(100, 100, 1200, 600)
284+
285+
self.buffer_size = self.serial_reader.sampling_rate # 1 second buffer
286+
self.chunk_size = 5 # Chunk of new samples to read each update
287+
self.raw_data_buffer = np.zeros(self.buffer_size, dtype=np.float32)
288+
self.x_vals = np.arange(self.buffer_size)
289+
290+
self.init_ui()
291+
292+
self.timer = QtCore.QTimer()
293+
self.timer.timeout.connect(self.update)
294+
self.timer.start(20) # Every 20 ms
295+
296+
def init_ui(self):
297+
self.central_widget = QtWidgets.QWidget()
298+
self.setCentralWidget(self.central_widget)
299+
300+
self.layout = QtWidgets.QHBoxLayout(self.central_widget)
301+
302+
# Raw EEG plot
303+
self.raw_plot = pg.PlotWidget(title="Raw EEG Data")
304+
self.raw_plot.setLabel('left', 'Amplitude')
305+
self.raw_plot.setLabel('bottom', 'Sample')
306+
self.raw_plot.setYRange(0, 10000)
307+
self.raw_curve = self.raw_plot.plot(self.x_vals, self.raw_data_buffer, pen='y')
308+
309+
# FFT plot
310+
self.fft_plot = pg.PlotWidget(title="Real-time FFT")
311+
self.fft_plot.setLabel('left', 'Power')
312+
self.fft_plot.setLabel('bottom', 'Frequency (Hz)')
313+
self.fft_plot.setXRange(0, 50)
314+
self.fft_plot.setYRange(0, 200000)
315+
self.fft_curve = self.fft_plot.plot(pen='c')
316+
317+
self.layout.addWidget(self.raw_plot)
318+
self.layout.addWidget(self.fft_plot)
319+
320+
def update(self):
321+
try:
322+
new_data = self.serial_reader.serial_buffer.read_block(self.chunk_size)
323+
if new_data is not None and len(new_data) == self.chunk_size:
324+
self.raw_data_buffer = np.roll(self.raw_data_buffer, -self.chunk_size)
325+
self.raw_data_buffer[-self.chunk_size:] = new_data
326+
self.raw_curve.setData(self.x_vals, self.raw_data_buffer)
327+
328+
_, fft_data, freqs = self.fft_processor.get_latest_data()
329+
if fft_data is not None and freqs is not None:
330+
self.fft_curve.setData(freqs, fft_data)
331+
332+
except Exception as e:
333+
print(f"Error in update: {e}")
334+
335+
def main():
336+
serial_reader = Chords_USB()
337+
if not serial_reader.detect_hardware():
338+
print("Failed to detect hardware. Exiting...")
339+
return
340+
341+
fft_processor = FFTProcessor(serial_reader)
342+
343+
serial_thread = threading.Thread(target=serial_reader.start_streaming)
344+
serial_thread.daemon = True
345+
serial_thread.start()
346+
347+
fft_processor.start_processing()
348+
349+
app = QtWidgets.QApplication(sys.argv)
350+
window = EEGMonitor(serial_reader, fft_processor)
351+
window.show()
352+
353+
try:
354+
sys.exit(app.exec_())
355+
finally:
356+
fft_processor.stop_processing()
357+
serial_reader.cleanup()
358+
359+
if __name__ == "__main__":
360+
main()

0 commit comments

Comments
 (0)