Skip to content

Commit

Permalink
Merge pull request #76 from Dash-Industry-Forum/fix-host
Browse files Browse the repository at this point in the history
fix: fix use of configured or automatic scheme://host everywhere
  • Loading branch information
tobbee authored Aug 22, 2023
2 parents 1937bfd + fb567c1 commit 4615962
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 51 deletions.
26 changes: 10 additions & 16 deletions cmd/livesim2/app/configurl.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ type ResponseConfig struct {
TimeSubsWvtt []string `json:"TimeSubsWvttLanguages,omitempty"`
TimeSubsDurMS int `json:"TimeSubsDurMS,omitempty"`
TimeSubsRegion int `json:"TimeSubsRegion,omitempty"`
Scheme string `json:"Scheme,omitempty"`
Host string `json:"Host,omitempty"`
}

Expand Down Expand Up @@ -282,26 +281,21 @@ func (c *ResponseConfig) URLContentPart() string {
return strings.Join(c.URLParts[c.URLContentIdx:], "/")
}

// SetScheme sets Scheme to non-trivial cfgValue or tries to detect from request.
func (c *ResponseConfig) SetScheme(cfgValue string, r *http.Request) {
if cfgValue != "" {
c.Scheme = cfgValue
return
// fullHost uses non-empty cfgHost or extracts from requests scheme://host from request.
func fullHost(cfgHost string, r *http.Request) string {
if cfgHost != "" {
return cfgHost
}
if r.TLS == nil {
c.Scheme = "http"
} else {
c.Scheme = "https"
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
return fmt.Sprintf("%s://%s", scheme, r.Host)
}

// SetHost sets Host to non-trivial cfgValue or tries to detect from request.
// SetHost sets scheme://host to non-trivial cfgValue or tries to detect from request.
func (c *ResponseConfig) SetHost(cfgValue string, r *http.Request) {
if cfgValue != "" {
c.Host = cfgValue
return
}
c.Host = r.Host
c.Host = fullHost(cfgValue, r)
}

func ms2S(ms int) int {
Expand Down
3 changes: 1 addition & 2 deletions cmd/livesim2/app/handler_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import (
// indexHandlerFunc handles access to /.
func (s *Server) indexHandlerFunc(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
fullHost := getSchemeAndHost(r, s.Cfg)
err := s.htmlTemplates.ExecuteTemplate(w, "welcome.html", fullHost)
err := s.htmlTemplates.ExecuteTemplate(w, "welcome.html", fullHost(s.Cfg.Host, r))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
Expand Down
20 changes: 5 additions & 15 deletions cmd/livesim2/app/handler_livesim.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ func (s *Server) livesimHandlerFunc(w http.ResponseWriter, r *http.Request) {
http.Error(w, msg, http.StatusInternalServerError)
return
}
fullHost := getSchemeAndHost(r, s.Cfg)

var nowMS int // Set from query string or from wall-clock
q := r.URL.Query()
Expand All @@ -59,6 +58,8 @@ func (s *Server) livesimHandlerFunc(w http.ResponseWriter, r *http.Request) {
return
}

cfg.SetHost(s.Cfg.Host, r)

if cfg.TimeOffsetS != nil {
offsetMS := int(*cfg.TimeOffsetS * 1000)
nowMS += offsetMS
Expand All @@ -80,7 +81,7 @@ func (s *Server) livesimHandlerFunc(w http.ResponseWriter, r *http.Request) {
case ".mpd":
_, mpdName := path.Split(contentPart)
cfg.SetHost(s.Cfg.Host, r)
err := writeLiveMPD(log, w, cfg, a, mpdName, fullHost, nowMS)
err := writeLiveMPD(log, w, cfg, a, mpdName, nowMS)
if err != nil {
// TODO. Add more granular errors like 404 not found
msg := fmt.Sprintf("liveMPD: %s", err)
Expand Down Expand Up @@ -112,21 +113,10 @@ func (s *Server) livesimHandlerFunc(w http.ResponseWriter, r *http.Request) {
}
}

func getSchemeAndHost(r *http.Request, cfg *ServerConfig) string {
if cfg.Host != "" {
return cfg.Host
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
return fmt.Sprintf("%s://%s", scheme, r.Host)
}

func writeLiveMPD(log *zerolog.Logger, w http.ResponseWriter, cfg *ResponseConfig, a *asset, mpdName, host string, nowMS int) error {
func writeLiveMPD(log *zerolog.Logger, w http.ResponseWriter, cfg *ResponseConfig, a *asset, mpdName string, nowMS int) error {
work := make([]byte, 0, 1024)
buf := bytes.NewBuffer(work)
lMPD, err := LiveMPD(a, mpdName, cfg, host, nowMS)
lMPD, err := LiveMPD(a, mpdName, cfg, nowMS)
if err != nil {
return fmt.Errorf("convertToLive: %w", err)
}
Expand Down
10 changes: 4 additions & 6 deletions cmd/livesim2/app/livempd.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func calcWrapTimes(a *asset, cfg *ResponseConfig, nowMS int, tsbd m.Duration) wr
}

// LiveMPD generates a dynamic configured MPD for a VoD asset.
func LiveMPD(a *asset, mpdName string, cfg *ResponseConfig, host string, nowMS int) (*m.MPD, error) {
func LiveMPD(a *asset, mpdName string, cfg *ResponseConfig, nowMS int) (*m.MPD, error) {
mpd, err := a.getVodMPD(mpdName)
if err != nil {
return nil, err
Expand All @@ -63,8 +63,6 @@ func LiveMPD(a *asset, mpdName string, cfg *ResponseConfig, host string, nowMS i
}
if cfg.AddLocationFlag {
var strBuf strings.Builder
strBuf.WriteString(cfg.Scheme)
strBuf.WriteString("://")
strBuf.WriteString(cfg.Host)
for i := 1; i < len(cfg.URLParts); i++ {
strBuf.WriteString("/")
Expand All @@ -90,7 +88,7 @@ func LiveMPD(a *asset, mpdName string, cfg *ResponseConfig, host string, nowMS i
}
}

addUTCTimings(mpd, cfg, host)
addUTCTimings(mpd, cfg)

afterStop := false
endTimeMS := nowMS
Expand Down Expand Up @@ -525,7 +523,7 @@ func lastSegAvailTimeS(cfg *ResponseConfig, lsi lastSegInfo) float64 {
}

// addUTCTimings adds the UTCTiming elements to the MPD.
func addUTCTimings(mpd *m.MPD, cfg *ResponseConfig, host string) {
func addUTCTimings(mpd *m.MPD, cfg *ResponseConfig) {
if len(cfg.UTCTimingMethods) == 0 {
// default if none is set. Use HTTP with ms precision.
mpd.UTCTimings = []*m.DescriptorType{
Expand Down Expand Up @@ -567,7 +565,7 @@ func addUTCTimings(mpd *m.MPD, cfg *ResponseConfig, host string) {
case UtcTimingHead:
ut = &m.DescriptorType{
SchemeIdUri: "urn:mpeg:dash:utc:http-head:2014",
Value: fmt.Sprintf("%s%s", host, UtcTimingHeadAsset),
Value: fmt.Sprintf("%s%s", cfg.Host, UtcTimingHeadAsset),
}
case UtcTimingNone:
cfg.UTCTimingMethods = nil
Expand Down
22 changes: 10 additions & 12 deletions cmd/livesim2/app/livempd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestLiveMPD(t *testing.T) {
cfg := NewResponseConfig()
nowMS := 100_000
// Number template
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, "", nowMS)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, nowMS)
assert.NoError(t, err)
assert.Equal(t, "dynamic", *liveMPD.Type)
assert.Equal(t, m.DateTime("1970-01-01T00:00:00Z"), liveMPD.AvailabilityStartTime)
Expand All @@ -83,7 +83,7 @@ func TestLiveMPD(t *testing.T) {
}
// SegmentTimeline with $Time$
cfg.SegTimelineFlag = true
liveMPD, err = LiveMPD(asset, tc.mpdName, cfg, "", nowMS)
liveMPD, err = LiveMPD(asset, tc.mpdName, cfg, nowMS)
assert.NoError(t, err)
assert.Equal(t, "dynamic", *liveMPD.Type)
assert.Equal(t, m.DateTime("1970-01-01T00:00:00Z"), liveMPD.AvailabilityStartTime)
Expand Down Expand Up @@ -133,7 +133,7 @@ func TestLiveMPDWithTimeSubs(t *testing.T) {
cfg.TimeSubsStpp = []string{"en", "sv"}
nowMS := 100_000
// Number template
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, "", nowMS)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, nowMS)
assert.NoError(t, err)
assert.Equal(t, "dynamic", *liveMPD.Type)
aSets := liveMPD.Periods[0].AdaptationSets
Expand Down Expand Up @@ -208,7 +208,7 @@ func TestSegmentTimes(t *testing.T) {
}
for nowS := tc.startTimeS; nowS < tc.endTimeS; nowS++ {
nowMS := nowS * 1000
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, "", nowMS)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, nowMS)
wantedStartNr := (nowS - 62) / 2 // Sliding window of 60s + one segment
assert.NoError(t, err)
for _, as := range liveMPD.Periods[0].AdaptationSets {
Expand Down Expand Up @@ -460,7 +460,7 @@ func TestPublishTime(t *testing.T) {
}
err := verifyAndFillConfig(cfg, tc.nowMS)
require.NoError(t, err)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, "", tc.nowMS)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, tc.nowMS)
assert.NoError(t, err)
assert.Equal(t, m.ConvertToDateTimeS(int64(tc.availabilityStartTime)), liveMPD.AvailabilityStartTime)
assert.Equal(t, m.DateTime(tc.wantedPublishTime), liveMPD.PublishTime)
Expand Down Expand Up @@ -532,7 +532,7 @@ func TestNormalAvailabilityTimeOffset(t *testing.T) {
cfg.SegTimelineFlag = tc.segTimelineTime
sc := strConvAccErr{}
cfg.AvailabilityTimeOffsetS = sc.AtofInf("ato", tc.ato)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, "", tc.nowMS)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, tc.nowMS)
if tc.wantedErr != "" {
assert.EqualError(t, err, tc.wantedErr)
return
Expand Down Expand Up @@ -602,7 +602,7 @@ func TestUTCTiming(t *testing.T) {
}
err := verifyAndFillConfig(cfg, tc.nowMS)
require.NoError(t, err)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, "", tc.nowMS)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, tc.nowMS)
assert.NoError(t, err)
assert.Equal(t, m.DateTime(tc.wantedPublishTime), liveMPD.PublishTime)
assert.Equal(t, tc.wantedUTCTimings, len(liveMPD.UTCTimings))
Expand Down Expand Up @@ -700,7 +700,7 @@ func TestMultiPeriod(t *testing.T) {
default: // $Number$
// no flag
}
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, "", tc.nowMS)
liveMPD, err := LiveMPD(asset, tc.mpdName, cfg, tc.nowMS)
if tc.wantedErr != "" {
assert.EqualError(t, err, tc.wantedErr)
return
Expand Down Expand Up @@ -740,21 +740,19 @@ func TestRelStartStopTimeIntoLocation(t *testing.T) {
url: "/livesim2/startrel_-20/mup_3/stoprel_20/testpic_2s/Manifest.mpd",
nowMS: 1_000_000,
wantedLocation: "http://localhost:8888/livesim2/start_980/mup_3/stop_1020/testpic_2s/Manifest.mpd",
scheme: "http",
host: "localhost:8888",
host: "http://localhost:8888",
},
}

for _, c := range cases {
cfg, err := processURLCfg(c.url, c.nowMS)
require.NoError(t, err)
cfg.SetScheme(c.scheme, nil)
cfg.SetHost(c.host, nil)
contentPart := cfg.URLContentPart()
asset, ok := am.findAsset(contentPart)
require.True(t, ok)
_, mpdName := path.Split(contentPart)
liveMPD, err := LiveMPD(asset, mpdName, cfg, "", c.nowMS)
liveMPD, err := LiveMPD(asset, mpdName, cfg, c.nowMS)
require.NoError(t, err)
require.Equal(t, c.wantedLocation, string(liveMPD.Location[0]), "the right location element is not inserted")
}
Expand Down

0 comments on commit 4615962

Please sign in to comment.