diff --git a/README.md b/README.md index 27c120d..9aef0fd 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ export GLADIA_API_KEY=your_key ./gladia transcribe podcast.mp3 --language en,fr,de ./gladia transcribe mixed.mp3 --code-switching --language en,fr ./gladia transcribe call.wav --diarize -o srt -./gladia transcribe podcast.mp3 --model solaria-3 +./gladia transcribe podcast.mp3 --model solaria-3 --language en ``` ## Commands @@ -64,10 +64,10 @@ export GLADIA_API_KEY=your_key | Flag | Default | Description | |------|---------|-------------| | `-o`, `--output` | `text` | Output: `text`, `json`, `json-full`, `srt`, `vtt` | -| `--language` | — | Expected language(s), comma-separated (`en` or `en,fr,de`) | -| `--code-switching`, `--code-switch` | off | Detect language per utterance | +| `--language` | — | Expected language(s), comma-separated (`en` or `en,fr,de`); narrows detection, does not enable code switching | +| `--cs`, `--code-switching` | off | Re-detect language on each utterance (mixed-language audio; solaria-1 only) | | `--diarize` | off | **Optional.** Identify speakers in the transcript | -| `--model` | — | STT model: `solaria-1` or `solaria-3` (default: API default) | +| `--model` | — | STT model: `solaria-1` or `solaria-3`. Solaria-3 accepts at most one `--language` (`en`, `fr`, `de`, `es`, or `it`) and does not support code switching. | | `-v`, `--verbose` | off | Show progress while polling | **Global flag** (any command): `--gladia-key` — API key if not in env or `~/.gladia` @@ -78,10 +78,10 @@ export GLADIA_API_KEY=your_key |------|-------------| | Auto-detect | `transcribe ` | | Constrain detection | `--language en,fr,de` (no code switching) | -| Code switching | `--code-switching` (+ optional `--language` hints) | +| Code switching | `--cs` or `--code-switching` (+ optional `--language` hints) | -- **`--language`** — tells Gladia which language(s) to expect. Several codes (`en,fr,de`) narrow detection; they do **not** turn on code switching. -- **`--code-switching`** — separate option: re-detect language on each utterance. Combine with `--language` when you know which languages may appear. +- **`--language`** — limits which language(s) Gladia considers (`en,fr,de` is a hint list, not per-utterance switching). +- **`--cs`** / **`--code-switching`** — turns on per-utterance language detection. Add `--language` to restrict which languages may appear. Not available with `solaria-3`. ```bash ./gladia languages # list valid codes diff --git a/cmd/transcribe.go b/cmd/transcribe.go index 4a2267d..86701cd 100644 --- a/cmd/transcribe.go +++ b/cmd/transcribe.go @@ -31,11 +31,11 @@ Examples: gladia transcribe podcast.mp3 --language en gladia transcribe interview.mp3 --code-switching gladia transcribe interview.mp3 --language en,fr,de - gladia transcribe call.wav --code-switch --language en -o json + gladia transcribe call.wav --cs --language en -o json gladia transcribe call.wav --diarize -o srt - gladia transcribe podcast.mp3 --model solaria-3 + gladia transcribe podcast.mp3 --model solaria-3 --language en gladia transcribe https://example.com/audio.mp3 -o json`, - Args: cobra.ExactArgs(1), + Args: validateTranscribeArgs, RunE: func(cmd *cobra.Command, args []string) error { if err := validateOutputFormat(outputFormat); err != nil { return err @@ -45,12 +45,19 @@ Examples: return err } + if err := validateLanguageFlag(languageFlag); err != nil { + return err + } + langs, err := types.ParseLanguages(languageFlag) if err != nil { return err } - codeSwitchSet := cmd.Flags().Changed("code-switching") || cmd.Flags().Changed("code-switch") + codeSwitchSet := cmd.Flags().Changed("code-switching") || cmd.Flags().Changed("cs") + if err := validateModelConfig(modelFlag, langs, codeSwitchSet, codeSwitching); err != nil { + return err + } langConfig, err := buildLanguageConfig(langs, codeSwitching, codeSwitchSet) if err != nil { return err @@ -69,7 +76,7 @@ Examples: } transcriptionReq := gladia.TranscriptionRequest{ - Model: modelFlag, + Model: normalizeModel(modelFlag), LanguageConfig: langConfig, Diarization: diarization, } @@ -91,12 +98,37 @@ Examples: } cmd.Flags().StringVarP(&outputFormat, "output", "o", "text", "Output format: text, json, json-full, srt, vtt") - cmd.Flags().StringVar(&languageFlag, "language", "", "Optional ISO 639-1 code(s), comma-separated (e.g. en or en,fr,de)") - cmd.Flags().BoolVar(&codeSwitching, "code-switching", false, "Enable code switching (detect language per utterance; independent of --language)") - cmd.Flags().BoolVar(&codeSwitching, "code-switch", false, "Alias for --code-switching") + cmd.Flags().StringVar(&languageFlag, "language", "", "Expected language(s), comma-separated (e.g. en or en,fr,de); does not enable code switching") + const codeSwitchingUsage = "Re-detect language on each utterance (for mixed-language audio; solaria-1 only)" + cmd.Flags().BoolVar(&codeSwitching, "cs", false, codeSwitchingUsage) + cmd.Flags().BoolVar(&codeSwitching, "code-switching", false, codeSwitchingUsage) + cmd.Flags().Lookup("cs").Hidden = true + cmd.Flags().Lookup("code-switching").Hidden = true cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Show progress while transcribing") cmd.Flags().BoolVar(&diarization, "diarize", false, "Enable speaker diarization") - cmd.Flags().StringVar(&modelFlag, "model", "", "STT model: solaria-1 or solaria-3 (default: API default)") + cmd.Flags().StringVar(&modelFlag, "model", "", "STT model: solaria-1 or solaria-3 (solaria-3 accepts at most one --language: en, fr, de, es, or it)") + + cmd.SetUsageTemplate(`Usage:{{if .Runnable}} + {{.UseLine}}{{end}}{{if .HasAvailableSubCommands}} + {{.CommandPath}} [command]{{end}}{{if gt (len .Aliases) 0}} + +Aliases: + {{.NameAndAliases}}{{end}}{{if .HasExample}} + +Examples: +{{.Example}}{{end}}{{if .HasAvailableLocalFlags}} + +Flags: + --cs, --code-switching — ` + codeSwitchingUsage + ` +{{.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableInheritedFlags}} + +Global Flags: +{{.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasHelpSubCommands}} + +Additional help topics:{{range .Commands}}{{if .IsAdditionalHelpTopicCommand}} + {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableSubCommands}} + +Use "{{.CommandPath}} [command] --help" for more information about a command.{{end}}`) return cmd } @@ -130,6 +162,7 @@ func validateOutputFormat(format string) error { } func validateModel(model string) error { + model = normalizeModel(model) if model == "" { return nil } @@ -141,6 +174,125 @@ func validateModel(model string) error { } } +var solaria3Languages = map[types.Language]bool{ + types.LanguageEn: true, + types.LanguageFr: true, + types.LanguageDe: true, + types.LanguageEs: true, + types.LanguageIt: true, +} + +func validateModelConfig(model string, langs []types.Language, codeSwitchSet, codeSwitching bool) error { + model = normalizeModel(model) + if model != "solaria-3" { + return nil + } + if codeSwitchSet && codeSwitching { + return fmt.Errorf("solaria-3 does not support code switching (use solaria-1 instead)") + } + switch len(langs) { + case 0: + return nil + case 1: + if !solaria3Languages[langs[0]] { + return fmt.Errorf("solaria-3 does not support language %q (use en, fr, de, es, or it)", langs[0]) + } + return nil + default: + codes := make([]string, len(langs)) + for i, lang := range langs { + codes[i] = string(lang) + } + return fmt.Errorf("solaria-3 accepts only one language, got %d (%s); use solaria-1 for multi-language", len(langs), strings.Join(codes, ", ")) + } +} + +func normalizeModel(model string) string { + model = strings.TrimSpace(strings.ToLower(model)) + return strings.ReplaceAll(model, " ", "-") +} + +func validateTranscribeArgs(cmd *cobra.Command, args []string) error { + if len(args) == 1 { + return nil + } + + langFlag, _ := cmd.Flags().GetString("language") + langFlag = strings.TrimSpace(langFlag) + + // gladia transcribe --language en fr meeting.wav + if len(args) == 2 && isKnownLanguageCode(args[0]) && !isKnownLanguageCode(args[1]) && langFlag != "" { + return spaceSeparatedLanguageError(joinLanguageCodes(langFlag, args[0])) + } + + // gladia transcribe meeting.wav --language en fr + var extraLangs []string + for _, arg := range args[1:] { + if isKnownLanguageCode(arg) { + extraLangs = append(extraLangs, arg) + } + } + if langFlag != "" && len(extraLangs) > 0 { + return spaceSeparatedLanguageError(joinLanguageCodes(append([]string{langFlag}, extraLangs...)...)) + } + + return fmt.Errorf("accepts 1 arg(s), received %d", len(args)) +} + +func validateLanguageFlag(s string) error { + s = strings.TrimSpace(s) + if s == "" || strings.Contains(s, ",") { + return nil + } + if strings.Contains(s, " ") { + parts := strings.Fields(s) + if len(parts) > 1 && allKnownLanguageCodes(parts) { + return spaceSeparatedLanguageError(parts) + } + } + return nil +} + +func spaceSeparatedLanguageError(codes []string) error { + normalized := make([]string, 0, len(codes)) + for _, code := range codes { + code = strings.ToLower(strings.TrimSpace(code)) + if code != "" { + normalized = append(normalized, code) + } + } + return fmt.Errorf("--language expects comma-separated codes (e.g. --language %s), not spaces", strings.Join(normalized, ",")) +} + +func joinLanguageCodes(codes ...string) []string { + out := make([]string, 0, len(codes)) + for _, code := range codes { + code = strings.ToLower(strings.TrimSpace(code)) + if code != "" { + out = append(out, code) + } + } + return out +} + +func allKnownLanguageCodes(codes []string) bool { + for _, code := range codes { + if !isKnownLanguageCode(code) { + return false + } + } + return len(codes) > 0 +} + +func isKnownLanguageCode(code string) bool { + code = strings.TrimSpace(code) + if code == "" { + return false + } + _, err := types.ParseLanguage(code) + return err == nil +} + func isHTTPURL(s string) bool { lower := strings.ToLower(s) return strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://") diff --git a/cmd/transcribe_test.go b/cmd/transcribe_test.go index 9ff8b15..36a8a2d 100644 --- a/cmd/transcribe_test.go +++ b/cmd/transcribe_test.go @@ -16,7 +16,7 @@ import ( ) func TestValidateModel(t *testing.T) { - for _, model := range []string{"", "solaria-1", "solaria-3"} { + for _, model := range []string{"", "solaria-1", "solaria-3", "solaria 3", " Solaria-3 "} { if err := validateModel(model); err != nil { t.Errorf("model %q: %v", model, err) } @@ -26,6 +26,33 @@ func TestValidateModel(t *testing.T) { } } +func TestValidateModelConfig_solaria3(t *testing.T) { + en := types.LanguageEn + fr := types.LanguageFr + ja := types.LanguageJp + + if err := validateModelConfig("solaria-3", []types.Language{en}, false, false); err != nil { + t.Fatalf("single supported language: %v", err) + } + if err := validateModelConfig("solaria-3", nil, false, false); err != nil { + t.Fatalf("no language should be allowed: %v", err) + } + if err := validateModelConfig("solaria-3", []types.Language{en, fr}, false, false); err == nil { + t.Fatal("expected error for multiple languages") + } else if !strings.Contains(err.Error(), "only one language") || !strings.Contains(err.Error(), "en, fr") { + t.Fatalf("unexpected error: %v", err) + } + if err := validateModelConfig("solaria-3", []types.Language{ja}, false, false); err == nil { + t.Fatal("expected error for unsupported language") + } + if err := validateModelConfig("solaria-3", []types.Language{en}, true, true); err == nil { + t.Fatal("expected error for code switching") + } + if err := validateModelConfig("solaria-1", nil, false, false); err != nil { + t.Fatalf("solaria-1 should not require language: %v", err) + } +} + func TestValidateOutputFormat(t *testing.T) { valid := []string{"text", "txt", "json", "json-full", "srt", "vtt"} for _, format := range valid { @@ -183,6 +210,111 @@ func TestTranscribeCommand_invalidModel(t *testing.T) { } } +func TestTranscribeCommand_invalidSolaria3MultipleLanguages(t *testing.T) { + withTempHome(t) + t.Setenv(envGladiaAPIKey, "k") + + cmd := newRootCmd() + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + cmd.SetArgs([]string{"transcribe", "https://example.com/a.wav", "--model", "solaria-3", "--language", "en,fr"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for solaria-3 with multiple languages") + } + if !strings.Contains(err.Error(), "only one language") || !strings.Contains(err.Error(), "en, fr") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestTranscribeCommand_solaria3WithoutLanguage(t *testing.T) { + withTempHome(t) + t.Setenv(envGladiaAPIKey, "test-key") + + var postedBody map[string]interface{} + donePayload := sampleTranscriptionResult() + donePayload.Status = "done" + doneBody, _ := json.Marshal(donePayload) + + server := httptest.NewServer(nil) + base := server.URL + server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && r.URL.Path == "/v2/pre-recorded": + _ = json.NewDecoder(r.Body).Decode(&postedBody) + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]string{"result_url": base + "/result/1"}) + case r.Method == http.MethodGet: + _, _ = w.Write(doneBody) + } + }) + defer server.Close() + + oldEndpoint := gladia.GladiaApiEndpoint + gladia.GladiaApiEndpoint = server.URL + t.Cleanup(func() { gladia.GladiaApiEndpoint = oldEndpoint }) + + cmd := newRootCmd() + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + cmd.SetArgs([]string{"transcribe", "https://example.com/a.wav", "--model", "solaria-3"}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("execute: %v", err) + } + if postedBody["model"] != "solaria-3" { + t.Fatalf("model = %v", postedBody["model"]) + } + if _, ok := postedBody["language_config"]; ok { + t.Fatalf("language_config should be omitted, got %#v", postedBody["language_config"]) + } +} + +func TestTranscribeCommand_spaceSeparatedLanguage(t *testing.T) { + withTempHome(t) + t.Setenv(envGladiaAPIKey, "k") + + cases := []struct { + name string + args []string + want string + }{ + { + name: "after source", + args: []string{"transcribe", "https://example.com/a.wav", "--language", "en", "fr"}, + want: "en,fr", + }, + { + name: "before source", + args: []string{"transcribe", "--language", "en", "fr", "meeting.wav"}, + want: "en,fr", + }, + { + name: "quoted value", + args: []string{"transcribe", "https://example.com/a.wav", "--language", "en fr"}, + want: "en,fr", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cmd := newRootCmd() + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + cmd.SetArgs(tc.args) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "comma-separated") || !strings.Contains(err.Error(), tc.want) { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + func TestTranscribeCommand_invalidLanguage(t *testing.T) { withTempHome(t) t.Setenv(envGladiaAPIKey, "k") @@ -209,7 +341,7 @@ func TestTranscribeCommand_URLTextOutput(t *testing.T) { base := server.URL server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodPost && r.URL.Path == "/v2/transcription/": + case r.Method == http.MethodPost && r.URL.Path == "/v2/pre-recorded": w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(map[string]string{"result_url": base + "/result/1"}) case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/result/"): @@ -253,7 +385,7 @@ func TestTranscribeCommand_codeSwitchingWithoutLanguages(t *testing.T) { base := server.URL server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodPost && r.URL.Path == "/v2/transcription/": + case r.Method == http.MethodPost && r.URL.Path == "/v2/pre-recorded": _ = json.NewDecoder(r.Body).Decode(&postedBody) w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(map[string]string{"result_url": base + "/result/1"}) @@ -299,7 +431,7 @@ func TestTranscribeCommand_modelRequestBody(t *testing.T) { base := server.URL server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodPost && r.URL.Path == "/v2/transcription/": + case r.Method == http.MethodPost && r.URL.Path == "/v2/pre-recorded": _ = json.NewDecoder(r.Body).Decode(&postedBody) w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(map[string]string{"result_url": base + "/result/1"}) @@ -333,7 +465,7 @@ func TestTranscribeCommand_modelRequestBody(t *testing.T) { }) t.Run("with --model solaria-3", func(t *testing.T) { - run("--model", "solaria-3") + run("--model", "solaria-3", "--language", "en") if postedBody["model"] != "solaria-3" { t.Fatalf("model = %v", postedBody["model"]) } @@ -353,7 +485,7 @@ func TestTranscribeCommand_diarizationRequestBody(t *testing.T) { base := server.URL server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodPost && r.URL.Path == "/v2/transcription/": + case r.Method == http.MethodPost && r.URL.Path == "/v2/pre-recorded": _ = json.NewDecoder(r.Body).Decode(&postedBody) w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(map[string]string{"result_url": base + "/result/1"}) @@ -420,7 +552,7 @@ func TestTranscribeCommand_languageAndCodeSwitching(t *testing.T) { base := server.URL server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodPost && r.URL.Path == "/v2/transcription/": + case r.Method == http.MethodPost && r.URL.Path == "/v2/pre-recorded": _ = json.NewDecoder(r.Body).Decode(&postedBody) w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(map[string]string{"result_url": base + "/result/1"}) diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 92c3eca..52ce7a3 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -9,7 +9,7 @@ func TestGladiaClient_apiURL(t *testing.T) { } c.GladiaEndpoint = "https://api.gladia.io/" - if got := c.apiURL("/v2/transcription/"); got != "https://api.gladia.io/v2/transcription/" { + if got := c.apiURL("/v2/pre-recorded"); got != "https://api.gladia.io/v2/pre-recorded" { t.Fatalf("got %q", got) } } diff --git a/pkg/client/transcribe.go b/pkg/client/transcribe.go index 742ad57..31533a8 100644 --- a/pkg/client/transcribe.go +++ b/pkg/client/transcribe.go @@ -43,11 +43,11 @@ type TranscriptionRequest struct { LanguageConfig *LanguageConfig `json:"language_config,omitempty"` Diarization bool `json:"diarization,omitempty"` DiarizationConfig *DiarizationConfig `json:"diarization_config,omitempty"` - Summarization bool `json:"summarization"` - SummarizationConfig *SummarizationConfig `json:"summarization_config"` - Translation bool `json:"translation"` - TranslationConfig *TranslationConfig `json:"translation_config"` - CustomVocabulary []string `json:"custom_vocabulary"` + Summarization bool `json:"summarization,omitempty"` + SummarizationConfig *SummarizationConfig `json:"summarization_config,omitempty"` + Translation bool `json:"translation,omitempty"` + TranslationConfig *TranslationConfig `json:"translation_config,omitempty"` + CustomVocabulary []string `json:"custom_vocabulary,omitempty"` } type TranslationConfig struct { @@ -231,9 +231,8 @@ func (c *GladiaClient) UploadFile(filePath string) (string, error) { return uploadResp.AudioURL, nil } -// TranscribeAudioURL calls the /v2/transcription/ endpoint using the provided audioURL. +// TranscribeAudioURL calls the /v2/pre-recorded endpoint using the provided audioURL. func (c *GladiaClient) TranscribeAudioURL(audioURL string, reqBody TranscriptionRequest) (*TranscriptionResult, error) { - // Set the audio URL in the request body. reqBody.AudioURL = audioURL requestData, err := json.Marshal(reqBody) @@ -241,7 +240,7 @@ func (c *GladiaClient) TranscribeAudioURL(audioURL string, reqBody Transcription return nil, fmt.Errorf("failed to marshal transcription request: %w", err) } - resp, err := c.createAndExecuteRequest("POST", c.apiURL("/v2/transcription/"), bytes.NewReader(requestData)) + resp, err := c.createAndExecuteRequest("POST", c.apiURL("/v2/pre-recorded"), bytes.NewReader(requestData)) if err != nil { return nil, fmt.Errorf("transcription request failed: %w", err) } diff --git a/pkg/client/transcribe_test.go b/pkg/client/transcribe_test.go index 4981a82..afd7bfc 100644 --- a/pkg/client/transcribe_test.go +++ b/pkg/client/transcribe_test.go @@ -69,7 +69,7 @@ func TestTranscribeAudioURL_success(t *testing.T) { base := server.URL server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodPost && r.URL.Path == "/v2/transcription/": + case r.Method == http.MethodPost && r.URL.Path == "/v2/pre-recorded": _ = json.NewDecoder(r.Body).Decode(&posted) w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(TranscriptionResponse{ResultURL: base + "/poll"})