// WAPaintCanvas.metal
// WhatsApp
// Created by Kuan Yong on 7/1/16.
// Copyright © 2016 WhatsApp. All rights reserved.

#include <metal_stdlib>
#include <metal_texture>

using namespace metal;

struct VertexInOut {
float4 position [[ position ]];
float2 texCoord [[ user(texturecoord) ]];

struct VertexInPositionTextureColor {
float2 position; // alignment: 8
float2 texCoord; // alignment: 8
half4 color; // alignment: 8

struct VertexOutPositionTextureColor {
float4 position [[ position ]]; // alignment: 16
float2 texCoord; // alignment: 8
half4 color; // alignment: 8

struct VertexInPositionTexture {
float2 position; // alignment: 8
float2 texCoord; // alignment: 8

struct VertexOutPositionTexture {
float4 position [[ position ]]; // alignment: 16
float2 texCoord; // alignment: 8

constexpr sampler texSampler(address::clamp_to_zero, filter::linear,

constexpr sampler layerSampler(filter::linear, mip_filter::none);
constexpr sampler cleanPlateSampler(filter::nearest, mip_filter::none);

vertex VertexInOut layerQuadVertex(const device float2 *position [[ buffer(0) ]],

const device float2 *texCoord [[ buffer(1) ]],
uint v_id [[ vertex_id ]] ) {
VertexInOut v_out;
v_out.position = float4(position[v_id], 0.0f, 1.0f);
v_out.texCoord = texCoord[v_id];
return v_out;

fragment half4 layerQuadFragment(VertexInOut inFrag [[ stage_in ]],

texture2d<half> tex2D [[ texture(0) ]]) {
half4 color = tex2D.sample(layerSampler, inFrag.texCoord);
return color;
vertex void
convertPointToVertexPositionSizeColor(const device float2 *pos_in [[ buffer(0) ]],
const device float2 *size_in [[ buffer(1) ]],
const device float4 *color_in
[[ buffer(2) ]],
constant float2 &textureCoord0
[[ buffer(7) ]],
constant float2 &textureCoord1
[[ buffer(8) ]],
device VertexInPositionTextureColor *v_out [[
buffer(4) ]],
uint v_id [[ vertex_id ]] ) {
float2 pos = pos_in[v_id];
float4 color = color_in[v_id];
float width2 = size_in[v_id].x * 0.5;
float height2 = size_in[v_id].y * 0.5;
uint outID = v_id * 6;

VertexInPositionTextureColor v0, v1, v2, v3;

v0.position = float2(pos.x - width2, pos.y - height2);
v0.color = half4(color);
v0.texCoord = float2(textureCoord0.x, textureCoord0.y);
v1.position = float2(pos.x - width2, pos.y + height2);
v1.color = half4(color);
v1.texCoord = float2(textureCoord0.x, textureCoord1.y);
v2.position = float2(pos.x + width2, pos.y - height2);
v2.color = half4(color);
v2.texCoord = float2(textureCoord1.x, textureCoord0.y);
v3.position = float2(pos.x + width2, pos.y + height2);
v3.color = half4(color);
v3.texCoord = float2(textureCoord1.x, textureCoord1.y);

v_out[outID] = v0;
v_out[outID+1] = v0;
v_out[outID+2] = v1;
v_out[outID+3] = v2;
v_out[outID+4] = v3;
v_out[outID+5] = v3;

vertex void
convertPointToVertexPositionSize(const device float2 *pos_in [[ buffer(0) ]],
const device float2 *size_in [[ buffer(1) ]],
device VertexInPositionTexture *v_out [[ buffer(4)
constant float2 &textureCoord0 [[ buffer(7) ]],
constant float2 &textureCoord1 [[ buffer(8) ]],
uint v_id [[ vertex_id ]] ) {
float2 pos = pos_in[v_id];
float width2 = size_in[v_id].x * 0.5;
float height2 = size_in[v_id].y * 0.5;
uint outID = v_id * 6;

VertexInPositionTexture v0, v1, v2, v3;

v0.position = float2(pos.x - width2, pos.y - height2);
v0.texCoord = float2(textureCoord0.x, textureCoord0.y);
v1.position = float2(pos.x - width2, pos.y + height2);
v1.texCoord = float2(textureCoord0.x, textureCoord1.y);
v2.position = float2(pos.x + width2, pos.y - height2);
v2.texCoord = float2(textureCoord1.x, textureCoord0.y);
v3.position = float2(pos.x + width2, pos.y + height2);
v3.texCoord = float2(textureCoord1.x, textureCoord1.y);

v_out[outID] = v0;
v_out[outID+1] = v0;
v_out[outID+2] = v1;
v_out[outID+3] = v2;
v_out[outID+4] = v3;
v_out[outID+5] = v3;

vertex void
convertPointToVertexPositionSizeColorAngle(const device float2 *pos_in [[ buffer(0)
const device float2 *size_in
[[ buffer(1) ]],
const device float4 *color_in
[[ buffer(2) ]],
const device float *angle_in
[[ buffer(3) ]],
constant float &ratio [[ buffer(6) ]],
constant float2 &textureCoord0
[[ buffer(7) ]],
constant float2 &textureCoord1
[[ buffer(8) ]],
device VertexInPositionTextureColor
*v_out [[ buffer(4) ]],
uint v_id [[ vertex_id ]] ) {
float2 pos = pos_in[v_id];
float4 color = color_in[v_id];
float width2 = size_in[v_id].x * 0.5;
float height2 = size_in[v_id].y * 0.5;
float angle = angle_in[v_id];
float2 x = float2(cos(angle) * width2, sin(angle) * ratio * width2);
float2 y = float2(-sin(angle) / ratio * height2, cos(angle) * height2);
uint outID = v_id * 6;

VertexInPositionTextureColor v0, v1, v2, v3;

v0.position = pos - x - y;
v0.color = half4(color);
v0.texCoord = float2(textureCoord0.x, textureCoord0.y);
v1.position = pos - x + y;
v1.color = half4(color);
v1.texCoord = float2(textureCoord0.x, textureCoord1.y);
v2.position = pos + x - y;
v2.color = half4(color);
v2.texCoord = float2(textureCoord1.x, textureCoord0.y);
v3.position = pos + x + y;
v3.color = half4(color);
v3.texCoord = float2(textureCoord1.x, textureCoord1.y);

v_out[outID] = v0;
v_out[outID+1] = v0;
v_out[outID+2] = v1;
v_out[outID+3] = v2;
v_out[outID+4] = v3;
v_out[outID+5] = v3;

vertex VertexOutPositionTextureColor
brushPickingVertex(const device VertexInPositionTextureColor *v_in [[ buffer(4) ]],
constant float3x3 &mat [[ buffer(5) ]],
uint v_id [[ vertex_id ]]) {
VertexInPositionTextureColor v = v_in[v_id];
VertexOutPositionTextureColor v_out;
float4 position;
position.xyw = mat * float3(v.position, 1.0f);
position.z = 0.0f;
v_out.position = position;
v_out.texCoord = v.texCoord;
v_out.color = v.color;
return v_out;

fragment half4
brushPickingFragment(VertexOutPositionTextureColor inFrag [[ stage_in ]],
texture2d<float> tex2D [[ texture(0) ]],
half4 dst_color [[ color(0) ]]) {
// Force texture alpha to go to 1.0 to make transparent areas pickable.
float alpha = min(1.0f, 1000.0f * tex2D.sample(texSampler, inFrag.texCoord).a);
half4 src_color = alpha * inFrag.color;
half4 color = src_color + dst_color * (1.0h - src_color.a);
return saturate(color);

vertex VertexOutPositionTextureColor
basicBrushVertex(const device VertexInPositionTextureColor *v_in [[ buffer(4) ]],
constant float3x3 &mat [[ buffer(5) ]],
uint v_id [[ vertex_id ]]) {
VertexInPositionTextureColor v = v_in[v_id];
VertexOutPositionTextureColor v_out;
float4 position;
position.xyw = mat * float3(v.position, 1.0f);
position.z = 0.0f;
v_out.position = position;
v_out.texCoord = v.texCoord;
v_out.color = v.color;
return v_out;

fragment half4
basicBrushFragment(VertexOutPositionTextureColor inFrag [[ stage_in ]],
texture2d<half> tex2D [[ texture(0) ]],
half4 dst_color [[ color(0) ]]) {
half4 src_color = tex2D.sample(texSampler, inFrag.texCoord) * inFrag.color;
half4 color = src_color + dst_color * (1.0h - src_color.a);
// Setting the output alpha this way is a common approach for rendering hard
brushes where
// textured quads are stacked on top of one another.
color.a = max(src_color.a, dst_color.a);
return saturate(color);

vertex VertexOutPositionTexture
desaturateBrushVertex(const device VertexInPositionTexture *v_in [[ buffer(4) ]],
constant float3x3 &mat [[ buffer(5) ]],
uint v_id [[ vertex_id ]]) {
VertexInPositionTexture v = v_in[v_id];
VertexOutPositionTexture v_out;
float4 position;
position.xyw = mat * float3(v.position, 1.0f);
position.z = 0.0f;
v_out.position = position;
v_out.texCoord = v.texCoord;
return v_out;

fragment half4
desaturateBrushFragment(VertexOutPositionTexture inFrag [[ stage_in ]],
constant float2 &windowSize [[ buffer(0) ]],
texture2d<half> brushTex2D [[ texture(0) ]],
texture2d<half> cleanPlateTex2D [[ texture(1) ]],
half4 dst_color [[ color(0) ]]) {
half4 plateColor = cleanPlateTex2D.sample(cleanPlateSampler,
float2(inFrag.position.x / windowSize.x, 1.0h - inFrag.position.y / windowSize.y));
half4 src_color = dst_color + plateColor * (1.0h - dst_color.a);
half gray = src_color.r * 0.2126h + src_color.g * 0.7152h + src_color.b *
src_color = half4(gray, gray, gray, src_color.a);
src_color = src_color * brushTex2D.sample(texSampler, inFrag.texCoord);
half4 color = src_color + dst_color * (1.0h - src_color.a);
return saturate(color);

vertex VertexOutPositionTexture
pixelateBrushVertex(const device VertexInPositionTexture *v_in [[ buffer(4) ]],
constant float3x3 &mat [[ buffer(5) ]],
uint v_id [[ vertex_id ]]) {
VertexInPositionTexture v = v_in[v_id];
VertexOutPositionTexture v_out;
float4 position;
position.xyw = mat * float3(v.position, 1.0f);
position.z = 0.0f;
v_out.position = position;
v_out.texCoord = v.texCoord;
return v_out;

fragment half4
pixelateBrushFragment(VertexOutPositionTexture inFrag [[ stage_in ]],
constant float2 &windowSize [[ buffer(0) ]],
constant float &width [[ buffer(1) ]],
texture2d<half> brushTex2D [[ texture(0) ]],
texture2d<half> cleanPlateTex2D [[ texture(1) ]],
half4 dst_color [[ color(0) ]]) {
float2 position = float2(inFrag.position.x / windowSize.x, 1.0h -
inFrag.position.y / windowSize.y);
float aspectRatio = windowSize.x / windowSize.y;
float2 stepSize = float2(width, aspectRatio * width);
position = round(position / stepSize) * stepSize;
half4 src_color = cleanPlateTex2D.sample(cleanPlateSampler, position);
src_color = src_color * brushTex2D.sample(texSampler, inFrag.texCoord);
half4 color = src_color + dst_color * (1.0h - src_color.a);
return saturate(color);
fragment half4
basicBrushEraseFragment(VertexOutPositionTextureColor inFrag [[ stage_in ]],
texture2d<half> tex2D [[ texture(0) ]],
half4 dst_color [[ color(0) ]]) {
half src_alpha = tex2D.sample(texSampler, inFrag.texCoord).a;
dst_color = dst_color * (1.0h - src_alpha);
return dst_color;

vertex VertexOutPositionTexture
eraserVertex(const device VertexInPositionTexture *v_in [[ buffer(4) ]],
constant float3x3 &mat [[ buffer(5) ]],
uint v_id [[ vertex_id ]]) {
VertexInPositionTexture v = v_in[v_id];
VertexOutPositionTexture v_out;
float4 position;
position.xyw = mat * float3(v.position, 1.0f);
position.z = 0.0f;
v_out.position = position;
v_out.texCoord = v.texCoord;
return v_out;

fragment half4
eraserFragment(VertexOutPositionTexture inFrag [[ stage_in ]],
texture2d<half> tex2D [[ texture(0) ]],
half4 dst_color [[ color(0) ]]) {
half src_alpha = tex2D.sample(texSampler, inFrag.texCoord).a;
dst_color = dst_color * (1.0h - src_alpha);
return dst_color;

vertex VertexOutPositionTextureColor
stampVertex(const device VertexInPositionTextureColor *v_in [[ buffer(4) ]],
constant float3x3 &mat [[ buffer(5) ]],
uint v_id [[ vertex_id ]]) {
VertexInPositionTextureColor v = v_in[v_id];
VertexOutPositionTextureColor v_out;
float4 position;
position.xyw = mat * float3(v.position, 1.0f);
position.z = 0.0f;
v_out.position = position;
v_out.texCoord = v.texCoord;
v_out.color = v.color;
return v_out;

fragment half4
stampFragment(VertexOutPositionTextureColor inFrag [[ stage_in ]],
texture2d<half> tex2D [[ texture(0) ]],
half4 dst_color [[ color(0) ]]) {
// Alpha is <1.0 only when we are dimming the stamp when it's eligible to be
half4 src_color = tex2D.sample(texSampler, inFrag.texCoord) * inFrag.color.a;
half4 color = src_color + dst_color * (1.0h - src_color.a);
return saturate(color);

