Skip to content

Commit

Permalink
perf: improve embedding layer performance (#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jul 16, 2023
1 parent 9b5b39a commit 9c1a222
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 26 deletions.
8 changes: 8 additions & 0 deletions base/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ func (rng RandomGenerator) NormalMatrix(row, col int, mean, stdDev float32) [][]
return ret
}

func (rng RandomGenerator) NormalVector(size int, mean, stdDev float32) []float32 {
ret := make([]float32, size)
for i := 0; i < len(ret); i++ {
ret[i] = float32(rng.NormFloat64())*stdDev + mean
}
return ret
}

// UniformMatrix makes a matrix filled with uniform random floats.
func (rng RandomGenerator) UniformMatrix(row, col int, low, high float32) [][]float32 {
ret := make([][]float32, row)
Expand Down
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,4 @@ require (

replace gorm.io/driver/sqlite v1.3.4 => github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e

replace gorgonia.org/gorgonia v0.9.18-0.20230327110624-d1c17944ed22 => github.com/gorse-io/gorgonia v0.0.0-20230619134452-7125bfc38f14

replace gorgonia.org/tensor v0.9.23 => github.com/gorse-io/tensor v0.0.0-20230617102451-4c006ddc5162
5 changes: 3 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,6 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorse-io/dashboard v0.0.0-20230319140716-18e3dabe9366 h1:s4CZgfU5HnOtJnGJ1EmUipY1IBpZa0nQmtTTz9gvXvM=
github.com/gorse-io/dashboard v0.0.0-20230319140716-18e3dabe9366/go.mod h1:w74IGf70uM5ZCeXmkBhLl3Ux6D+HpBryzcc75VfZA4s=
github.com/gorse-io/gorgonia v0.0.0-20230619134452-7125bfc38f14 h1:/znIy+Zye7MrXVKUG4L1pjKxFWHA5slrwyt+UbxrJiI=
github.com/gorse-io/gorgonia v0.0.0-20230619134452-7125bfc38f14/go.mod h1:TdV4UXKprIVs6rZv0IiRcMznbn4diHrFUc5lXfwB8pM=
github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e h1:uPQtYQzG1QcC3Qbv+tuEe8Q2l++V4KEcqYSSwB9qobg=
github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e/go.mod h1:PmIOwYnI+F1lRKd6F/PdLXGgI8GZ5H8x8z1yx0+0bmQ=
github.com/gorse-io/tensor v0.0.0-20230617102451-4c006ddc5162 h1:W4aIbIvkE9/9PLuGJ7OcWuEtTeUaXgTd2enX440+e7Q=
Expand Down Expand Up @@ -721,6 +719,7 @@ go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw=
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E=
go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E=
go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E=
go4.org/unsafe/assume-no-moving-gc v0.0.0-20230525183740-e7c30c78aeb2 h1:WJhcL4p+YeDxmZWg141nRm7XC8IDmhz7lk5GpadO1Sg=
go4.org/unsafe/assume-no-moving-gc v0.0.0-20230525183740-e7c30c78aeb2/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E=
Expand Down Expand Up @@ -1221,6 +1220,8 @@ gorgonia.org/dawson v1.2.0 h1:hJ/aofhfkReSnJdSMDzypRZ/oWDL1TmeYOauBnXKdFw=
gorgonia.org/dawson v1.2.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws=
gorgonia.org/gorgonia v0.9.2/go.mod h1:ZtOb9f/wM2OMta1ISGspQ4roGDgz9d9dKOaPNvGR+ec=
gorgonia.org/gorgonia v0.9.17/go.mod h1:g66b5Z6ATUdhVqYl2ZAAwblv5hnGW08vNinGLcnrceI=
gorgonia.org/gorgonia v0.9.18-0.20230327110624-d1c17944ed22 h1:l63Ws8VHVzDD1UugrjSFNQ2+GsJ8gO9X8S7U8Ay2z6Y=
gorgonia.org/gorgonia v0.9.18-0.20230327110624-d1c17944ed22/go.mod h1:kYe25GPmZ+1ycLqfKDQx+50UIhklCU7lSDXiotON/f4=
gorgonia.org/tensor v0.9.0-beta/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w=
gorgonia.org/tensor v0.9.17/go.mod h1:75SMdLLhZ+2oB0/EE8lFEIt1Caoykdd4bz1mAe59deg=
gorgonia.org/tensor v0.9.20/go.mod h1:75SMdLLhZ+2oB0/EE8lFEIt1Caoykdd4bz1mAe59deg=
Expand Down
159 changes: 137 additions & 22 deletions model/click/deepfm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ import (
"modernc.org/mathutil"
)

const (
beta1 float32 = 0.9
beta2 float32 = 0.999
eps float32 = 1e-8
)

type DeepFM struct {
BaseFactorizationMachine

Expand All @@ -42,15 +48,21 @@ type DeepFM struct {

vm gorgonia.VM
g *gorgonia.ExprGraph
indices *gorgonia.Node
embeddingV *gorgonia.Node
embeddingW *gorgonia.Node
values *gorgonia.Node
output *gorgonia.Node
target *gorgonia.Node
cost *gorgonia.Node
v *gorgonia.Node
w *gorgonia.Node
b *gorgonia.Node
learnables []*gorgonia.Node
v [][]float32
w []float32
m_v [][]float32
m_w []float32
v_v [][]float32
v_w []float32
t int

// Hyper parameters
batchSize int
Expand Down Expand Up @@ -139,7 +151,9 @@ func (fm *DeepFM) BatchPredict(x []lo.Tuple2[[]int32, []float32]) []float32 {
indicesTensor, valuesTensor, _ := fm.convertToTensors(x, nil)
predictions := make([]float32, 0, len(x))
for i := 0; i < len(x); i += fm.batchSize {
lo.Must0(gorgonia.Let(fm.indices, lo.Must1(indicesTensor.Slice(gorgonia.S(i, i+fm.batchSize)))))
v, w := fm.embedding(lo.Must1(indicesTensor.Slice(gorgonia.S(i, i+fm.batchSize))))
lo.Must0(gorgonia.Let(fm.embeddingV, v))
lo.Must0(gorgonia.Let(fm.embeddingW, w))
lo.Must0(gorgonia.Let(fm.values, lo.Must1(valuesTensor.Slice(gorgonia.S(i, i+fm.batchSize)))))
lo.Must0(fm.vm.RunAll())
predictions = append(predictions, fm.output.Value().Data().([]float32)...)
Expand Down Expand Up @@ -175,11 +189,14 @@ func (fm *DeepFM) Fit(trainSet *Dataset, testSet *Dataset, config *FitConfig) Sc
fitStart := time.Now()
cost := float32(0)
for i := 0; i < trainSet.Count(); i += fm.batchSize {
lo.Must0(gorgonia.Let(fm.indices, lo.Must1(indicesTensor.Slice(gorgonia.S(i, i+fm.batchSize)))))
v, w := fm.embedding(lo.Must1(indicesTensor.Slice(gorgonia.S(i, i+fm.batchSize))))
lo.Must0(gorgonia.Let(fm.embeddingV, v))
lo.Must0(gorgonia.Let(fm.embeddingW, w))
lo.Must0(gorgonia.Let(fm.values, lo.Must1(valuesTensor.Slice(gorgonia.S(i, i+fm.batchSize)))))
lo.Must0(gorgonia.Let(fm.target, lo.Must1(targetTensor.Slice(gorgonia.S(i, i+fm.batchSize)))))
lo.Must0(fm.vm.RunAll())

fm.backward(lo.Must1(indicesTensor.Slice(gorgonia.S(i, i+fm.batchSize))), epoch)
cost += fm.cost.Value().Data().(float32)
lo.Must0(solver.Step(gorgonia.NodesToValueGrads(fm.learnables)))
fm.vm.Reset()
Expand Down Expand Up @@ -217,21 +234,19 @@ func (fm *DeepFM) Init(trainSet *Dataset) {
fm.numDimension = mathutil.MaxVal(fm.numDimension, len(x))
}

fm.v = gorgonia.NewMatrix(fm.g, tensor.Float32,
gorgonia.WithShape(fm.numFeatures, fm.nFactors),
gorgonia.WithName("v"),
gorgonia.WithInit(gorgonia.Gaussian(float64(fm.initMean), float64(fm.initStdDev))))
fm.w = gorgonia.NewMatrix(fm.g, tensor.Float32,
gorgonia.WithShape(fm.numFeatures, 1),
gorgonia.WithName("w"),
gorgonia.WithInit(gorgonia.Gaussian(float64(fm.initMean), float64(fm.initStdDev))))
fm.v = fm.GetRandomGenerator().NormalMatrix(fm.numFeatures, fm.nFactors, fm.initMean, fm.initStdDev)
fm.w = fm.GetRandomGenerator().NormalVector(fm.numFeatures, fm.initMean, fm.initStdDev)
fm.m_v = zeros(fm.numFeatures, fm.nFactors)
fm.m_w = make([]float32, fm.numFeatures)
fm.v_v = zeros(fm.numFeatures, fm.nFactors)
fm.v_w = make([]float32, fm.numFeatures)
fm.b = gorgonia.NewMatrix(fm.g, tensor.Float32,
gorgonia.WithShape(1, 1),
gorgonia.WithName("b"),
gorgonia.WithInit(gorgonia.Zeroes()))
fm.learnables = []*gorgonia.Node{fm.v, fm.w, fm.b}

fm.forward(fm.batchSize)
fm.learnables = []*gorgonia.Node{fm.b, fm.embeddingV, fm.embeddingW}
lo.Must1(gorgonia.Grad(fm.cost, fm.learnables...))

fm.vm = gorgonia.NewTapeMachine(fm.g, gorgonia.BindDualValues(fm.learnables...))
Expand All @@ -252,9 +267,12 @@ func (fm *DeepFM) Complexity() int {

func (fm *DeepFM) forward(batchSize int) {
// input nodes
fm.indices = gorgonia.NodeFromAny(fm.g,
tensor.New(tensor.WithShape(batchSize, fm.numDimension), tensor.WithBacking(make([]float32, batchSize*fm.numDimension))),
gorgonia.WithName("indices"))
fm.embeddingV = gorgonia.NodeFromAny(fm.g,
tensor.New(tensor.WithShape(batchSize, fm.numDimension, fm.nFactors), tensor.WithBacking(make([]float32, batchSize*fm.numDimension*fm.nFactors))),
gorgonia.WithName("embeddingV"))
fm.embeddingW = gorgonia.NodeFromAny(fm.g,
tensor.New(tensor.WithShape(batchSize, fm.numDimension, 1), tensor.WithBacking(make([]float32, batchSize*fm.numDimension))),
gorgonia.WithName("embeddingW"))
fm.values = gorgonia.NodeFromAny(fm.g,
tensor.New(tensor.WithShape(batchSize, fm.numDimension), tensor.WithBacking(make([]float32, batchSize*fm.numDimension))),
gorgonia.WithName("values"))
Expand All @@ -263,18 +281,16 @@ func (fm *DeepFM) forward(batchSize int) {
gorgonia.WithName("target"))

// factorization machine
v := gorgonia.Must(gorgonia.Embedding(fm.v, fm.indices))
w := gorgonia.Must(gorgonia.Embedding(fm.w, fm.indices))
x := gorgonia.Must(gorgonia.Reshape(fm.values, []int{batchSize, fm.numDimension, 1}))
vx := gorgonia.Must(gorgonia.BatchedMatMul(v, x, true))
vx := gorgonia.Must(gorgonia.BatchedMatMul(fm.embeddingV, x, true))
sumSquare := gorgonia.Must(gorgonia.Square(vx))
v2 := gorgonia.Must(gorgonia.Square(v))
v2 := gorgonia.Must(gorgonia.Square(fm.embeddingV))
x2 := gorgonia.Must(gorgonia.Square(x))
squareSum := gorgonia.Must(gorgonia.BatchedMatMul(v2, x2, true))
sum := gorgonia.Must(gorgonia.Sub(sumSquare, squareSum))
sum = gorgonia.Must(gorgonia.Sum(sum, 1))
sum = gorgonia.Must(gorgonia.Mul(sum, fm.nodeFromFloat64(0.5)))
linear := gorgonia.Must(gorgonia.BatchedMatMul(w, x, true, false))
linear := gorgonia.Must(gorgonia.BatchedMatMul(fm.embeddingW, x, true, false))
fm.output = gorgonia.Must(gorgonia.BroadcastAdd(
gorgonia.Must(gorgonia.Reshape(linear, []int{batchSize})),
fm.b,
Expand All @@ -286,6 +302,97 @@ func (fm *DeepFM) forward(batchSize int) {
fm.cost = fm.bceWithLogits(fm.target, fm.output)
}

func (fm *DeepFM) embedding(indices tensor.View) (v, w *tensor.Dense) {
s := indices.Shape()
if len(s) != 2 {
panic("indices must be 2-dimensional")
}
batchSize, numDimension := s[0], s[1]

dataV := make([]float32, batchSize*numDimension*fm.nFactors)
dataW := make([]float32, batchSize*numDimension)
for i := 0; i < batchSize; i++ {
for j := 0; j < numDimension; j++ {
index := lo.Must1(indices.At(i, j)).(float32)
for k := 0; k < fm.nFactors; k++ {
dataV[i*numDimension*fm.nFactors+j*fm.nFactors+k] = fm.v[int(index)][k]
}
dataW[i*numDimension+j] = fm.w[int(index)]
}
}

v = tensor.New(tensor.WithShape(batchSize, numDimension, fm.nFactors), tensor.WithBacking(dataV))
w = tensor.New(tensor.WithShape(batchSize, numDimension, 1), tensor.WithBacking(dataW))
return
}

func (fm *DeepFM) backward(indices tensor.View, t int) {
s := indices.Shape()
if len(s) != 2 {
panic("indices must be 2-dimensional")
}
batchSize, numDimension := s[0], s[1]

gradEmbeddingV := lo.Must1(fm.embeddingV.Grad()).Data().([]float32)
gradEmbeddingW := lo.Must1(fm.embeddingW.Grad()).Data().([]float32)
gradV := make(map[int][]float32)
gradW := make(map[int]float32)

for i := 0; i < batchSize; i++ {
for j := 0; j < numDimension; j++ {
index := int(lo.Must1(indices.At(i, j)).(float32))

if _, exist := gradV[index]; !exist {
gradV[index] = make([]float32, fm.nFactors)
}
for k := 0; k < fm.nFactors; k++ {
gradV[index][k] += gradEmbeddingV[i*numDimension*fm.nFactors+j*fm.nFactors+k]
}

if _, exist := gradW[index]; !exist {
gradW[index] = 0
}
gradW[index] += gradEmbeddingW[i*numDimension+j]
}
}

fm.t++
correction1 := float32(1 - math32.Pow(beta1, float32(fm.t)))
correction2 := float32(1 - math32.Pow(beta2, float32(fm.t)))

for index, grad := range gradV {
for k := 0; k < fm.nFactors; k++ {
grad[k] += fm.reg * fm.v[index][k]
grad[k] /= float32(batchSize)
// m_t = beta_1 * m_{t-1} + (1 - beta_1) * g_t
fm.m_v[index][k] = beta1*fm.m_v[index][k] + (1-beta1)*grad[k]
// v_t = beta_2 * v_{t-1} + (1 - beta_2) * g_t^2
fm.v_v[index][k] = beta2*fm.v_v[index][k] + (1-beta2)*grad[k]*grad[k]
// \hat{m}_t = m_t / (1 - beta_1^t)
mHat := fm.m_v[index][k] / correction1
// \hat{v}_t = v_t / (1 - beta_2^t)
vHat := fm.v_v[index][k] / correction2
// \theta_t = \theta_{t-1} + \eta * \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)
fm.v[index][k] -= fm.lr * mHat / (math32.Sqrt(vHat) + eps)
}
}

for index, grad := range gradW {
grad += fm.reg * fm.w[index]
grad /= float32(batchSize)
// m_t = beta_1 * m_{t-1} + (1 - beta_1) * g_t
fm.m_w[index] = beta1*fm.m_w[index] + (1-beta1)*grad
// v_t = beta_2 * v_{t-1} + (1 - beta_2) * g_t^2
fm.v_w[index] = beta2*fm.v_w[index] + (1-beta2)*grad*grad
// \hat{m}_t = m_t / (1 - beta_1^t)
mHat := fm.m_w[index] / correction1
// \hat{v}_t = v_t / (1 - beta_2^t)
vHat := fm.v_w[index] / correction2
// \theta_t = \theta_{t-1} + \eta * \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)
fm.w[index] -= fm.lr * mHat / (math32.Sqrt(vHat) + eps)
}
}

func (fm *DeepFM) convertToTensors(x []lo.Tuple2[[]int32, []float32], y []float32) (indicesTensor, valuesTensor, targetTensor *tensor.Dense) {
if y != nil && len(x) != len(y) {
panic("length of x and y must be equal")
Expand Down Expand Up @@ -351,3 +458,11 @@ func (fm *DeepFM) bceWithLogits(target, prediction *gorgonia.Node) *gorgonia.Nod
func (fm *DeepFM) nodeFromFloat64(any float32) *gorgonia.Node {
return gorgonia.NodeFromAny(fm.g, any, gorgonia.WithName(uuid.NewString()))
}

func zeros(a, b int) [][]float32 {
retVal := make([][]float32, a)
for i := range retVal {
retVal[i] = make([]float32, b)
}
return retVal
}

0 comments on commit 9c1a222

Please sign in to comment.