Skip to content

Instantly share code, notes, and snippets.

@Metaxal
Last active September 28, 2022 09:26
Show Gist options
  • Save Metaxal/498d66bc14903efe45a78b0da3f7e4fc to your computer and use it in GitHub Desktop.
Save Metaxal/498d66bc14903efe45a78b0da3f7e4fc to your computer and use it in GitHub Desktop.
softmax speedup issue with futures
#lang racket/base
(require racket/flonum
(only-in math/flonum flvector-sum) ; slow due to type/untype boundary
racket/fixnum
racket/future)
#;(#%declare #:unsafe) ; doesn't make much difference
(provide softmax-ref/unstable
softmax-ref/stable)
;===============;
;=== Softmax ===;
;===============;
#|
Given:
* a matrix β represented as a flat flvector of n-rows × n-cols,
* a set of selected rows of the β matrix (given as `idxs` which represent the starting indices
of the rows in the flvector),
* a selected column `act`,
we want to calculate this softmax expression:
⎛ ▁▁▁ ⎞
⎜ ╲ ⎟
exp⎜ ╱ β[row, act]⎟
⎜ ▔▔▔ ⎟
⎝row ∈ rows ⎠
──────────────────────────────────
n-cols
▁▁▁ ⎛ ▁▁▁ ⎞
╲ ⎜ ╲ ⎟
╱ exp⎜ ╱ β[row, col]⎟
▔▔▔ ⎜ ▔▔▔ ⎟
col=1 ⎝row ∈ rows ⎠
We want to perform this calculation as stably and as fast as possible.
Ref:
(require text-block)
(displayln ($formula '(/ (exp (list (sum "row ∈ rows" "") "β[row, act]" ))
(list (sum "col=1" "n-cols")
(exp (list (sum "row ∈ rows" "") "β[row, col]"))))))
|#
;; Fast but unstable version
(define (softmax-ref/unstable βmatrix idxs n-cols act)
(define θact 0.)
(define Z ; normalizer
(for/flsum ([col (in-range n-cols)])
(define θ
(flexp
(for/flsum ([idx (in-fxvector idxs)])
(flvector-ref βmatrix (fx+ idx col)))))
(when (fx= col act)
(set! θact θ))
θ))
(fl/ θact Z))
;; More stable version, where we subtract the max.
;; This shouldn't be more than 2x as slow than the fast but unsafe version,
;; so why is it 10x–20x slower??
(define (softmax-ref/stable βmatrix idxs n-cols act)
(define θact 0.)
(define βmax -inf.0)
(set! n-cols (fx+ n-cols))
(for ([col (in-range n-cols)])
(define β
(for/flsum ([idx (in-fxvector idxs)])
(flvector-ref βmatrix (fx+ idx col))))
(when (fl> β βmax)
(set! βmax β)))
(define Z
(for/flsum ([col (in-range n-cols)])
(define β
(for/flsum ([idx (in-fxvector idxs)])
(flvector-ref βmatrix (fx+ idx col))))
(define θu (flexp (fl- β βmax)))
(when (fx= col act)
(set! θact θu))
θu))
(fl/ θact Z))
;=============;
;=== Utils ===;
;=============;
;; `expr` must produce a flonum
;; TODO: Use `for/fold/derived` and in particular `split-for-body`
;; Really we should use `flsum`, but that's way too slow for now.
;; Let's start with this.
(define-syntax-rule (for/flsum (clause ...) body ... expr)
(for/fold ([s 0.])
(clause ...)
body ...
(fl+ s expr)))
;; groups the elements of the sequence xs into n lists of the same size ±1
;; Assumes the order doesn't matter
;; xs: sequence?
;; n: natural
;; -> (vectorof list?)
(define (group-into-lists xs n)
(define vec (make-vector n '()))
(define i 0)
(for ([x xs])
(vector-set! vec i (cons x (vector-ref vec i)))
(set! i (modulo (+ i 1) n)))
vec)
;; Sums results of calculation spread over futures.
;; Has a few issues, in particular doesn't handle `#:when` inter-clauses.
(define-syntax-rule (for/flsum/async #:n-futures n-futures (clause ...) body ... expr)
(let ([results (make-flvector n-futures)])
(for/async (clause ... [fut (in-naturals)])
body ...
(define res expr)
(flvector-set! results fut res))
(flvector-sum results)))
;============;
;=== Main ===;
;============;
(module+ main
(require racket/math)
(display (banner))
(printf "#processors: ~a\n\n" (processor-count))
(define n-cols 12)
(define n-rows 10000)
(define n-cells (* n-cols n-rows))
(define n-idxs 100) ; change this to 200 to see +nan.0 with `softmax-ref/unstable`
(define n-idxss 100000)
(define act 1) ; some selected column. Mock-up for an actual sequence of numbers
(define βmatrix (make-flvector n-cells 0.))
(for ([i (in-range n-cells)])
(flvector-set! βmatrix i (fl* -10. (random))))
(define idxss
(for/list ([j (in-range n-idxss)])
(for/fxvector #:length n-idxs ([i (in-range n-idxs)]) (* n-cols (random n-rows)))))
(displayln "No future:")
(collect-garbage)
(displayln
(time
(for/flsum ([idxs (in-list idxss)])
(softmax-ref/unstable βmatrix idxs n-cols act))))
(collect-garbage)
(displayln
(time
(for/flsum ([idxs (in-list idxss)])
(softmax-ref/stable βmatrix idxs n-cols act))))
;; Performs the same calculations but with different number of futures
(for ([n-futures (in-list '(1 2 3 4 5 10 15 20 30 40 50 100 150 200))])
(printf "\nn-futures: ~a\n" n-futures)
;; group the idxss into n-futures groups
(define idxss-groups (group-into-lists idxss n-futures))
;; We use a macro to make sure `the-softmax-ref` is lexically in function-call position
(define-syntax-rule (time-it the-softmax-ref)
(begin
(collect-garbage)
(time
;; Spread the calculation into n-futures
(for/flsum/async #:n-futures n-futures
([idxss (in-vector idxss-groups)])
(for/flsum ([idxs (in-list idxss)])
(the-softmax-ref βmatrix idxs n-cols act))))))
(displayln (time-it softmax-ref/unstable))
(displayln (time-it softmax-ref/stable))))
#| Results on a 64-cores machine
No future:
cpu time: 793 real time: 793 gc time: 51
n-futures: 1
cpu time: 344 real time: 323 gc time: 21 ; softmax-ref/unstable
cpu time: 803 real time: 803 gc time: 18 ; softmax-ref/stable
n-futures: 2
cpu time: 308 real time: 153 gc time: 12
cpu time: 1263 real time: 592 gc time: 404
n-futures: 3
cpu time: 337 real time: 119 gc time: 15
cpu time: 1295 real time: 474 gc time: 359
n-futures: 4
cpu time: 345 real time: 100 gc time: 15
cpu time: 1528 real time: 482 gc time: 426
n-futures: 5
cpu time: 340 real time: 77 gc time: 16
cpu time: 1608 real time: 470 gc time: 446
n-futures: 10
cpu time: 500 real time: 70 gc time: 20
cpu time: 2077 real time: 424 gc time: 460
n-futures: 15
cpu time: 438 real time: 50 gc time: 20
cpu time: 2602 real time: 437 gc time: 466
n-futures: 20
cpu time: 475 real time: 44 gc time: 23
cpu time: 3380 real time: 454 gc time: 473
n-futures: 30
cpu time: 448 real time: 36 gc time: 27
cpu time: 5159 real time: 510 gc time: 532
n-futures: 40
cpu time: 501 real time: 35 gc time: 30
cpu time: 5622 real time: 516 gc time: 554
n-futures: 50
cpu time: 558 real time: 34 gc time: 33
cpu time: 7377 real time: 588 gc time: 595
n-futures: 100
cpu time: 639 real time: 41 gc time: 46
cpu time: 8532 real time: 840 gc time: 789
n-futures: 150
cpu time: 651 real time: 42 gc time: 47
cpu time: 7806 real time: 1002 gc time: 876
n-futures: 200
cpu time: 616 real time: 42 gc time: 45
cpu time: 8839 real time: 945 gc time: 873
Observe how the unstable version speeds up significantly, but the stable version doesn't.
|#
@Metaxal
Copy link
Author

Metaxal commented Sep 28, 2022

See discussion and solution here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment