diff --git a/google-assistant-sdk/googlesamples/assistant/grpc/audio_helpers.py b/google-assistant-sdk/googlesamples/assistant/grpc/audio_helpers.py index cd6d5de..6d62fc1 100644 --- a/google-assistant-sdk/googlesamples/assistant/grpc/audio_helpers.py +++ b/google-assistant-sdk/googlesamples/assistant/grpc/audio_helpers.py @@ -14,12 +14,12 @@ """Helper functions for audio streams.""" +import array import logging -import threading +import math import time +import threading import wave -import math -import array import click import sounddevice as sd @@ -165,6 +165,9 @@ def start(self): def stop(self): pass + def flush(self): + pass + class SoundDeviceStream(object): """Audio stream based on an underlying sound device. @@ -207,7 +210,7 @@ def write(self, buf): return len(buf) def flush(self): - if self._flush_size > 0: + if self._audio_stream.active and self._flush_size > 0: self._audio_stream.write(b'\x00' * self._flush_size) def start(self): @@ -218,7 +221,6 @@ def start(self): def stop(self): """Stop the underlying stream.""" if self._audio_stream.active: - self.flush() self._audio_stream.stop() def close(self): @@ -264,29 +266,43 @@ def __init__(self, source, sink, iter_size, sample_width): self._sink = sink self._iter_size = iter_size self._sample_width = sample_width - self._stop_recording = threading.Event() - self._start_playback = threading.Event() self._volume_percentage = 50 + self._stop_recording = threading.Event() + self._source_lock = threading.RLock() + self._recording = False + self._playing = False def start_recording(self): """Start recording from the audio source.""" + self._recording = True self._stop_recording.clear() self._source.start() - self._sink.start() def stop_recording(self): """Stop recording from the audio source.""" self._stop_recording.set() + with self._source_lock: + self._source.stop() + self._recording = False def start_playback(self): """Start playback to the audio sink.""" - self._start_playback.set() + self._playing = True + self._sink.start() def stop_playback(self): """Stop playback from the audio sink.""" - self._start_playback.clear() - self._source.stop() + self._sink.flush() self._sink.stop() + self._playing = False + + @property + def recording(self): + return self._recording + + @property + def playing(self): + return self._playing @property def volume_percentage(self): @@ -299,19 +315,13 @@ def volume_percentage(self, new_volume_percentage): def read(self, size): """Read bytes from the source (if currently recording). - - Will returns an empty byte string, if stop_recording() was called. """ - if self._stop_recording.is_set(): - return b'' - return self._source.read(size) + with self._source_lock: + return self._source.read(size) def write(self, buf): """Write bytes to the sink (if currently playing). - - Will block until start_playback() is called. """ - self._start_playback.wait() buf = align_buf(buf, self._sample_width) buf = normalize_audio_buffer(buf, self.volume_percentage) return self._sink.write(buf) @@ -323,7 +333,10 @@ def close(self): def __iter__(self): """Returns a generator reading data from the stream.""" - return iter(lambda: self.read(self._iter_size), b'') + while True: + if self._stop_recording.is_set(): + raise StopIteration + yield self.read(self._iter_size) @property def sample_rate(self): diff --git a/google-assistant-sdk/googlesamples/assistant/grpc/pushtotalk.py b/google-assistant-sdk/googlesamples/assistant/grpc/pushtotalk.py index 5043553..c508fc3 100644 --- a/google-assistant-sdk/googlesamples/assistant/grpc/pushtotalk.py +++ b/google-assistant-sdk/googlesamples/assistant/grpc/pushtotalk.py @@ -120,26 +120,29 @@ def assist(self): self.conversation_stream.start_recording() logging.info('Recording audio request.') - def iter_assist_requests(): + def iter_log_assist_requests(): for c in self.gen_assist_requests(): assistant_helpers.log_assist_request_without_audio(c) yield c - self.conversation_stream.start_playback() + logging.debug('Reached end of AssistRequest iteration.') # This generator yields AssistResponse proto messages # received from the gRPC Google Assistant API. - for resp in self.assistant.Assist(iter_assist_requests(), + for resp in self.assistant.Assist(iter_log_assist_requests(), self.deadline): assistant_helpers.log_assist_response_without_audio(resp) if resp.event_type == END_OF_UTTERANCE: - logging.info('End of audio request detected') + logging.info('End of audio request detected.') + logging.info('Stopping recording.') self.conversation_stream.stop_recording() if resp.speech_results: logging.info('Transcript of user request: "%s".', ' '.join(r.transcript for r in resp.speech_results)) - logging.info('Playing assistant response.') if len(resp.audio_out.audio_data) > 0: + if not self.conversation_stream.playing: + self.conversation_stream.start_playback() + logging.info('Playing assistant response.') self.conversation_stream.write(resp.audio_out.audio_data) if resp.dialog_state_out.conversation_state: conversation_state = resp.dialog_state_out.conversation_state @@ -317,7 +320,6 @@ def main(api_endpoint, credentials, project_id, logging.info('Connecting to %s', api_endpoint) # Configure audio source and sink. - audio_device = None if input_audio_file: audio_source = audio_helpers.WaveSource( open(input_audio_file, 'rb'), @@ -325,13 +327,11 @@ def main(api_endpoint, credentials, project_id, sample_width=audio_sample_width ) else: - audio_source = audio_device = ( - audio_device or audio_helpers.SoundDeviceStream( + audio_source = audio_helpers.SoundDeviceStream( sample_rate=audio_sample_rate, sample_width=audio_sample_width, block_size=audio_block_size, flush_size=audio_flush_size - ) ) if output_audio_file: audio_sink = audio_helpers.WaveSink( @@ -340,13 +340,11 @@ def main(api_endpoint, credentials, project_id, sample_width=audio_sample_width ) else: - audio_sink = audio_device = ( - audio_device or audio_helpers.SoundDeviceStream( + audio_sink = audio_helpers.SoundDeviceStream( sample_rate=audio_sample_rate, sample_width=audio_sample_width, block_size=audio_block_size, flush_size=audio_flush_size - ) ) # Create conversation stream with the given audio source and sink. conversation_stream = audio_helpers.ConversationStream( diff --git a/google-assistant-sdk/tests/test_audio_helpers.py b/google-assistant-sdk/tests/test_audio_helpers.py index 6fefa9d..90a57ff 100644 --- a/google-assistant-sdk/tests/test_audio_helpers.py +++ b/google-assistant-sdk/tests/test_audio_helpers.py @@ -16,7 +16,6 @@ import unittest import time -import threading import wave from googlesamples.assistant.grpc import audio_helpers @@ -87,12 +86,29 @@ def test_write_header(self): self.assertEqual(b'RIFF', self.stream.getvalue()[:4]) -class DummyStream(BytesIO): +class DummyStream(BytesIO, object): + started = False + stopped = False + flushed = False + def start(self): - pass + self.started = True def stop(self): - pass + self.stopped = True + + def read(self, *args): + if self.stopped: + return b'' + return super(DummyStream, self).read(*args) + + def write(self, *args): + if not self.started: + return + return super(DummyStream, self).write(*args) + + def flush(self): + self.flushed = True class ConversationStreamTest(unittest.TestCase): @@ -114,17 +130,25 @@ def test_stop_recording(self): def test_start_playback(self): self.playback_started = False - - def start_playback(): - self.playback_started = True - self.stream.start_playback() - t = threading.Timer(0.1, start_playback) - t.start() - # write will block until start_playback is called. self.stream.write(b'foo') - self.assertEqual(True, self.playback_started) + self.assertEqual(b'', self.sink.getvalue()) + self.stream.start_playback() + self.stream.write(b'foo') self.assertEqual(b'foo\0', self.sink.getvalue()) + def test_sink_source_state(self): + self.assertEquals(False, self.source.started) + self.stream.start_recording() + self.assertEquals(True, self.source.started) + self.stream.stop_recording() + self.assertEquals(True, self.source.stopped) + + self.assertEquals(False, self.sink.started) + self.stream.start_playback() + self.assertEquals(True, self.sink.started) + self.stream.stop_playback() + self.assertEquals(True, self.sink.stopped) + def test_oneshot_conversation(self): self.assertEqual(b'audio', self.stream.read(5)) self.stream.stop_recording()