Metal for Accelerating Machine Learning
Metal for Accelerating Machine Learning
WWDC 2018
Metal Performance Shaders
GPU-accelerated primitives, optimized for iOS and macOS
- Image processing
- Linear algebra
- Machine learining
- inference
- training (new) - Ray tracing (new)
Training
data:image/s3,"s3://crabby-images/b9d90/b9d906290bcf4cf2275945688988d06dfaaf14d6" alt=""
Inference
data:image/s3,"s3://crabby-images/a11aa/a11aa57f02d1e4375739f00ec081f80c8e15790c" alt=""
CNN Inference Enhancements
FP16 accumulation
- Available with Apple A11 Bionic GPU for
- Convolution
- Convolution transpose - Sufficient precision for commonly used neural networks
- Delivers better performance than FP32
data:image/s3,"s3://crabby-images/1704f/1704f67ff76c7a3bf6b448ca34a147985f2b80da" alt=""
CNN Training
data:image/s3,"s3://crabby-images/117a9/117a931ae5088b4842644884ecbc0bfb25b817d9" alt=""
Training
data:image/s3,"s3://crabby-images/bc24f/bc24fa72e0b785b63ecd5e73291bb1cd1bc5a2d1" alt=""
Forward Pass
data:image/s3,"s3://crabby-images/217c0/217c0a8acbbdb18f5304f487b366778b9bc9ae09" alt=""
Loss computation
data:image/s3,"s3://crabby-images/90ff0/90ff06ce38d87962d9d6992151b7a894ac9682a5" alt=""
Gradient pass
data:image/s3,"s3://crabby-images/b8a9c/b8a9c69368afb1aab2aac619f55dbe32d35ccfdf" alt=""
Weight update
data:image/s3,"s3://crabby-images/522c2/522c2e321c87c5ed1d37d4b1cdd44f86b9654273" alt=""
Iterate
- Forward pass → Loss computation → Gradient pass → Weight update
data:image/s3,"s3://crabby-images/9b7a7/9b7a7e182bb28d259a51a87e4423753b01243bd6" alt=""
Training a Neural Network with MPS
- Create training graph
- Prepare inputs
- Specify wights
- Execute graph (Graph updates wights)
- Complete training process
Create Training Graph
- Describe neural network using graph API
data:image/s3,"s3://crabby-images/2f52b/2f52bbf2b36cca959972d9dfcd8e10bf44ed6c9a" alt=""
- Image nodes — Data
data:image/s3,"s3://crabby-images/0e938/0e938e53ba62a155d9a96bba23f50c3c78e0111c" alt=""
- Filter nodes — Operations
data:image/s3,"s3://crabby-images/01622/01622a244e56f1a175c4caa73493848e52e19cb2" alt=""
data:image/s3,"s3://crabby-images/04511/04511515f21547f735436b949b48f0d461a8056f" alt=""
Create an Inference Graph
data:image/s3,"s3://crabby-images/593dc/593dc7197122b9e6b48329d3c2b7b712d01fc775" alt=""
data:image/s3,"s3://crabby-images/eb209/eb20999e6a6244a681afa667cb02023138f7ef95" alt=""
Prepare Inputs
- Inputs to the graph
- Batch of source images
- Batch of source states
data:image/s3,"s3://crabby-images/e1f96/e1f9608458b51470e171d31423d1816049eac2c7" alt=""
Batches
- Batches are arrays of images or states
data:image/s3,"s3://crabby-images/007c8/007c899e289efb08ecc73639b5c2e847cba70459" alt=""
States
MPSState
passes state of forward node to gradient node- Graph manages all states
data:image/s3,"s3://crabby-images/241d7/241d70676da4fff7111864f249271b9558c6e29c" alt=""
Loss Labels
data:image/s3,"s3://crabby-images/23632/236320f05116f7ae187a7915866a35390e4b56ef" alt=""
Data Source Providers
- Convolution
- Fully Connected
- Batch normalization
- Instance normalization
- Just-in-time loading and purging of weights data
- Minimize memory footprint
data:image/s3,"s3://crabby-images/cf4fb/cf4fb9c416890ce6bb588c45e8f016394868179c" alt=""
data:image/s3,"s3://crabby-images/867fd/867fd90f6c140040f8ce93b06d604c020cf55fb9" alt=""
Execute graph
data:image/s3,"s3://crabby-images/0856c/0856c2b15e4c407fe54a3d857156494b824eaeda" alt=""
Updating Weights
- Implement optional update method on Data Source Provider
- Graph calls update method automatically
data:image/s3,"s3://crabby-images/b52e5/b52e590d8fd6708a890143db6785aa6dc5b40766" alt=""
Optimizer
- Describe how to take update step on training parameters
- Used in update method of Data Source Provider
- Variants
-MPSNNOptimizerAdam
-MPSNOptimizerStochasticGradientDescent
-MPSNNOptimizerRMSProp
- Custom
data:image/s3,"s3://crabby-images/9eb0e/9eb0eecb7dbcdcbea19486d14daf630fad20433b" alt=""
data:image/s3,"s3://crabby-images/da25a/da25a53ec20b7f76456b147934ba3e238a60b08e" alt=""
Complete training process
data:image/s3,"s3://crabby-images/649e8/649e8d36a55f0967a286755173d448d2f7eb7401" alt=""
Demo
data:image/s3,"s3://crabby-images/1b816/1b81695732b565df5b36c13476cdb4b44d758d8e" alt=""
CNN
1 to 1
data:image/s3,"s3://crabby-images/c5819/c5819ccf092e5d17c099a213e3b0f65de50e4274" alt=""
RNN
- 1 to Many
data:image/s3,"s3://crabby-images/4ae93/4ae932edebb1b82cfdd0578b63652ac4bae31948" alt=""
- Many to Many
data:image/s3,"s3://crabby-images/f3a38/f3a3856eb93f0deb3ef6ee0191e0d8054e904182" alt=""
Recurrent Neural Networks
Variants for inference and training (new)
- Single Gate
- Long Short-Term Memory (LSTM)
- Gated Recurrent Unit (GRU)
- Minimally Gated Unit (MGU)
Activity Classifier
Inference
data:image/s3,"s3://crabby-images/4c641/4c6413bdda64ba734e295874be2fe58a3edda7cf" alt=""
Training
data:image/s3,"s3://crabby-images/9ba9b/9ba9b45cba5c25d31553ce37ce6c9877c4f1ebac" alt=""
data:image/s3,"s3://crabby-images/000a2/000a239e2b1028cf34357c2915a22c7673ffd2c8" alt=""
data:image/s3,"s3://crabby-images/0cd27/0cd27bee74eeca0acc256890e34cd079a342ddc2" alt=""
data:image/s3,"s3://crabby-images/71eec/71eec62787d20c2b23327fed70bbae23c514e3d7" alt=""
Data Converters
data:image/s3,"s3://crabby-images/2a56e/2a56ed8b2df70d02f64755c516cd5b746e8ece3d" alt=""
data:image/s3,"s3://crabby-images/4efc4/4efc40d198c4888890fb4acf926c00f937212775" alt=""
Demo
Object classification training using TensorFlow with MPS
data:image/s3,"s3://crabby-images/f11bc/f11bc5abf45b419bbdd8091a406ae6d4587c58a5" alt=""
data:image/s3,"s3://crabby-images/2d4e2/2d4e275bd5aee376d7e11047708df9e17a4a84ec" alt=""