Skip to content
This repository was archived by the owner on Oct 19, 2023. It is now read-only.

google-assistant-sdk/pushtotalk: fix conversation_stream handling #188

Merged
merged 7 commits into from
Mar 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions google-assistant-sdk/googlesamples/assistant/grpc/audio_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

"""Helper functions for audio streams."""

import array
import logging
import threading
import math
import time
import threading
import wave
import math
import array

import click
import sounddevice as sd
Expand Down Expand Up @@ -165,6 +165,9 @@ def start(self):
def stop(self):
pass

def flush(self):
pass


class SoundDeviceStream(object):
"""Audio stream based on an underlying sound device.
Expand Down Expand Up @@ -207,7 +210,7 @@ def write(self, buf):
return len(buf)

def flush(self):
if self._flush_size > 0:
if self._audio_stream.active and self._flush_size > 0:
self._audio_stream.write(b'\x00' * self._flush_size)

def start(self):
Expand All @@ -218,7 +221,6 @@ def start(self):
def stop(self):
"""Stop the underlying stream."""
if self._audio_stream.active:
self.flush()
self._audio_stream.stop()

def close(self):
Expand Down Expand Up @@ -264,29 +266,43 @@ def __init__(self, source, sink, iter_size, sample_width):
self._sink = sink
self._iter_size = iter_size
self._sample_width = sample_width
self._stop_recording = threading.Event()
self._start_playback = threading.Event()
self._volume_percentage = 50
self._stop_recording = threading.Event()
self._source_lock = threading.RLock()
self._recording = False
self._playing = False

def start_recording(self):
"""Start recording from the audio source."""
self._recording = True
self._stop_recording.clear()
self._source.start()
self._sink.start()

def stop_recording(self):
"""Stop recording from the audio source."""
self._stop_recording.set()
with self._source_lock:
self._source.stop()
self._recording = False

def start_playback(self):
"""Start playback to the audio sink."""
self._start_playback.set()
self._playing = True
self._sink.start()

def stop_playback(self):
"""Stop playback from the audio sink."""
self._start_playback.clear()
self._source.stop()
self._sink.flush()
self._sink.stop()
self._playing = False

@property
def recording(self):
return self._recording

@property
def playing(self):
return self._playing

@property
def volume_percentage(self):
Expand All @@ -299,19 +315,13 @@ def volume_percentage(self, new_volume_percentage):

def read(self, size):
"""Read bytes from the source (if currently recording).

Will returns an empty byte string, if stop_recording() was called.
"""
if self._stop_recording.is_set():
return b''
return self._source.read(size)
with self._source_lock:
return self._source.read(size)

def write(self, buf):
"""Write bytes to the sink (if currently playing).

Will block until start_playback() is called.
"""
self._start_playback.wait()
buf = align_buf(buf, self._sample_width)
buf = normalize_audio_buffer(buf, self.volume_percentage)
return self._sink.write(buf)
Expand All @@ -323,7 +333,10 @@ def close(self):

def __iter__(self):
"""Returns a generator reading data from the stream."""
return iter(lambda: self.read(self._iter_size), b'')
while True:
if self._stop_recording.is_set():
raise StopIteration
yield self.read(self._iter_size)

@property
def sample_rate(self):
Expand Down
22 changes: 10 additions & 12 deletions google-assistant-sdk/googlesamples/assistant/grpc/pushtotalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,26 +120,29 @@ def assist(self):
self.conversation_stream.start_recording()
logging.info('Recording audio request.')

def iter_assist_requests():
def iter_log_assist_requests():
for c in self.gen_assist_requests():
assistant_helpers.log_assist_request_without_audio(c)
yield c
self.conversation_stream.start_playback()
logging.debug('Reached end of AssistRequest iteration.')

# This generator yields AssistResponse proto messages
# received from the gRPC Google Assistant API.
for resp in self.assistant.Assist(iter_assist_requests(),
for resp in self.assistant.Assist(iter_log_assist_requests(),
self.deadline):
assistant_helpers.log_assist_response_without_audio(resp)
if resp.event_type == END_OF_UTTERANCE:
logging.info('End of audio request detected')
logging.info('End of audio request detected.')
logging.info('Stopping recording.')
self.conversation_stream.stop_recording()
if resp.speech_results:
logging.info('Transcript of user request: "%s".',
' '.join(r.transcript
for r in resp.speech_results))
logging.info('Playing assistant response.')
if len(resp.audio_out.audio_data) > 0:
if not self.conversation_stream.playing:
self.conversation_stream.start_playback()
logging.info('Playing assistant response.')
self.conversation_stream.write(resp.audio_out.audio_data)
if resp.dialog_state_out.conversation_state:
conversation_state = resp.dialog_state_out.conversation_state
Expand Down Expand Up @@ -317,21 +320,18 @@ def main(api_endpoint, credentials, project_id,
logging.info('Connecting to %s', api_endpoint)

# Configure audio source and sink.
audio_device = None
if input_audio_file:
audio_source = audio_helpers.WaveSource(
open(input_audio_file, 'rb'),
sample_rate=audio_sample_rate,
sample_width=audio_sample_width
)
else:
audio_source = audio_device = (
audio_device or audio_helpers.SoundDeviceStream(
audio_source = audio_helpers.SoundDeviceStream(
sample_rate=audio_sample_rate,
sample_width=audio_sample_width,
block_size=audio_block_size,
flush_size=audio_flush_size
)
)
if output_audio_file:
audio_sink = audio_helpers.WaveSink(
Expand All @@ -340,13 +340,11 @@ def main(api_endpoint, credentials, project_id,
sample_width=audio_sample_width
)
else:
audio_sink = audio_device = (
audio_device or audio_helpers.SoundDeviceStream(
audio_sink = audio_helpers.SoundDeviceStream(
sample_rate=audio_sample_rate,
sample_width=audio_sample_width,
block_size=audio_block_size,
flush_size=audio_flush_size
)
)
# Create conversation stream with the given audio source and sink.
conversation_stream = audio_helpers.ConversationStream(
Expand Down
48 changes: 36 additions & 12 deletions google-assistant-sdk/tests/test_audio_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import unittest

import time
import threading
import wave

from googlesamples.assistant.grpc import audio_helpers
Expand Down Expand Up @@ -87,12 +86,29 @@ def test_write_header(self):
self.assertEqual(b'RIFF', self.stream.getvalue()[:4])


class DummyStream(BytesIO):
class DummyStream(BytesIO, object):
started = False
stopped = False
flushed = False

def start(self):
pass
self.started = True

def stop(self):
pass
self.stopped = True

def read(self, *args):
if self.stopped:
return b''
return super(DummyStream, self).read(*args)

def write(self, *args):
if not self.started:
return
return super(DummyStream, self).write(*args)

def flush(self):
self.flushed = True


class ConversationStreamTest(unittest.TestCase):
Expand All @@ -114,17 +130,25 @@ def test_stop_recording(self):

def test_start_playback(self):
self.playback_started = False

def start_playback():
self.playback_started = True
self.stream.start_playback()
t = threading.Timer(0.1, start_playback)
t.start()
# write will block until start_playback is called.
self.stream.write(b'foo')
self.assertEqual(True, self.playback_started)
self.assertEqual(b'', self.sink.getvalue())
self.stream.start_playback()
self.stream.write(b'foo')
self.assertEqual(b'foo\0', self.sink.getvalue())

def test_sink_source_state(self):
self.assertEquals(False, self.source.started)
self.stream.start_recording()
self.assertEquals(True, self.source.started)
self.stream.stop_recording()
self.assertEquals(True, self.source.stopped)

self.assertEquals(False, self.sink.started)
self.stream.start_playback()
self.assertEquals(True, self.sink.started)
self.stream.stop_playback()
self.assertEquals(True, self.sink.stopped)

def test_oneshot_conversation(self):
self.assertEqual(b'audio', self.stream.read(5))
self.stream.stop_recording()
Expand Down