feat: stream download cli

This commit is contained in:
2025-10-10 21:01:34 +02:00
commit 3f46bcc1bd
7 changed files with 375 additions and 0 deletions

View File

@@ -0,0 +1,271 @@
package downloader
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"sort"
"strings"
"time"
"github.com/grafov/m3u8"
)
const (
defaultClientTimeout = 30 * time.Second
maxRedirects = 5
)
var (
errInvalidPlaylist = errors.New("unsupported playlist type. expected media playlist")
errUnsupportedMaster = errors.New("master playlist contains no playable variants")
errFFmpegMissing = errors.New("ffmpeg is required on PATH to transmux segments")
)
type segmentInfo struct {
Sequence uint64
URI string
}
// Download downloads the video stream referenced by the given URL into a file named outputName.
// If outputName is empty, the base name of the stream URL is used with an .mp4 extension.
func Download(ctx context.Context, streamURL, outputName string) error {
if streamURL == "" {
return errors.New("stream URL must not be empty")
}
parsed, err := url.Parse(streamURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
client := &http.Client{Timeout: defaultClientTimeout}
playlistBody, finalURL, err := fetchWithRedirects(ctx, client, parsed, maxRedirects)
if err != nil {
return fmt.Errorf("fetch playlist: %w", err)
}
defer playlistBody.Close()
playlist, listType, err := m3u8.DecodeFrom(bufio.NewReader(playlistBody), true)
if err != nil {
return fmt.Errorf("parse playlist: %w", err)
}
var segments []segmentInfo
switch listType {
case m3u8.MEDIA:
mediaPlaylist, ok := playlist.(*m3u8.MediaPlaylist)
if !ok {
return errInvalidPlaylist
}
segments = collectSegments(mediaPlaylist)
case m3u8.MASTER:
masterPlaylist, ok := playlist.(*m3u8.MasterPlaylist)
if !ok {
return errInvalidPlaylist
}
variantURL, err := selectVariant(masterPlaylist, finalURL)
if err != nil {
return err
}
return Download(ctx, variantURL, outputName)
default:
return errInvalidPlaylist
}
if len(segments) == 0 {
return errors.New("playlist contains no segments")
}
if outputName == "" {
outputName = inferOutputName(finalURL)
}
tempTS, err := os.CreateTemp("", "sdl-*.ts")
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
defer func() {
tempTS.Close()
os.Remove(tempTS.Name())
}()
for i, segment := range segments {
if err := downloadSegment(ctx, client, finalURL, segment, tempTS); err != nil {
return fmt.Errorf("download segment %d: %w", i, err)
}
}
if _, err := tempTS.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("rewind temp file: %w", err)
}
mp4Name := ensureMP4Extension(outputName)
if err := transmuxToMP4(ctx, tempTS.Name(), mp4Name); err != nil {
return err
}
return nil
}
func collectSegments(playlist *m3u8.MediaPlaylist) []segmentInfo {
segments := make([]segmentInfo, 0, len(playlist.Segments))
for _, segment := range playlist.Segments {
if segment == nil || segment.URI == "" {
continue
}
segments = append(segments, segmentInfo{Sequence: segment.SeqId, URI: segment.URI})
}
sort.Slice(segments, func(i, j int) bool {
return segments[i].Sequence < segments[j].Sequence
})
return segments
}
func selectVariant(master *m3u8.MasterPlaylist, base *url.URL) (string, error) {
var chosen *m3u8.Variant
for _, variant := range master.Variants {
if variant == nil || variant.URI == "" {
continue
}
if chosen == nil || variant.Bandwidth > chosen.Bandwidth {
chosen = variant
}
}
if chosen == nil {
return "", errUnsupportedMaster
}
resolved, err := base.Parse(chosen.URI)
if err != nil {
return "", fmt.Errorf("parse variant URL: %w", err)
}
return resolved.String(), nil
}
func downloadSegment(ctx context.Context, client *http.Client, base *url.URL, segment segmentInfo, output io.Writer) error {
segmentURL, err := base.Parse(segment.URI)
if err != nil {
return fmt.Errorf("parse segment URL: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, segmentURL.String(), nil)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("fetch segment: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status %s", resp.Status)
}
if _, err := io.Copy(output, resp.Body); err != nil {
return fmt.Errorf("write segment: %w", err)
}
return nil
}
func fetchWithRedirects(ctx context.Context, client *http.Client, streamURL *url.URL, redirects int) (io.ReadCloser, *url.URL, error) {
currentURL := streamURL
for count := 0; count <= redirects; count++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, currentURL.String(), nil)
if err != nil {
return nil, nil, fmt.Errorf("create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http get: %w", err)
}
switch {
case resp.StatusCode >= 300 && resp.StatusCode < 400:
location := resp.Header.Get("Location")
resp.Body.Close()
if location == "" {
return nil, nil, errors.New("redirect without location header")
}
nextURL, err := currentURL.Parse(location)
if err != nil {
return nil, nil, fmt.Errorf("parse redirect URL: %w", err)
}
currentURL = nextURL
continue
case resp.StatusCode == http.StatusOK:
return resp.Body, currentURL, nil
default:
resp.Body.Close()
return nil, nil, fmt.Errorf("unexpected status %s", resp.Status)
}
}
return nil, nil, errors.New("too many redirects")
}
func inferOutputName(streamURL *url.URL) string {
base := path.Base(streamURL.Path)
if base == "." || base == "/" || base == "" {
base = "download"
}
trimmed := strings.TrimSuffix(base, path.Ext(base))
if trimmed == "" {
trimmed = "download"
}
return trimmed + ".mp4"
}
func ensureMP4Extension(name string) string {
lower := strings.ToLower(name)
if strings.HasSuffix(lower, ".mp4") {
return name
}
ext := filepath.Ext(name)
trimmed := strings.TrimSuffix(name, ext)
if trimmed == "" {
trimmed = "output"
}
return trimmed + ".mp4"
}
func transmuxToMP4(ctx context.Context, tsPath, mp4Path string) error {
if _, err := exec.LookPath("ffmpeg"); err != nil {
return errFFmpegMissing
}
if err := os.MkdirAll(filepath.Dir(mp4Path), 0o755); err != nil {
return fmt.Errorf("ensure output directory: %w", err)
}
cmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", tsPath, "-c", "copy", "-movflags", "+faststart", mp4Path)
cmd.Stdout = io.Discard
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
os.Remove(mp4Path)
msg := strings.TrimSpace(stderr.String())
if msg != "" {
return fmt.Errorf("ffmpeg: %s", msg)
}
return fmt.Errorf("ffmpeg: %w", err)
}
return nil
}