Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lipsync pipeline with Parler-TTS #3175

Draft
wants to merge 8 commits into
base: ai-video
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,20 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice)
}
case "lipsync":
_, ok := capabilityConstraints[core.Capability_Lipsync]
if !ok {
aiCaps = append(aiCaps, core.Capability_Lipsync)
capabilityConstraints[core.Capability_Lipsync] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_Lipsync].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_Lipsync, config.ModelID, autoPrice)
}
}

if len(aiCaps) > 0 {
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type AI interface {
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Lipsync(context.Context, worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const (
Capability_Upscale
Capability_AudioToText
Capability_SegmentAnything2
Capability_Lipsync
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -116,6 +117,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_SegmentAnything2: "Segment anything 2",
Capability_Lipsync: "Lipsync",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -207,6 +209,7 @@ func OptionalCapabilities() []Capability {
Capability_Upscale,
Capability_AudioToText,
Capability_SegmentAnything2,
Capability_Lipsync,
}
}

Expand Down
7 changes: 7 additions & 0 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe
return orch.node.SegmentAnything2(ctx, req)
}

func (orch *orchestrator) Lipsync(ctx context.Context, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error) {
return orch.node.Lipsync(ctx, req)
}

func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error {
if orch.node == nil || orch.node.Recipient == nil {
return nil
Expand Down Expand Up @@ -970,6 +974,9 @@ func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.GenAudioToTex
func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return n.AIWorker.SegmentAnything2(ctx, req)
}
func (n *LivepeerNode) Lipsync(ctx context.Context, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error) {
return n.AIWorker.Lipsync(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,5 @@ require (
lukechampine.com/blake3 v1.2.1 // indirect
rsc.io/tmplfunc v0.0.3 // indirect
)

replace github.com/livepeer/ai-worker => ../ai-worker
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,6 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.5.0 h1:dgO6j9QVFPOq9omIcgB1YmgVSlhV94BMb6QO4WUocX8=
github.com/livepeer/ai-worker v0.5.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
56 changes: 56 additions & 0 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))
lp.transRPC.Handle("/lipsync", oapiReqValidator(lp.Lipsync()))

return nil
}
Expand Down Expand Up @@ -181,6 +182,49 @@ func (h *lphttp) SegmentAnything2() http.Handler {
})
}

func (h *lphttp) Lipsync() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

// Log the remote IP for debugging
clog.Infof(ctx, "Received lipsync request from %s", remoteAddr)

// Check for any errors in reading the multipart form
multiRdr, err := r.MultipartReader()
if err != nil {
clog.Errorf(ctx, "Failed to read multipart form: %v", err)
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

// Bind the multipart request to the struct
var req worker.GenLipsyncMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
clog.Errorf(ctx, "Failed to bind multipart request: %v", err)
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

// Log and check the model_id
if req.ModelId == nil || *req.ModelId == "" {
defaultModelId := "parler-tts/parler-tts-large-v1"
req.ModelId = &defaultModelId
} else {
clog.Infof(ctx, "model_id received: %s", *req.ModelId)
}

// Additional debug for other form fields (if needed)
clog.Infof(ctx, "Text input: %v", req.TextInput)
clog.Infof(ctx, "Received image file: %v", req.Image)

// Call the handleAIRequest function
handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -324,6 +368,18 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.GenLipsyncMultipartRequestBody:
pipeline = "lipsync"
cap = core.Capability_Lipsync
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.Lipsync(ctx, v)
}

// TODO(pschroedl): Infer length of video based on tokenizing text input or length of audio input file
outPixels = int64(1000)
default:
respondWithError(w, "Unknown request type", http.StatusBadRequest)
return
Expand Down
55 changes: 55 additions & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2()))
ls.HTTPMux.Handle("/lipsync", oapiReqValidator(ls.Lipsync()))

return nil
}
Expand Down Expand Up @@ -428,6 +429,60 @@ func (ls *LivepeerServer) SegmentAnything2() http.Handler {
})
}

func (ls *LivepeerServer) Lipsync() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.GenLipsyncMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received Lipsync request; image_size=%v model_id=%v", req.Image.FileSize(), req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processLipsync(ctx, params, req)
if err != nil {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
if errors.As(err, &serviceUnavailableErr) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
if errors.As(err, &badRequestErr) {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed Lipsync request model_id=%v took=%v", req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}


func (ls *LivepeerServer) ImageToVideoResult() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
107 changes: 107 additions & 0 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-x
const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler"
const defaultAudioToTextModelID = "openai/whisper-large-v3"
const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large"
const defaultLipsyncModelID = "parler-tts/parler-tts-large-v1"

type ServiceUnavailableError struct {
err error
Expand Down Expand Up @@ -792,6 +793,103 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess
return &res, nil
}

func processLipsync(ctx context.Context, params aiRequestParams, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
}

txtResp := resp.(*worker.VideoBinaryResponse)

return txtResp, nil
}

func submitLipsync(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error) {
var buf bytes.Buffer
mw, err := worker.NewLipsyncMultipartWriter(&buf, req)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

// Optionally process audio or image for metadata (e.g., duration, size)
// Assuming audio or image can be optional, you can adjust accordingly
imageRdr, err := req.Image.Reader()
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
config, _, err := image.DecodeConfig(imageRdr)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

// Calculate the output size/frames (this could be based on the input image size or audio length)
outFrames := int64(config.Height) * int64(config.Width)

// Prepare payment and balance update if applicable
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outFrames)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)

// Send the request and measure the processing time
// start := time.Now()
resp, err := client.GenLipsyncWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders)
// took := time.Since(start)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

// Check for errors in the response
if resp.JSON200 == nil {
// Handle the case where the response is not a 200 success
return nil, errors.New(strings.TrimSuffix(string(resp.Body), "\n"))
}

// Update the balance as receiving change if relevant
if balUpdate != nil {
balUpdate.Status = ReceivedChange
}

// Calculate the latency score for this lipsync request
// sess.LatencyScore = CalculateLipsyncLatencyScore(took, outFrames)

// Log the AI request completion with latency score and pricing
if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}

monitor.AIRequestFinished(ctx, "lipsync", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}

return resp.JSON200, nil
}

func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) {
var cap core.Capability
var modelID string
Expand Down Expand Up @@ -852,6 +950,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitSegmentAnything2(ctx, params, sess, v)
}
case worker.GenLipsyncMultipartRequestBody:
cap = core.Capability_Lipsync
modelID = defaultLipsyncModelID
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLipsync(ctx, params, sess, v)
}
default:
return nil, fmt.Errorf("unsupported request type %T", req)
}
Expand Down
1 change: 1 addition & 0 deletions server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type Orchestrator interface {
Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Lipsync(ctx context.Context, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error)
}

// Balance describes methods for a session's balance maintenance
Expand Down
6 changes: 6 additions & 0 deletions server/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ func (r *stubOrchestrator) AudioToText(ctx context.Context, req worker.GenAudioT
func (r *stubOrchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return nil, nil
}
func (r *stubOrchestrator) Lipsync(ctx context.Context, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoResponse, error) {
return nil, nil
}
func (r *stubOrchestrator) CheckAICapacity(pipeline, modelID string) bool {
return true
}
Expand Down Expand Up @@ -1391,6 +1394,9 @@ func (r *mockOrchestrator) AudioToText(ctx context.Context, req worker.GenAudioT
func (r *mockOrchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return nil, nil
}
func (r *mockOrchestrator) Lipsync(ctx context.Context, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoResponse, error) {
return nil, nil
}
func (r *mockOrchestrator) CheckAICapacity(pipeline, modelID string) bool {
return true
}
Expand Down
Loading