From 5990e30f99918792e8141fd978c4cbdc2eabd270 Mon Sep 17 00:00:00 2001 From: Tom Date: Fri, 15 Jan 2021 15:55:58 -0700 Subject: [PATCH] GstreamerMediaPlayer: Fix callback function use-after-free Allocate a copy of the callback function and promise, and keep it until it was executed. This fixes a problem where the callback function could be freed as soon as the callback function sets the promise value and the thread waiting on the promise exits the function immediately, freeing the closure data while it is still executing. --- .../include/MediaPlayer/BaseStreamSource.h | 6 - .../include/MediaPlayer/MediaPlayer.h | 2 +- .../include/MediaPlayer/PipelineInterface.h | 2 +- .../src/BaseStreamSource.cpp | 6 +- .../GStreamerMediaPlayer/src/MediaPlayer.cpp | 157 +++++++++--------- 5 files changed, 84 insertions(+), 89 deletions(-) diff --git a/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/BaseStreamSource.h b/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/BaseStreamSource.h index 64e7d80404..82a08d387a 100644 --- a/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/BaseStreamSource.h +++ b/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/BaseStreamSource.h @@ -181,12 +181,6 @@ class BaseStreamSource : public SourceInterface { /// Number of times reading data has been attempted since data was last successfully read. guint m_sourceRetryCount; - /// Function to invoke on the worker thread thread when more data is needed. - const std::function m_handleNeedDataFunction; - - /// Function to invoke on the worker thread thread when there is enough data. - const std::function m_handleEnoughDataFunction; - /// ID of the handler installed to receive need data signals. guint m_needDataHandlerId; diff --git a/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/MediaPlayer.h b/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/MediaPlayer.h index 5b3955afe8..d6c8fe7e30 100644 --- a/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/MediaPlayer.h +++ b/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/MediaPlayer.h @@ -138,7 +138,7 @@ class MediaPlayer void setDecoder(GstElement* decoder) override; GstElement* getDecoder() const override; GstElement* getPipeline() const override; - guint queueCallback(const std::function* callback) override; + guint queueCallback(std::function&& callback) override; guint attachSource(GSource* source) override; gboolean removeSource(guint tag) override; /// @} diff --git a/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/PipelineInterface.h b/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/PipelineInterface.h index 83021fc5ef..a16a803642 100644 --- a/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/PipelineInterface.h +++ b/MediaPlayer/GStreamerMediaPlayer/include/MediaPlayer/PipelineInterface.h @@ -75,7 +75,7 @@ class PipelineInterface { * @param callback The callback to queue. * @return The ID of the queued callback (for calling @c g_source_remove). */ - virtual guint queueCallback(const std::function* callback) = 0; + virtual guint queueCallback(std::function&& callback) = 0; /** * Attach the source to the worker thread. diff --git a/MediaPlayer/GStreamerMediaPlayer/src/BaseStreamSource.cpp b/MediaPlayer/GStreamerMediaPlayer/src/BaseStreamSource.cpp index 61e98ce5f0..812d019375 100644 --- a/MediaPlayer/GStreamerMediaPlayer/src/BaseStreamSource.cpp +++ b/MediaPlayer/GStreamerMediaPlayer/src/BaseStreamSource.cpp @@ -89,8 +89,6 @@ BaseStreamSource::BaseStreamSource(PipelineInterface* pipeline, const std::strin m_sourceId{0}, m_hasReadData{false}, m_sourceRetryCount{0}, - m_handleNeedDataFunction{[this]() { return handleNeedData(); }}, - m_handleEnoughDataFunction{[this]() { return handleEnoughData(); }}, m_needDataHandlerId{0}, m_enoughDataHandlerId{0}, m_seekDataHandlerId{0}, @@ -311,7 +309,7 @@ void BaseStreamSource::onNeedData(GstElement* pipeline, guint size, gpointer poi ACSDK_DEBUG9(LX("m_needDataCallbackId already set")); return; } - source->m_needDataCallbackId = source->m_pipeline->queueCallback(&source->m_handleNeedDataFunction); + source->m_needDataCallbackId = source->m_pipeline->queueCallback([source]() { return source->handleNeedData(); }); } gboolean BaseStreamSource::handleNeedData() { @@ -330,7 +328,7 @@ void BaseStreamSource::onEnoughData(GstElement* pipeline, gpointer pointer) { ACSDK_DEBUG9(LX("m_enoughDataCallbackId already set")); return; } - source->m_enoughDataCallbackId = source->m_pipeline->queueCallback(&source->m_handleEnoughDataFunction); + source->m_enoughDataCallbackId = source->m_pipeline->queueCallback([source]() { return source->handleEnoughData(); }); } gboolean BaseStreamSource::handleEnoughData() { diff --git a/MediaPlayer/GStreamerMediaPlayer/src/MediaPlayer.cpp b/MediaPlayer/GStreamerMediaPlayer/src/MediaPlayer.cpp index e893fe00ed..01de66bf58 100644 --- a/MediaPlayer/GStreamerMediaPlayer/src/MediaPlayer.cpp +++ b/MediaPlayer/GStreamerMediaPlayer/src/MediaPlayer.cpp @@ -200,13 +200,13 @@ MediaPlayer::SourceId MediaPlayer::setSource( const avsCommon::utils::AudioFormat* audioFormat, const SourceConfig& config) { ACSDK_DEBUG9(LX("setSourceCalled").d("name", RequiresShutdown::name()).d("sourceType", "AttachmentReader")); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, &reader, &promise, &config, audioFormat]() { - handleSetAttachmentReaderSource(std::move(reader), config, &promise, audioFormat); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, &reader, promise, &config, audioFormat]() { + handleSetAttachmentReaderSource(std::move(reader), config, promise.get(), audioFormat); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { auto sourceId = future.get(); // Assume that the Attachment is fully buffered - not ideal, revisit if needed. Should be fine for file streams // and resources. @@ -223,13 +223,13 @@ MediaPlayer::SourceId MediaPlayer::setSource( avsCommon::utils::MediaType format) { ACSDK_DEBUG9( LX("setSourceCalled").d("name", RequiresShutdown::name()).d("sourceType", "istream").d("format", format)); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, &stream, repeat, &config, &promise]() { - handleSetIStreamSource(stream, repeat, config, &promise); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, &stream, repeat, &config, promise]() { + handleSetIStreamSource(stream, repeat, config, promise.get()); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { auto sourceId = future.get(); // Assume that the Attachment is fully buffered - not ideal, revisit if needed. Should be fine for file streams // and resources. @@ -246,13 +246,13 @@ MediaPlayer::SourceId MediaPlayer::setSource( bool repeat, const PlaybackContext& playbackContext) { ACSDK_DEBUG9(LX("setSourceForUrlCalled").d("name", RequiresShutdown::name()).sensitive("url", url)); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, url, offset, &config, &promise, repeat]() { - handleSetUrlSource(url, offset, config, &promise, repeat); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, url, offset, &config, promise, repeat]() { + handleSetUrlSource(url, offset, config, promise.get(), repeat); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return ERROR_SOURCE_ID; @@ -281,14 +281,14 @@ bool MediaPlayer::play(MediaPlayer::SourceId id) { m_source->preprocess(); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, id, &promise]() { - handlePlay(id, &promise); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, id, promise]() { + handlePlay(id, promise.get()); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return false; @@ -296,13 +296,13 @@ bool MediaPlayer::play(MediaPlayer::SourceId id) { bool MediaPlayer::stop(MediaPlayer::SourceId id) { ACSDK_DEBUG9(LX("stopCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, id, &promise]() { - handleStop(id, &promise); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, id, promise]() { + handleStop(id, promise.get()); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return false; @@ -310,13 +310,13 @@ bool MediaPlayer::stop(MediaPlayer::SourceId id) { bool MediaPlayer::pause(MediaPlayer::SourceId id) { ACSDK_DEBUG9(LX("pausedCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, id, &promise]() { - handlePause(id, &promise); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, id, promise]() { + handlePause(id, promise.get()); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return false; @@ -324,13 +324,13 @@ bool MediaPlayer::pause(MediaPlayer::SourceId id) { bool MediaPlayer::resume(MediaPlayer::SourceId id) { ACSDK_DEBUG9(LX("resumeCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, id, &promise]() { - handleResume(id, &promise); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, id, promise]() { + handleResume(id, promise.get()); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return false; @@ -338,14 +338,14 @@ bool MediaPlayer::resume(MediaPlayer::SourceId id) { std::chrono::milliseconds MediaPlayer::getOffset(MediaPlayer::SourceId id) { ACSDK_DEBUG9(LX("getOffsetCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, id, &promise]() { - handleGetOffset(id, &promise); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, id, promise]() { + handleGetOffset(id, promise.get()); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return MEDIA_PLAYER_INVALID_OFFSET; @@ -363,14 +363,14 @@ void MediaPlayer::addObserver(std::shared_ptr obse } ACSDK_DEBUG9(LX("addObserverCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, &promise, &observer]() { - handleAddObserver(&promise, observer); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, promise, &observer]() { + handleAddObserver(promise.get(), observer); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { future.wait(); } } @@ -382,27 +382,27 @@ void MediaPlayer::removeObserver(std::shared_ptr o } ACSDK_DEBUG9(LX("removeObserverCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, &promise, &observer]() { - handleRemoveObserver(&promise, observer); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, promise, &observer]() { + handleRemoveObserver(promise.get(), observer); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { future.wait(); } } bool MediaPlayer::setVolume(int8_t volume) { ACSDK_DEBUG9(LX("setVolumeCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, &promise, volume]() { - handleSetVolume(&promise, volume); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, promise, volume]() { + handleSetVolume(promise.get(), volume); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return false; @@ -455,13 +455,13 @@ void MediaPlayer::handleSetVolume(std::promise* promise, int8_t volume) { bool MediaPlayer::setMute(bool mute) { ACSDK_DEBUG9(LX("setMuteCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, &promise, mute]() { - handleSetMute(&promise, mute); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, promise, mute]() { + handleSetMute(promise.get(), mute); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return false; @@ -483,13 +483,13 @@ void MediaPlayer::handleSetMute(std::promise* promise, bool mute) { bool MediaPlayer::getSpeakerSettings(SpeakerInterface::SpeakerSettings* settings) { ACSDK_DEBUG9(LX("getSpeakerSettingsCalled").d("name", RequiresShutdown::name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, &promise, settings]() { - handleGetSpeakerSettings(&promise, settings); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, promise, settings]() { + handleGetSpeakerSettings(promise.get(), settings); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { return future.get(); } return false; @@ -924,13 +924,16 @@ bool MediaPlayer::seek() { return seekSuccessful; } -guint MediaPlayer::queueCallback(const std::function* callback) { +guint MediaPlayer::queueCallback(std::function&& callback) { if (isShutdown()) { return UNQUEUED_CALLBACK; } auto source = g_idle_source_new(); g_source_set_callback( - source, reinterpret_cast(&onCallback), const_cast*>(callback), nullptr); + source, reinterpret_cast(&onCallback), new std::function(std::move(callback)), + [](gpointer data) { + delete reinterpret_cast*>(data); + }); auto sourceId = g_source_attach(source, m_workerContext); g_source_unref(source); return sourceId; @@ -996,13 +999,13 @@ gboolean MediaPlayer::onCallback(const std::function* callback) { void MediaPlayer::onPadAdded(GstElement* decoder, GstPad* pad, gpointer pointer) { auto mediaPlayer = static_cast(pointer); ACSDK_DEBUG9(LX("onPadAddedCalled").d("name", mediaPlayer->name())); - std::promise promise; - auto future = promise.get_future(); - std::function callback = [mediaPlayer, &promise, decoder, pad]() { - mediaPlayer->handlePadAdded(&promise, decoder, pad); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [mediaPlayer, promise, decoder, pad]() { + mediaPlayer->handlePadAdded(promise.get(), decoder, pad); return false; }; - if (mediaPlayer->queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (mediaPlayer->queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { future.wait(); } } @@ -1932,9 +1935,9 @@ void MediaPlayer::setEqualizerBandLevels(EqualizerBandLevelMap bandLevelMap) { if (!m_equalizerEnabled) { return; } - std::promise promise; - auto future = promise.get_future(); - std::function callback = [this, &promise, bandLevelMap]() { + auto promise = std::make_shared>(); + auto future = promise->get_future(); + std::function callback = [this, promise, bandLevelMap]() { auto it = bandLevelMap.find(EqualizerBand::BASS); if (bandLevelMap.end() != it) { g_object_set( @@ -1959,10 +1962,10 @@ void MediaPlayer::setEqualizerBandLevels(EqualizerBandLevelMap bandLevelMap) { static_cast(clampEqualizerLevel(it->second)), NULL); } - promise.set_value(); + promise->set_value(); return false; }; - if (queueCallback(&callback) != UNQUEUED_CALLBACK) { + if (queueCallback(std::move(callback)) != UNQUEUED_CALLBACK) { future.get(); } }