YARP
Yet Another Robot Platform
 
Loading...
Searching...
No Matches
whisperSpeechTranscription.cpp
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: 2023-2023 Istituto Italiano di Tecnologia (IIT)
3 * SPDX-License-Identifier: BSD-3-Clause
4 */
5
7
8#include <yarp/os/Log.h>
9#include <yarp/os/LogStream.h>
11
12#include <cstdio>
13#include <cstdlib>
14#include <algorithm>
15#include <thread>
16#include <regex>
17
18using namespace yarp::os;
19using namespace yarp::dev;
20
21namespace {
22YARP_LOG_COMPONENT(WHISPER_SPEECHTR, "yarp.device.WhisperSpeechTranscription")
23}
24
26{
27 m_language = "en";
28 m_wparams.language=m_language.c_str();
30 m_model = "ggml-base.en.bin";
31}
32
37
39{
40 if (config.check("threads","number of threads")){
41 m_wparams.n_threads = config.find("threads").asInt16();}
42 if (config.check("processors", "number of processors")){
43 m_wparams.n_threads = config.find("processors").asInt16();}
44 if (config.check("initial_prompt")) {
45 m_wparams.initial_prompt = config.find("initial_prompt").asString().c_str();}
46 if (config.check("duration","duration of audio to process in milliseconds")) {
47 m_wparams.duration_ms = config.find("duration").asInt32();}
48 if (config.check("offset_ms")) {
49 m_wparams.offset_ms = config.find("offset_ms").asInt32();}
50 if (config.check("speed_up")) {
51 m_wparams.speed_up = config.find("speed_up").asBool();}
52 if (config.check("thold_pt","word timestamp probability threshold")) {
53 m_wparams.thold_pt = config.find("thold_pt").asFloat32();}
54 if (config.check("entropy_thold","entropy threshold for decoder fail")) {
55 m_wparams.entropy_thold = config.find("entropy_thold").asFloat32();}
56 if (config.check("logprob_thold","log probability threshold for decoder fail")) {
57 m_wparams.logprob_thold = config.find("logprob_thold").asFloat32();}
58 if (config.check("print_timestamps", "print_timestamps")) {
59 m_wparams.logprob_thold = config.find("print_timestamps").asFloat32();}
60 if (config.check("model", "file containing the model")) {
61 m_model = config.find("model").asString();}
62 if (config.check("translate", "translate from source language to English")) {
63 m_wparams.translate = config.find("translate").asBool();}
64// if (config.check("diarize", "stereo audio diarization")) {
65// m_wparams.diarize = config.find("diarize").asBool();}
66 if(config.check("print_realtime", "print_realtime")) {
67 m_wparams.print_realtime = config.find("print_realtime").asBool();}
68 if(config.check("print_progress", "print_progress")) {
69 m_wparams.print_progress = config.find("print_progress").asBool();}
70 if (config.check("split_on_word", "split on word rather than on token")) {
71 m_wparams.split_on_word = config.find("split_on_word").asBool();}
72 if (config.check("best_of", "number of best candidates to keep")) {
73 m_wparams.greedy.best_of = config.find("best_of").asInt32();}
74 if (config.check("detect-language", "exit after automatically detecting language")) {
75 m_wparams.detect_language = config.find("detect-language").asBool();}
76 if (config.check("language", "spoken language ('auto' for auto-detect)")) {
77 m_wparams.language = config.find("language").asString().c_str();
78 m_language = m_wparams.language;//???
79 }
80 if (config.check("beam_size", " beam size for beam search")) {
81 m_wparams.beam_search.beam_size = config.find("beam_size").asInt32();
82 m_wparams.strategy = m_wparams.beam_search.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
83 }
85 int32_t max_len = 0;
86 bool no_fallback = false;
87 if (config.check("max-context", "maximum number of text context tokens to store")) {
88 max_context = config.find("max-context").asInt32();}
89 if (config.check("max-len", "maximum segment length in characters")) {
90 max_len = config.find("max-len").asInt32();}
91 if (config.check("no-fallback", "do not use temperature fallback while decoding")) {
92 no_fallback = config.find("no-fallback").asBool();}
93 if (config.check("remove_symbols","remove [] symbols from the text transcript")) {
94 m_no_symbols = config.find("remove_symbols").asBool();}
95 m_wparams.n_max_text_ctx = max_context >= 0 ? max_context : m_wparams.n_max_text_ctx;
96 m_wparams.token_timestamps = false || max_len > 0;
97 m_wparams.max_len = false && max_len == 0 ? 60 : max_len;
98 m_wparams.temperature_inc = no_fallback ? 0.0f : m_wparams.temperature_inc;
99
100 if (m_language != "auto" && whisper_lang_id(m_language.c_str()) == -1)
101 {
102 yCError(WHISPER_SPEECHTR, "error: unknown language '%s'\n", m_language.c_str());
103 return false;
104 }
105
106 // whisper init
107 if (m_model=="")
108 {
109 yCError(WHISPER_SPEECHTR, "Please provide full path to the model file with parameter --model\n");
110 return false;
111 }
112 m_ctx = whisper_init_from_file(m_model.c_str());
113 if (m_ctx == nullptr)
114 {
115 yCError(WHISPER_SPEECHTR, "Failed to initialize whisper context\n");
116 return false;
117 }
118
119 // print system information
120 {
121 yCInfo(WHISPER_SPEECHTR, "system_info: n_threads = %d / %d | %s\n",
122 m_wparams.n_threads * n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
123 }
124
125 // print some info about the processing
126 {
127 if (!whisper_is_multilingual(m_ctx))
128 {
129 if (m_wparams.language != "en" || m_wparams.translate)
130 {
131 m_wparams.language = "en";
132 m_wparams.translate = false;
133 yCWarning(WHISPER_SPEECHTR,"model is not multilingual, ignoring language and translation options");
134 }
135 }
136 if (m_wparams.detect_language)
137 {
138 m_wparams.language = "auto";
139 }
140 yCDebug(WHISPER_SPEECHTR, "%s: processing (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
141 __func__, int(m_pcmf32.size()), float(m_pcmf32.size()) / WHISPER_SAMPLE_RATE,
142 m_wparams.n_threads, n_processors,
143 m_wparams.language,
144 m_wparams.translate ? "translate" : "transcribe",
145 m_wparams.print_timestamps);
146 }
147 return true;
148}
149
151{
152 if (m_ctx)
153 {
154 whisper_free(m_ctx);
155 m_ctx=nullptr;
156 }
157 return true;
158}
159
160bool WhisperSpeechTranscription::setLanguage(const std::string& language)
161{
162 m_language=language;
163 yCInfo(WHISPER_SPEECHTR) << "Language set to" << language;
164 return true;
165}
166
168{
169 language = m_language;
170 return true;
171}
172
173bool WhisperSpeechTranscription::transcribe(const yarp::sig::Sound& sound, std::string& transcription, double& score)
174{
175 yCDebug(WHISPER_SPEECHTR)<< "received audio" << sound.getSamples();
176
177 score=0;
178 transcription="";
179
180 if (sound.getSamples() == 0 ||
181 sound.getChannels() == 0)
182 {
183 yCError(WHISPER_SPEECHTR) << "Invalid Sound sample received";
184 return false;
185 }
186
187 //copy the audio data;
188#if 0
189 m_pcmf32.resize(1000);
190 m_pcmf32s.clear();
191#else
192 m_pcmf32.resize(sound.getSamples());
193 size_t channel=0;
194 for (size_t i = 0; i < m_pcmf32.size(); i++)
195 {
196 m_pcmf32[i]=float(sound.get(i, channel))/65535.0;
197 }
198 m_pcmf32s.clear();
199#endif
200
201 // run the inference
202 {
203 if (whisper_full_parallel(m_ctx, m_wparams, m_pcmf32.data(), m_pcmf32.size(), n_processors) != 0)
204 {
205 yCError(WHISPER_SPEECHTR, "failed to process audio");
206 return false;
207 }
208 }
209
210 // output stuff
211 {
212 const int n_segments = whisper_full_n_segments(m_ctx);
213 for (int i = 0; i < n_segments; ++i) {
214 const char* text = whisper_full_get_segment_text(m_ctx, i);
215 transcription += std::string(text);
217 }
218 }
219
220 //remove symbols such as [bla bla]
221 if (m_no_symbols)
222 {
223 std::regex pattern1("\\[[^\\]]*\\]");
224 std::string input = transcription;
225 transcription = std::regex_replace(input, pattern1, "");
226 }
227
228 //assign the score
229 score = 1.0;
230 if (transcription.empty()) {score = 0.0;}
231 return true;
232}
virtual bool setLanguage(const std::string &language) override
Sets the language for speech transcription.
bool open(yarp::os::Searchable &config) override
Open the DeviceDriver.
bool close() override
Close the DeviceDriver.
virtual bool transcribe(const yarp::sig::Sound &sound, std::string &transcription, double &score) override
Performs the speech transcription.
virtual bool getLanguage(std::string &language) override
Gets the current language set for speech transcription.
A mini-server for performing network communication in the background.
A base class for nested structures that can be searched.
Definition Searchable.h:31
virtual bool check(const std::string &key) const =0
Check if there exists a property of the given name.
virtual Value & find(const std::string &key) const =0
Gets a value corresponding to a given keyword.
Class for storing sounds See Audio in YARP for additional documentation on YARP audio.
Definition Sound.h:25
size_t getChannels() const
Get the number of channels of the sound.
Definition Sound.cpp:603
audio_sample get(size_t sample, size_t channel=0) const
Definition Sound.cpp:294
size_t getSamples() const
Get the number of samples contained in the sound.
Definition Sound.cpp:598
#define yCInfo(component,...)
#define yCError(component,...)
#define yCWarning(component,...)
#define yCDebug(component,...)
#define YARP_LOG_COMPONENT(name,...)
For streams capable of holding different kinds of content, check what they actually have.
An interface to the operating system, including Port based communication.