Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion core/http/endpoints/elevenlabs/soundgeneration.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/audio"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
Expand Down Expand Up @@ -51,7 +52,11 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
if err != nil {
return err
}
return c.Attachment(filePath, filepath.Base(filePath))

filePath, contentType := audio.NormalizeAudioFile(filePath)
if contentType != "" {
c.Response().Header().Set("Content-Type", contentType)
}
return c.Attachment(filePath, filepath.Base(filePath))
}
}
5 changes: 5 additions & 0 deletions core/http/endpoints/elevenlabs/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/audio"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
Expand Down Expand Up @@ -39,6 +40,10 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
if err != nil {
return err
}
filePath, contentType := audio.NormalizeAudioFile(filePath)
if contentType != "" {
c.Response().Header().Set("Content-Type", contentType)
}
return c.Attachment(filePath, filepath.Base(filePath))
}
}
11 changes: 7 additions & 4 deletions core/http/endpoints/localai/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"

"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/xlog"

"github.com/mudler/LocalAI/pkg/audio"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
)

// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
Expand Down Expand Up @@ -86,6 +85,10 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
return err
}

filePath, contentType := audio.NormalizeAudioFile(filePath)
if contentType != "" {
c.Response().Header().Set("Content-Type", contentType)
}
return c.Attachment(filePath, filepath.Base(filePath))
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ require (
)

require (
github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/labstack/gommon v0.4.2 // indirect
github.com/swaggo/files/v2 v2.0.2 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ github.com/decred/dcrd/crypto/blake256 v1.1.0 h1:zPMNGQCm0g4QTY27fOCorQW7EryeQ/U
github.com/decred/dcrd/crypto/blake256 v1.1.0/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8 h1:OtSeLS5y0Uy01jaKK4mA/WVIYtpzVm63vLVAPzJXigg=
github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8/go.mod h1:apkPC/CR3s48O2D7Y++n1XWEpgPNNCjXYga3PPbJe2E=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
Expand Down
130 changes: 130 additions & 0 deletions pkg/audio/identify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package audio

import (
"io"
"os"
"path/filepath"
"strings"

"github.com/dhowden/tag"
"github.com/mudler/xlog"
)

// extensionFromFileType returns the file extension for tag.FileType.
func extensionFromFileType(ft tag.FileType) string {
switch ft {
case tag.FLAC:
return "flac"
case tag.MP3:
return "mp3"
case tag.OGG:
return "ogg"
case tag.M4A:
return "m4a"
case tag.M4B:
return "m4b"
case tag.M4P:
return "m4p"
case tag.ALAC:
return "m4a"
case tag.DSF:
return "dsf"
default:
return ""
}
}

// contentTypeFromFileType returns the MIME type for tag.FileType.
func contentTypeFromFileType(ft tag.FileType) string {
switch ft {
case tag.FLAC:
return "audio/flac"
case tag.MP3:
return "audio/mpeg"
case tag.OGG:
return "audio/ogg"
case tag.M4A, tag.M4B, tag.M4P, tag.ALAC:
return "audio/mp4"
case tag.DSF:
return "audio/dsd"
default:
return ""
}
}

// Identify reads from r and returns the detected audio extension and Content-Type.
// It uses github.com/dhowden/tag to identify the format from the stream.
// Returns ("", "", err) if the format could not be identified.
func Identify(r io.ReadSeeker) (ext string, contentType string, err error) {
_, fileType, err := tag.Identify(r)
if err != nil || fileType == tag.UnknownFileType {
return "", "", err
}
ext = extensionFromFileType(fileType)
contentType = contentTypeFromFileType(fileType)
if ext == "" || contentType == "" {
return "", "", nil
}
return ext, contentType, nil
}

// ContentTypeFromExtension returns the MIME type for common audio file extensions.
// Use as a fallback when Identify fails or when the file is not openable.
func ContentTypeFromExtension(path string) string {
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
switch ext {
case "flac":
return "audio/flac"
case "mp3":
return "audio/mpeg"
case "wav":
return "audio/wav"
case "ogg":
return "audio/ogg"
case "m4a", "m4b", "m4p":
return "audio/mp4"
case "webm":
return "audio/webm"
default:
return ""
}
}

// NormalizeAudioFile opens the file at path, identifies its format with tag.Identify,
// and renames the file to have the correct extension if the current one does not match.
// It returns the path to use (possibly the renamed file) and the Content-Type to set.
// If identification fails, returns (path, ContentTypeFromExtension(path)).
func NormalizeAudioFile(path string) (finalPath string, contentType string) {
finalPath = path
f, err := os.Open(path)
if err != nil {
contentType = ContentTypeFromExtension(path)
return finalPath, contentType
}
defer f.Close()

ext, ct, identifyErr := Identify(f)
if identifyErr != nil || ext == "" || ct == "" {
contentType = ContentTypeFromExtension(path)
return finalPath, contentType
}
contentType = ct

currentExt := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
if currentExt == ext {
return finalPath, contentType
}

dir := filepath.Dir(path)
base := filepath.Base(path)
baseNoExt := strings.TrimSuffix(base, filepath.Ext(base))
if baseNoExt == "" {
baseNoExt = base
}
newPath := filepath.Join(dir, baseNoExt+"."+ext)
if renameErr := os.Rename(path, newPath); renameErr != nil {
xlog.Debug("Could not rename audio file to match type", "from", path, "to", newPath, "error", renameErr)
return finalPath, contentType
}
return newPath, contentType
}
10 changes: 9 additions & 1 deletion tests/e2e/mock-backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,15 @@ func (m *MockBackend) GenerateVideo(ctx context.Context, in *pb.GenerateVideoReq

func (m *MockBackend) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
xlog.Debug("TTS called", "text", in.Text)
// Return success - actual audio would be in the Result message for real backends
dst := in.GetDst()
if dst != "" {
if err := os.MkdirAll(filepath.Dir(dst), 0750); err != nil {
return &pb.Result{Message: err.Error(), Success: false}, nil
}
if err := writeMinimalWAV(dst); err != nil {
return &pb.Result{Message: err.Error(), Success: false}, nil
}
}
return &pb.Result{
Message: "TTS audio generated successfully (mocked)",
Success: true,
Expand Down
33 changes: 19 additions & 14 deletions tests/e2e/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,20 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
Describe("TTS APIs", func() {
Context("TTS", func() {
It("should generate mocked audio", func() {
req, err := http.NewRequest("POST", apiURL+"/audio/speech", nil)
body := `{"model":"mock-model","input":"Hello world","voice":"default"}`
req, err := http.NewRequest("POST", apiURL+"/audio/speech", io.NopCloser(strings.NewReader(body)))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")

body := `{"model":"mock-model","input":"Hello world","voice":"default"}`
req.Body = http.NoBody
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(body)), nil
}

// Use direct HTTP client for TTS endpoint
httpClient := &http.Client{Timeout: 30 * time.Second}
resp, err := httpClient.Do(req)
if err == nil {
defer resp.Body.Close()
Expect(resp.StatusCode).To(BeNumerically("<", 500))
}
Expect(err).ToNot(HaveOccurred())
defer resp.Body.Close()
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "TTS response should set an audio Content-Type")
data, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(len(data)).To(BeNumerically(">", 0), "TTS response body should be non-empty")
})
})
})
Expand All @@ -107,7 +104,11 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
resp, err := httpClient.Do(req)
Expect(err).ToNot(HaveOccurred())
defer resp.Body.Close()
Expect(resp.StatusCode).To(BeNumerically("<", 500))
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "sound-generation response should set an audio Content-Type (pkg/audio normalization)")
data, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(len(data)).To(BeNumerically(">", 0), "sound-generation response body should be non-empty")
})

It("should generate mocked sound (advanced mode)", func() {
Expand All @@ -120,7 +121,11 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
resp, err := httpClient.Do(req)
Expect(err).ToNot(HaveOccurred())
defer resp.Body.Close()
Expect(resp.StatusCode).To(BeNumerically("<", 500))
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "sound-generation response should set an audio Content-Type (pkg/audio normalization)")
data, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(len(data)).To(BeNumerically(">", 0), "sound-generation response body should be non-empty")
})
})

Expand Down
Loading