mirror of
				https://github.com/gohugoio/hugo.git
				synced 2024-05-11 05:54:58 +00:00 
			
		
		
		
	tpl: Allow the partial template func to return any type
This commit adds support for return values in partials.
This means that you can now do this and similar:
    {{ $v := add . 42 }}
    {{ return $v }}
Partials without a `return` statement will be rendered as before.
This works for both `partial` and `partialCached`.
Fixes #5783
			
			
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							9225db636e
						
					
				
				
					commit
					a55640de8e
				
			@@ -20,6 +20,12 @@ type Eqer interface {
 | 
			
		||||
	Eq(other interface{}) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ProbablyEq is an equal check that may return false positives, but never
 | 
			
		||||
// a false negative.
 | 
			
		||||
type ProbablyEqer interface {
 | 
			
		||||
	ProbablyEq(other interface{}) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Comparer can be used to compare two values.
 | 
			
		||||
// This will be used when using the le, ge etc. operators in the templates.
 | 
			
		||||
// Compare returns -1 if the given version is less than, 0 if equal and 1 if greater than
 | 
			
		||||
 
 | 
			
		||||
@@ -264,3 +264,44 @@ Hugo: {{ hugo.Generator }}
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPartialWithReturn(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	b := newTestSitesBuilder(t).WithSimpleConfigFile()
 | 
			
		||||
 | 
			
		||||
	b.WithTemplatesAdded(
 | 
			
		||||
		"index.html", `
 | 
			
		||||
Test Partials With Return Values:
 | 
			
		||||
 | 
			
		||||
add42: 50: {{ partial "add42.tpl" 8 }}
 | 
			
		||||
dollarContext: 60: {{ partial "dollarContext.tpl" 18 }}
 | 
			
		||||
adder: 70: {{ partial "dict.tpl" (dict "adder" 28) }}
 | 
			
		||||
complex: 80: {{ partial "complex.tpl" 38 }}
 | 
			
		||||
`,
 | 
			
		||||
		"partials/add42.tpl", `
 | 
			
		||||
		{{ $v := add . 42 }}
 | 
			
		||||
		{{ return $v }}
 | 
			
		||||
		`,
 | 
			
		||||
		"partials/dollarContext.tpl", `
 | 
			
		||||
{{ $v := add $ 42 }}
 | 
			
		||||
{{ return $v }}
 | 
			
		||||
`,
 | 
			
		||||
		"partials/dict.tpl", `
 | 
			
		||||
{{ $v := add $.adder 42 }}
 | 
			
		||||
{{ return $v }}
 | 
			
		||||
`,
 | 
			
		||||
		"partials/complex.tpl", `
 | 
			
		||||
{{ return add . 42 }}
 | 
			
		||||
`,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	b.CreateSites().Build(BuildCfg{})
 | 
			
		||||
 | 
			
		||||
	b.AssertFileContent("public/index.html",
 | 
			
		||||
		"add42: 50: 50",
 | 
			
		||||
		"dollarContext: 60: 60",
 | 
			
		||||
		"adder: 70: 70",
 | 
			
		||||
		"complex: 80: 80",
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -23,6 +23,12 @@ import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gohugoio/hugo/compare"
 | 
			
		||||
 | 
			
		||||
	"github.com/gohugoio/hugo/common/hreflect"
 | 
			
		||||
 | 
			
		||||
	"github.com/spf13/cast"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// The Provider interface defines an interface for measuring metrics.
 | 
			
		||||
@@ -35,20 +41,20 @@ type Provider interface {
 | 
			
		||||
	WriteMetrics(w io.Writer)
 | 
			
		||||
 | 
			
		||||
	// TrackValue tracks the value for diff calculations etc.
 | 
			
		||||
	TrackValue(key, value string)
 | 
			
		||||
	TrackValue(key string, value interface{})
 | 
			
		||||
 | 
			
		||||
	// Reset clears the metric store.
 | 
			
		||||
	Reset()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type diff struct {
 | 
			
		||||
	baseline string
 | 
			
		||||
	baseline interface{}
 | 
			
		||||
	count    int
 | 
			
		||||
	simSum   int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *diff) add(v string) *diff {
 | 
			
		||||
	if d.baseline == "" {
 | 
			
		||||
func (d *diff) add(v interface{}) *diff {
 | 
			
		||||
	if !hreflect.IsTruthful(v) {
 | 
			
		||||
		d.baseline = v
 | 
			
		||||
		d.count = 1
 | 
			
		||||
		d.simSum = 100 // If we get only one it is very cache friendly.
 | 
			
		||||
@@ -90,7 +96,7 @@ func (s *Store) Reset() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TrackValue tracks the value for diff calculations etc.
 | 
			
		||||
func (s *Store) TrackValue(key, value string) {
 | 
			
		||||
func (s *Store) TrackValue(key string, value interface{}) {
 | 
			
		||||
	if !s.calculateHints {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -191,13 +197,43 @@ func (b bySum) Less(i, j int) bool { return b[i].sum > b[j].sum }
 | 
			
		||||
 | 
			
		||||
// howSimilar is a naive diff implementation that returns
 | 
			
		||||
// a number between 0-100 indicating how similar a and b are.
 | 
			
		||||
// 100 is when all words in a also exists in b.
 | 
			
		||||
func howSimilar(a, b string) int {
 | 
			
		||||
 | 
			
		||||
func howSimilar(a, b interface{}) int {
 | 
			
		||||
	if a == b {
 | 
			
		||||
		return 100
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	as, err1 := cast.ToStringE(a)
 | 
			
		||||
	bs, err2 := cast.ToStringE(b)
 | 
			
		||||
 | 
			
		||||
	if err1 == nil && err2 == nil {
 | 
			
		||||
		return howSimilarStrings(as, bs)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err1 != err2 {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	e1, ok1 := a.(compare.Eqer)
 | 
			
		||||
	e2, ok2 := b.(compare.Eqer)
 | 
			
		||||
	if ok1 && ok2 && e1.Eq(e2) {
 | 
			
		||||
		return 100
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO(bep) implement ProbablyEq for Pages etc.
 | 
			
		||||
	pe1, pok1 := a.(compare.ProbablyEqer)
 | 
			
		||||
	pe2, pok2 := b.(compare.ProbablyEqer)
 | 
			
		||||
	if pok1 && pok2 && pe1.ProbablyEq(pe2) {
 | 
			
		||||
		return 90
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// howSimilar is a naive diff implementation that returns
 | 
			
		||||
// a number between 0-100 indicating how similar a and b are.
 | 
			
		||||
// 100 is when all words in a also exists in b.
 | 
			
		||||
func howSimilarStrings(a, b string) int {
 | 
			
		||||
 | 
			
		||||
	// Give some weight to the word positions.
 | 
			
		||||
	const partitionSize = 4
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -36,6 +36,13 @@ func init() {
 | 
			
		||||
			},
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		// TODO(bep) we need the return to be a valid identifier, but
 | 
			
		||||
		// should consider another way of adding it.
 | 
			
		||||
		ns.AddMethodMapping(func() string { return "" },
 | 
			
		||||
			[]string{"return"},
 | 
			
		||||
			[][2]string{},
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		ns.AddMethodMapping(ctx.IncludeCached,
 | 
			
		||||
			[]string{"partialCached"},
 | 
			
		||||
			[][2]string{},
 | 
			
		||||
 
 | 
			
		||||
@@ -18,10 +18,14 @@ package partials
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	texttemplate "text/template"
 | 
			
		||||
 | 
			
		||||
	"github.com/gohugoio/hugo/tpl"
 | 
			
		||||
 | 
			
		||||
	bp "github.com/gohugoio/hugo/bufferpool"
 | 
			
		||||
	"github.com/gohugoio/hugo/deps"
 | 
			
		||||
)
 | 
			
		||||
@@ -62,8 +66,22 @@ type Namespace struct {
 | 
			
		||||
	cachedPartials *partialCache
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Include executes the named partial and returns either a string,
 | 
			
		||||
// when the partial is a text/template, or template.HTML when html/template.
 | 
			
		||||
// contextWrapper makes room for a return value in a partial invocation.
 | 
			
		||||
type contextWrapper struct {
 | 
			
		||||
	Arg    interface{}
 | 
			
		||||
	Result interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Set sets the return value and returns an empty string.
 | 
			
		||||
func (c *contextWrapper) Set(in interface{}) string {
 | 
			
		||||
	c.Result = in
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Include executes the named partial.
 | 
			
		||||
// If the partial contains a return statement, that value will be returned.
 | 
			
		||||
// Else, the rendered output will be returned:
 | 
			
		||||
// A string if the partial is a text/template, or template.HTML when html/template.
 | 
			
		||||
func (ns *Namespace) Include(name string, contextList ...interface{}) (interface{}, error) {
 | 
			
		||||
	if strings.HasPrefix(name, "partials/") {
 | 
			
		||||
		name = name[8:]
 | 
			
		||||
@@ -83,31 +101,54 @@ func (ns *Namespace) Include(name string, contextList ...interface{}) (interface
 | 
			
		||||
		// For legacy reasons.
 | 
			
		||||
		templ, found = ns.deps.Tmpl.Lookup(n + ".html")
 | 
			
		||||
	}
 | 
			
		||||
	if found {
 | 
			
		||||
 | 
			
		||||
	if !found {
 | 
			
		||||
		return "", fmt.Errorf("partial %q not found", name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var info tpl.Info
 | 
			
		||||
	if ip, ok := templ.(tpl.TemplateInfoProvider); ok {
 | 
			
		||||
		info = ip.TemplateInfo()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var w io.Writer
 | 
			
		||||
 | 
			
		||||
	if info.HasReturn {
 | 
			
		||||
		// Wrap the context sent to the template to capture the return value.
 | 
			
		||||
		// Note that the template is rewritten to make sure that the dot (".")
 | 
			
		||||
		// and the $ variable points to Arg.
 | 
			
		||||
		context = &contextWrapper{
 | 
			
		||||
			Arg: context,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// We don't care about any template output.
 | 
			
		||||
		w = ioutil.Discard
 | 
			
		||||
	} else {
 | 
			
		||||
		b := bp.GetBuffer()
 | 
			
		||||
		defer bp.PutBuffer(b)
 | 
			
		||||
		w = b
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
		if err := templ.Execute(b, context); err != nil {
 | 
			
		||||
	if err := templ.Execute(w, context); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
		if _, ok := templ.(*texttemplate.Template); ok {
 | 
			
		||||
			s := b.String()
 | 
			
		||||
	var result interface{}
 | 
			
		||||
 | 
			
		||||
	if ctx, ok := context.(*contextWrapper); ok {
 | 
			
		||||
		result = ctx.Result
 | 
			
		||||
	} else if _, ok := templ.(*texttemplate.Template); ok {
 | 
			
		||||
		result = w.(fmt.Stringer).String()
 | 
			
		||||
	} else {
 | 
			
		||||
		result = template.HTML(w.(fmt.Stringer).String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ns.deps.Metrics != nil {
 | 
			
		||||
				ns.deps.Metrics.TrackValue(n, s)
 | 
			
		||||
			}
 | 
			
		||||
			return s, nil
 | 
			
		||||
		ns.deps.Metrics.TrackValue(n, result)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
		s := b.String()
 | 
			
		||||
		if ns.deps.Metrics != nil {
 | 
			
		||||
			ns.deps.Metrics.TrackValue(n, s)
 | 
			
		||||
		}
 | 
			
		||||
		return template.HTML(s), nil
 | 
			
		||||
	return result, nil
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return "", fmt.Errorf("partial %q not found", name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IncludeCached executes and caches partial templates.  An optional variant
 | 
			
		||||
 
 | 
			
		||||
@@ -22,10 +22,17 @@ type Info struct {
 | 
			
		||||
	// Set for shortcode templates with any {{ .Inner }}
 | 
			
		||||
	IsInner bool
 | 
			
		||||
 | 
			
		||||
	// Set for partials with a return statement.
 | 
			
		||||
	HasReturn bool
 | 
			
		||||
 | 
			
		||||
	// Config extracted from template.
 | 
			
		||||
	Config Config
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (info Info) IsZero() bool {
 | 
			
		||||
	return info.Config.Version == 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	Version int
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -51,15 +51,17 @@ func (t *templateHandler) addAceTemplate(name, basePath, innerPath string, baseC
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isShort := isShortcode(name)
 | 
			
		||||
	typ := resolveTemplateType(name)
 | 
			
		||||
 | 
			
		||||
	info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
 | 
			
		||||
	info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if isShort {
 | 
			
		||||
	if typ == templateShortcode {
 | 
			
		||||
		t.addShortcodeVariant(name, info, templ)
 | 
			
		||||
	} else {
 | 
			
		||||
		t.templateInfo[name] = info
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -139,6 +139,18 @@ func templateNameAndVariants(name string) (string, []string) {
 | 
			
		||||
	return name, variants
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func resolveTemplateType(name string) templateType {
 | 
			
		||||
	if isShortcode(name) {
 | 
			
		||||
		return templateShortcode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.Contains(name, "partials/") {
 | 
			
		||||
		return templatePartial
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return templateUndefined
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isShortcode(name string) bool {
 | 
			
		||||
	return strings.Contains(name, "shortcodes/")
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -90,6 +90,11 @@ type templateHandler struct {
 | 
			
		||||
	// (language, output format etc.) of that shortcode.
 | 
			
		||||
	shortcodes map[string]*shortcodeTemplates
 | 
			
		||||
 | 
			
		||||
	// templateInfo maps template name to some additional information about that template.
 | 
			
		||||
	// Note that for shortcodes that same information is embedded in the
 | 
			
		||||
	// shortcodeTemplates type.
 | 
			
		||||
	templateInfo map[string]tpl.Info
 | 
			
		||||
 | 
			
		||||
	// text holds all the pure text templates.
 | 
			
		||||
	text *textTemplates
 | 
			
		||||
	html *htmlTemplates
 | 
			
		||||
@@ -172,18 +177,30 @@ func (t *templateHandler) Lookup(name string) (tpl.Template, bool) {
 | 
			
		||||
		// The templates are stored without the prefix identificator.
 | 
			
		||||
		name = strings.TrimPrefix(name, textTmplNamePrefix)
 | 
			
		||||
 | 
			
		||||
		return t.text.Lookup(name)
 | 
			
		||||
		return t.applyTemplateInfo(t.text.Lookup(name))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Look in both
 | 
			
		||||
	if te, found := t.html.Lookup(name); found {
 | 
			
		||||
		return te, true
 | 
			
		||||
		return t.applyTemplateInfo(te, true)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return t.text.Lookup(name)
 | 
			
		||||
	return t.applyTemplateInfo(t.text.Lookup(name))
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *templateHandler) applyTemplateInfo(templ tpl.Template, found bool) (tpl.Template, bool) {
 | 
			
		||||
	if adapter, ok := templ.(*tpl.TemplateAdapter); ok {
 | 
			
		||||
		if adapter.Info.IsZero() {
 | 
			
		||||
			if info, found := t.templateInfo[templ.Name()]; found {
 | 
			
		||||
				adapter.Info = info
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return templ, found
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This currently only applies to shortcodes and what we get here is the
 | 
			
		||||
// shortcode name.
 | 
			
		||||
func (t *templateHandler) LookupVariant(name string, variants tpl.TemplateVariants) (tpl.Template, bool, bool) {
 | 
			
		||||
@@ -246,6 +263,7 @@ func (t *templateHandler) clone(d *deps.Deps) *templateHandler {
 | 
			
		||||
		Deps:         d,
 | 
			
		||||
		layoutsFs:    d.BaseFs.Layouts.Fs,
 | 
			
		||||
		shortcodes:   make(map[string]*shortcodeTemplates),
 | 
			
		||||
		templateInfo: t.templateInfo,
 | 
			
		||||
		html:         &htmlTemplates{t: template.Must(t.html.t.Clone()), overlays: make(map[string]*template.Template), templatesCommon: t.html.templatesCommon},
 | 
			
		||||
		text:         &textTemplates{textTemplate: &textTemplate{t: texttemplate.Must(t.text.t.Clone())}, overlays: make(map[string]*texttemplate.Template), templatesCommon: t.text.templatesCommon},
 | 
			
		||||
		errors:       make([]*templateErr, 0),
 | 
			
		||||
@@ -309,6 +327,7 @@ func newTemplateAdapter(deps *deps.Deps) *templateHandler {
 | 
			
		||||
		Deps:         deps,
 | 
			
		||||
		layoutsFs:    deps.BaseFs.Layouts.Fs,
 | 
			
		||||
		shortcodes:   make(map[string]*shortcodeTemplates),
 | 
			
		||||
		templateInfo: make(map[string]tpl.Info),
 | 
			
		||||
		html:         htmlT,
 | 
			
		||||
		text:         textT,
 | 
			
		||||
		errors:       make([]*templateErr, 0),
 | 
			
		||||
@@ -463,15 +482,17 @@ func (t *htmlTemplates) addTemplateIn(tt *template.Template, name, tpl string) e
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isShort := isShortcode(name)
 | 
			
		||||
	typ := resolveTemplateType(name)
 | 
			
		||||
 | 
			
		||||
	info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
 | 
			
		||||
	info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if isShort {
 | 
			
		||||
	if typ == templateShortcode {
 | 
			
		||||
		t.handler.addShortcodeVariant(name, info, templ)
 | 
			
		||||
	} else {
 | 
			
		||||
		t.handler.templateInfo[name] = info
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
@@ -511,7 +532,7 @@ func (t *textTemplate) parseIn(tt *texttemplate.Template, name, tpl string) (*te
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := applyTemplateTransformersToTextTemplate(false, templ); err != nil {
 | 
			
		||||
	if _, err := applyTemplateTransformersToTextTemplate(templateUndefined, templ); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return templ, nil
 | 
			
		||||
@@ -524,15 +545,17 @@ func (t *textTemplates) addTemplateIn(tt *texttemplate.Template, name, tpl strin
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isShort := isShortcode(name)
 | 
			
		||||
	typ := resolveTemplateType(name)
 | 
			
		||||
 | 
			
		||||
	info, err := applyTemplateTransformersToTextTemplate(isShort, templ)
 | 
			
		||||
	info, err := applyTemplateTransformersToTextTemplate(typ, templ)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if isShort {
 | 
			
		||||
	if typ == templateShortcode {
 | 
			
		||||
		t.handler.addShortcodeVariant(name, info, templ)
 | 
			
		||||
	} else {
 | 
			
		||||
		t.handler.templateInfo[name] = info
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
@@ -737,7 +760,7 @@ func (t *htmlTemplates) handleMaster(name, overlayFilename, masterFilename strin
 | 
			
		||||
	// * https://github.com/golang/go/issues/16101
 | 
			
		||||
	// * https://github.com/gohugoio/hugo/issues/2549
 | 
			
		||||
	overlayTpl = overlayTpl.Lookup(overlayTpl.Name())
 | 
			
		||||
	if _, err := applyTemplateTransformersToHMLTTemplate(false, overlayTpl); err != nil {
 | 
			
		||||
	if _, err := applyTemplateTransformersToHMLTTemplate(templateUndefined, overlayTpl); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -777,7 +800,7 @@ func (t *textTemplates) handleMaster(name, overlayFilename, masterFilename strin
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	overlayTpl = overlayTpl.Lookup(overlayTpl.Name())
 | 
			
		||||
	if _, err := applyTemplateTransformersToTextTemplate(false, overlayTpl); err != nil {
 | 
			
		||||
	if _, err := applyTemplateTransformersToTextTemplate(templateUndefined, overlayTpl); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	t.overlays[name] = overlayTpl
 | 
			
		||||
@@ -847,15 +870,17 @@ func (t *templateHandler) addTemplateFile(name, baseTemplatePath, path string) e
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		isShort := isShortcode(name)
 | 
			
		||||
		typ := resolveTemplateType(name)
 | 
			
		||||
 | 
			
		||||
		info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
 | 
			
		||||
		info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if isShort {
 | 
			
		||||
		if typ == templateShortcode {
 | 
			
		||||
			t.addShortcodeVariant(templateName, info, templ)
 | 
			
		||||
		} else {
 | 
			
		||||
			t.templateInfo[name] = info
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -39,6 +39,14 @@ var reservedContainers = map[string]bool{
 | 
			
		||||
	"Data": true,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type templateType int
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	templateUndefined templateType = iota
 | 
			
		||||
	templateShortcode
 | 
			
		||||
	templatePartial
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type templateContext struct {
 | 
			
		||||
	decl     decl
 | 
			
		||||
	visited  map[string]bool
 | 
			
		||||
@@ -47,14 +55,16 @@ type templateContext struct {
 | 
			
		||||
	// The last error encountered.
 | 
			
		||||
	err error
 | 
			
		||||
 | 
			
		||||
	// Only needed for shortcodes
 | 
			
		||||
	isShortcode bool
 | 
			
		||||
	typ templateType
 | 
			
		||||
 | 
			
		||||
	// Set when we're done checking for config header.
 | 
			
		||||
	configChecked bool
 | 
			
		||||
 | 
			
		||||
	// Contains some info about the template
 | 
			
		||||
	tpl.Info
 | 
			
		||||
 | 
			
		||||
	// Store away the return node in partials.
 | 
			
		||||
	returnNode *parse.CommandNode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c templateContext) getIfNotVisited(name string) *parse.Tree {
 | 
			
		||||
@@ -84,12 +94,12 @@ func createParseTreeLookup(templ *template.Template) func(nn string) *parse.Tree
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func applyTemplateTransformersToHMLTTemplate(isShortcode bool, templ *template.Template) (tpl.Info, error) {
 | 
			
		||||
	return applyTemplateTransformers(isShortcode, templ.Tree, createParseTreeLookup(templ))
 | 
			
		||||
func applyTemplateTransformersToHMLTTemplate(typ templateType, templ *template.Template) (tpl.Info, error) {
 | 
			
		||||
	return applyTemplateTransformers(typ, templ.Tree, createParseTreeLookup(templ))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func applyTemplateTransformersToTextTemplate(isShortcode bool, templ *texttemplate.Template) (tpl.Info, error) {
 | 
			
		||||
	return applyTemplateTransformers(isShortcode, templ.Tree,
 | 
			
		||||
func applyTemplateTransformersToTextTemplate(typ templateType, templ *texttemplate.Template) (tpl.Info, error) {
 | 
			
		||||
	return applyTemplateTransformers(typ, templ.Tree,
 | 
			
		||||
		func(nn string) *parse.Tree {
 | 
			
		||||
			tt := templ.Lookup(nn)
 | 
			
		||||
			if tt != nil {
 | 
			
		||||
@@ -99,19 +109,54 @@ func applyTemplateTransformersToTextTemplate(isShortcode bool, templ *texttempla
 | 
			
		||||
		})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func applyTemplateTransformers(isShortcode bool, templ *parse.Tree, lookupFn func(name string) *parse.Tree) (tpl.Info, error) {
 | 
			
		||||
func applyTemplateTransformers(typ templateType, templ *parse.Tree, lookupFn func(name string) *parse.Tree) (tpl.Info, error) {
 | 
			
		||||
	if templ == nil {
 | 
			
		||||
		return tpl.Info{}, errors.New("expected template, but none provided")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c := newTemplateContext(lookupFn)
 | 
			
		||||
	c.isShortcode = isShortcode
 | 
			
		||||
	c.typ = typ
 | 
			
		||||
 | 
			
		||||
	err := c.applyTransformations(templ.Root)
 | 
			
		||||
	_, err := c.applyTransformations(templ.Root)
 | 
			
		||||
 | 
			
		||||
	if err == nil && c.returnNode != nil {
 | 
			
		||||
		// This is a partial with a return statement.
 | 
			
		||||
		c.Info.HasReturn = true
 | 
			
		||||
		templ.Root = c.wrapInPartialReturnWrapper(templ.Root)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return c.Info, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	partialReturnWrapperTempl = `{{ $_hugo_dot := $ }}{{ $ := .Arg }}{{ with .Arg }}{{ $_hugo_dot.Set ("PLACEHOLDER") }}{{ end }}`
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var partialReturnWrapper *parse.ListNode
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	templ, err := texttemplate.New("").Parse(partialReturnWrapperTempl)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	partialReturnWrapper = templ.Tree.Root
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *templateContext) wrapInPartialReturnWrapper(n *parse.ListNode) *parse.ListNode {
 | 
			
		||||
	wrapper := partialReturnWrapper.CopyList()
 | 
			
		||||
	withNode := wrapper.Nodes[2].(*parse.WithNode)
 | 
			
		||||
	retn := withNode.List.Nodes[0]
 | 
			
		||||
	setCmd := retn.(*parse.ActionNode).Pipe.Cmds[0]
 | 
			
		||||
	setPipe := setCmd.Args[1].(*parse.PipeNode)
 | 
			
		||||
	// Replace PLACEHOLDER with the real return value.
 | 
			
		||||
	// Note that this is a PipeNode, so it will be wrapped in parens.
 | 
			
		||||
	setPipe.Cmds = []*parse.CommandNode{c.returnNode}
 | 
			
		||||
	withNode.List.Nodes = append(n.Nodes, retn)
 | 
			
		||||
 | 
			
		||||
	return wrapper
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The truth logic in Go's template package is broken for certain values
 | 
			
		||||
// for the if and with keywords. This works around that problem by wrapping
 | 
			
		||||
// the node passed to if/with in a getif conditional.
 | 
			
		||||
@@ -141,7 +186,7 @@ func (c *templateContext) wrapWithGetIf(p *parse.PipeNode) {
 | 
			
		||||
// 1) Make all .Params.CamelCase and similar into lowercase.
 | 
			
		||||
// 2) Wraps every with and if pipe in getif
 | 
			
		||||
// 3) Collects some information about the template content.
 | 
			
		||||
func (c *templateContext) applyTransformations(n parse.Node) error {
 | 
			
		||||
func (c *templateContext) applyTransformations(n parse.Node) (bool, error) {
 | 
			
		||||
	switch x := n.(type) {
 | 
			
		||||
	case *parse.ListNode:
 | 
			
		||||
		if x != nil {
 | 
			
		||||
@@ -169,12 +214,16 @@ func (c *templateContext) applyTransformations(n parse.Node) error {
 | 
			
		||||
			c.decl[x.Decl[0].Ident[0]] = x.Cmds[0].String()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, cmd := range x.Cmds {
 | 
			
		||||
			c.applyTransformations(cmd)
 | 
			
		||||
		for i, cmd := range x.Cmds {
 | 
			
		||||
			keep, _ := c.applyTransformations(cmd)
 | 
			
		||||
			if !keep {
 | 
			
		||||
				x.Cmds = append(x.Cmds[:i], x.Cmds[i+1:]...)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	case *parse.CommandNode:
 | 
			
		||||
		c.collectInner(x)
 | 
			
		||||
		keep := c.collectReturnNode(x)
 | 
			
		||||
 | 
			
		||||
		for _, elem := range x.Args {
 | 
			
		||||
			switch an := elem.(type) {
 | 
			
		||||
@@ -191,9 +240,10 @@ func (c *templateContext) applyTransformations(n parse.Node) error {
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return keep, c.err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return c.err
 | 
			
		||||
	return true, c.err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *templateContext) applyTransformationsToNodes(nodes ...parse.Node) {
 | 
			
		||||
@@ -229,7 +279,7 @@ func (c *templateContext) hasIdent(idents []string, ident string) bool {
 | 
			
		||||
// on the form:
 | 
			
		||||
//    {{ $_hugo_config:= `{ "version": 1 }` }}
 | 
			
		||||
func (c *templateContext) collectConfig(n *parse.PipeNode) {
 | 
			
		||||
	if !c.isShortcode {
 | 
			
		||||
	if c.typ != templateShortcode {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if c.configChecked {
 | 
			
		||||
@@ -271,7 +321,7 @@ func (c *templateContext) collectConfig(n *parse.PipeNode) {
 | 
			
		||||
// collectInner determines if the given CommandNode represents a
 | 
			
		||||
// shortcode call to its .Inner.
 | 
			
		||||
func (c *templateContext) collectInner(n *parse.CommandNode) {
 | 
			
		||||
	if !c.isShortcode {
 | 
			
		||||
	if c.typ != templateShortcode {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if c.Info.IsInner || len(n.Args) == 0 {
 | 
			
		||||
@@ -295,6 +345,28 @@ func (c *templateContext) collectInner(n *parse.CommandNode) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *templateContext) collectReturnNode(n *parse.CommandNode) bool {
 | 
			
		||||
	if c.typ != templatePartial || c.returnNode != nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(n.Args) < 2 {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ident, ok := n.Args[0].(*parse.IdentifierNode)
 | 
			
		||||
	if !ok || ident.Ident != "return" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.returnNode = n
 | 
			
		||||
	// Remove the "return" identifiers
 | 
			
		||||
	c.returnNode.Args = c.returnNode.Args[1:]
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// indexOfReplacementStart will return the index of where to start doing replacement,
 | 
			
		||||
// -1 if none needed.
 | 
			
		||||
func (d decl) indexOfReplacementStart(idents []string) int {
 | 
			
		||||
 
 | 
			
		||||
@@ -180,7 +180,7 @@ PARAMS SITE GLOBAL3: {{ $site.Params.LOWER }}
 | 
			
		||||
func TestParamsKeysToLower(t *testing.T) {
 | 
			
		||||
	t.Parallel()
 | 
			
		||||
 | 
			
		||||
	_, err := applyTemplateTransformers(false, nil, nil)
 | 
			
		||||
	_, err := applyTemplateTransformers(templateUndefined, nil, nil)
 | 
			
		||||
	require.Error(t, err)
 | 
			
		||||
 | 
			
		||||
	templ, err := template.New("foo").Funcs(testFuncs).Parse(paramsTempl)
 | 
			
		||||
@@ -484,7 +484,7 @@ func TestCollectInfo(t *testing.T) {
 | 
			
		||||
			require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
			c := newTemplateContext(createParseTreeLookup(templ))
 | 
			
		||||
			c.isShortcode = true
 | 
			
		||||
			c.typ = templateShortcode
 | 
			
		||||
			c.applyTransformations(templ.Tree.Root)
 | 
			
		||||
 | 
			
		||||
			assert.Equal(test.expected, c.Info)
 | 
			
		||||
@@ -492,3 +492,46 @@ func TestCollectInfo(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPartialReturn(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name      string
 | 
			
		||||
		tplString string
 | 
			
		||||
		expected  bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"Basic", `
 | 
			
		||||
{{ $a := "Hugo Rocks!" }}
 | 
			
		||||
{{ return $a }}
 | 
			
		||||
`, true},
 | 
			
		||||
		{"Expression", `
 | 
			
		||||
{{ return add 32 }}
 | 
			
		||||
`, true},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	echo := func(in interface{}) interface{} {
 | 
			
		||||
		return in
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	funcs := template.FuncMap{
 | 
			
		||||
		"return": echo,
 | 
			
		||||
		"add":    echo,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			assert := require.New(t)
 | 
			
		||||
 | 
			
		||||
			templ, err := template.New("foo").Funcs(funcs).Parse(test.tplString)
 | 
			
		||||
			require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
			_, err = applyTemplateTransformers(templatePartial, templ.Tree, createParseTreeLookup(templ))
 | 
			
		||||
 | 
			
		||||
			// Just check that it doesn't fail in this test. We have functional tests
 | 
			
		||||
			// in hugoblib.
 | 
			
		||||
			assert.NoError(err)
 | 
			
		||||
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user