From 7038c6bbeaebcbce3d992ba68f28ba99ab56f09d Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 11 Feb 2025 17:17:19 -0800 Subject: [PATCH 1/9] Practice Round Problems' --- examples/grayscale_py/eval.py | 1 + examples/grayscale_py/reference.py | 44 ++++++++++ examples/grayscale_py/submission_triton.py | 70 ++++++++++++++++ examples/grayscale_py/task.py | 9 +++ examples/grayscale_py/task.yml | 33 ++++++++ examples/grayscale_py/utils.py | 94 ++++++++++++++++++++++ 6 files changed, 251 insertions(+) create mode 120000 examples/grayscale_py/eval.py create mode 100644 examples/grayscale_py/reference.py create mode 100644 examples/grayscale_py/submission_triton.py create mode 100644 examples/grayscale_py/task.py create mode 100644 examples/grayscale_py/task.yml create mode 100644 examples/grayscale_py/utils.py diff --git a/examples/grayscale_py/eval.py b/examples/grayscale_py/eval.py new file mode 120000 index 00000000..caf621bd --- /dev/null +++ b/examples/grayscale_py/eval.py @@ -0,0 +1 @@ +../eval.py \ No newline at end of file diff --git a/examples/grayscale_py/reference.py b/examples/grayscale_py/reference.py new file mode 100644 index 00000000..264e733e --- /dev/null +++ b/examples/grayscale_py/reference.py @@ -0,0 +1,44 @@ +from utils import verbose_allclose +import torch +from task import input_t, output_t + +def ref_kernel(data: input_t) -> output_t: + """ + Reference implementation of RGB to grayscale conversion using PyTorch. + Uses the standard coefficients: Y = 0.2989 R + 0.5870 G + 0.1140 B + + Args: + data: RGB tensor of shape (H, W, 3) with values in [0, 1] + Returns: + Grayscale tensor of shape (H, W) with values in [0, 1] + """ + # Standard RGB to Grayscale coefficients + weights = torch.tensor([0.2989, 0.5870, 0.1140], + device=data.device, + dtype=data.dtype) + return torch.sum(data * weights, dim=-1) + +def generate_input(size: int, seed: int) -> input_t: + """ + Generates random RGB image tensor of specified size. + Returns: + Tensor of shape (size, size, 3) with values in [0, 1] + """ + gen = torch.Generator(device='cuda') + gen.manual_seed(seed) + return torch.rand(size, size, 3, + device='cuda', + dtype=torch.float32, + generator=gen).contiguous() + +def check_implementation( + data: input_t, + output: output_t, +) -> str: + expected = ref_kernel(data) + reasons = verbose_allclose(output, expected, rtol=1e-4, atol=1e-4) + + if len(reasons) > 0: + return "mismatch found! custom implementation doesn't match reference: " + reasons[0] + + return '' \ No newline at end of file diff --git a/examples/grayscale_py/submission_triton.py b/examples/grayscale_py/submission_triton.py new file mode 100644 index 00000000..27be62f9 --- /dev/null +++ b/examples/grayscale_py/submission_triton.py @@ -0,0 +1,70 @@ +import torch +import triton +import triton.language as tl +from task import input_t, output_t + +@triton.jit +def grayscale_kernel( + input_ptr, output_ptr, + H, W, + stride_h, stride_w, stride_c, + BLOCK_SIZE: tl.constexpr, +): + # Program ID + pid = tl.program_id(0) + + # Calculate start indices + block_start_h = (pid // ((W + BLOCK_SIZE - 1) // BLOCK_SIZE)) * BLOCK_SIZE + block_start_w = (pid % ((W + BLOCK_SIZE - 1) // BLOCK_SIZE)) * BLOCK_SIZE + + # Offsets for this block + offs_h = block_start_h + tl.arange(0, BLOCK_SIZE) + offs_w = block_start_w + tl.arange(0, BLOCK_SIZE) + + # Create mask for valid pixels + mask = (offs_h[:, None] < H) & (offs_w[None, :] < W) + + # RGB to Grayscale coefficients + R_COEF = 0.2989 + G_COEF = 0.5870 + B_COEF = 0.1140 + + # Calculate base pointer for each pixel in the block + base_ptr = offs_h[:, None] * stride_h + offs_w[None, :] * stride_w + + # Load RGB channels + r = tl.load(input_ptr + base_ptr + 0 * stride_c, mask=mask, other=0.0) + g = tl.load(input_ptr + base_ptr + 1 * stride_c, mask=mask, other=0.0) + b = tl.load(input_ptr + base_ptr + 2 * stride_c, mask=mask, other=0.0) + + # Convert to grayscale + gray = R_COEF * r + G_COEF * g + B_COEF * b + + # Store result + out_ptr = offs_h[:, None] * W + offs_w[None, :] + tl.store(output_ptr + out_ptr, gray, mask=mask) + +def custom_kernel(data: input_t) -> output_t: + H, W, C = data.shape + assert C == 3, "Input must be an RGB image" + + # Create output tensor + output = torch.empty((H, W), device=data.device, dtype=data.dtype) + + # Calculate strides + stride_h = W * C + stride_w = C + stride_c = 1 + + # Launch kernel + BLOCK_SIZE = 32 + grid = ((H + BLOCK_SIZE - 1) // BLOCK_SIZE) * ((W + BLOCK_SIZE - 1) // BLOCK_SIZE) + + grayscale_kernel[grid]( + data, output, + H, W, + stride_h, stride_w, stride_c, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output \ No newline at end of file diff --git a/examples/grayscale_py/task.py b/examples/grayscale_py/task.py new file mode 100644 index 00000000..4a717fcc --- /dev/null +++ b/examples/grayscale_py/task.py @@ -0,0 +1,9 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=torch.Tensor) # Input will be (H, W, 3) RGB tensor +output_t = TypeVar("output_t", bound=torch.Tensor) # Output will be (H, W) grayscale tensor + +class TestSpec(TypedDict): + size: int # Size of the square image (H=W) + seed: int \ No newline at end of file diff --git a/examples/grayscale_py/task.yml b/examples/grayscale_py/task.yml new file mode 100644 index 00000000..b14a81b5 --- /dev/null +++ b/examples/grayscale_py/task.yml @@ -0,0 +1,33 @@ +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "eval.py"} + +lang: "py" + +description: | + Implement an RGB to grayscale conversion kernel that matches the reference implementation. + The kernel should convert RGB images to grayscale using the standard coefficients: + Y = 0.2989 R + 0.5870 G + 0.1140 B + + Input: RGB tensor of shape (H, W, 3) with values in [0, 1] + Output: Grayscale tensor of shape (H, W) with values in [0, 1] + +config: + main: "eval.py" + +tests: + - {"size": 127, "seed": 4242} + - {"size": 128, "seed": 5236} + - {"size": 129, "seed": 1001} + - {"size": 256, "seed": 5531} + - {"size": 512, "seed": 9173} + +benchmarks: + - {"size": 1024, "seed": 54352} + - {"size": 2048, "seed": 93246} + - {"size": 4096, "seed": 6256} + - {"size": 8192, "seed": 8841} + - {"size": 16384, "seed": 6252} \ No newline at end of file diff --git a/examples/grayscale_py/utils.py b/examples/grayscale_py/utils.py new file mode 100644 index 00000000..5363abdd --- /dev/null +++ b/examples/grayscale_py/utils.py @@ -0,0 +1,94 @@ +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +def verbose_allclose( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + tensor1 (torch.Tensor): First tensor to compare. + tensor2 (torch.Tensor): Second tensor to compare. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + list[str]: List of error messages if tensors don't match, empty list otherwise. + """ + # Check if the shapes of the tensors match + if tensor1.shape != tensor2.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(tensor1 - tensor2) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(tensor2) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.sum().item() + + # Check if all elements are close + all_close = num_mismatched == 0 + + # Return error messages if there are mismatches + if not all_close and num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}", + f"Mismatched elements: {mismatched_indices}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {tensor1[i]} {tensor2[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] \ No newline at end of file From 8f38e1b8e6deaa6dc73b8665344faf23ee5becd7 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 11 Feb 2025 17:29:19 -0800 Subject: [PATCH 2/9] vector sum --- examples/vectorsum_py/eval.py | 1 + examples/vectorsum_py/reference.py | 35 +++++++++++ examples/vectorsum_py/submission.py | 54 +++++++++++++++++ examples/vectorsum_py/task.py | 9 +++ examples/vectorsum_py/task.yml | 31 ++++++++++ examples/vectorsum_py/utils.py | 93 +++++++++++++++++++++++++++++ 6 files changed, 223 insertions(+) create mode 120000 examples/vectorsum_py/eval.py create mode 100644 examples/vectorsum_py/reference.py create mode 100644 examples/vectorsum_py/submission.py create mode 100644 examples/vectorsum_py/task.py create mode 100644 examples/vectorsum_py/task.yml create mode 100644 examples/vectorsum_py/utils.py diff --git a/examples/vectorsum_py/eval.py b/examples/vectorsum_py/eval.py new file mode 120000 index 00000000..caf621bd --- /dev/null +++ b/examples/vectorsum_py/eval.py @@ -0,0 +1 @@ +../eval.py \ No newline at end of file diff --git a/examples/vectorsum_py/reference.py b/examples/vectorsum_py/reference.py new file mode 100644 index 00000000..2b662f81 --- /dev/null +++ b/examples/vectorsum_py/reference.py @@ -0,0 +1,35 @@ +from utils import verbose_allclose +import torch +from task import input_t, output_t + +def ref_kernel(data: input_t) -> output_t: + """ + Reference implementation of vector sum reduction using PyTorch. + Args: + data: Input tensor to be reduced + Returns: + Tensor containing the sum of all elements + """ + return data.sum() + +def generate_input(size: int, seed: int) -> input_t: + """ + Generates random input tensor of specified shape. + Returns: + Tensor to be reduced + """ + gen = torch.Generator(device='cuda') + gen.manual_seed(seed) + return torch.randn(size, device='cuda', dtype=torch.float32, generator=gen).contiguous() + +def check_implementation( + data: input_t, + output: output_t, +) -> bool: + expected = ref_kernel(data) + reasons = verbose_allclose(output, expected) + + if len(reasons) > 0: + return "mismatch found! custom implementation doesn't match reference: " + reasons[0] + + return '' \ No newline at end of file diff --git a/examples/vectorsum_py/submission.py b/examples/vectorsum_py/submission.py new file mode 100644 index 00000000..4e1969e2 --- /dev/null +++ b/examples/vectorsum_py/submission.py @@ -0,0 +1,54 @@ +import torch +import triton +import triton.language as tl +from task import input_t, output_t + +@triton.jit +def sum_kernel( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """ + Parallel reduction kernel that sums elements in chunks. + Each thread block reduces BLOCK_SIZE elements. + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load data + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + + # Compute local reduction + block_sum = tl.sum(x, axis=0) + + # Store the partial sum + tl.atomic_add(output_ptr, block_sum) + +def custom_kernel(data: input_t) -> output_t: + """ + Performs parallel reduction to compute sum of all elements. + Args: + data: Input tensor to be reduced + Returns: + Tensor containing the sum of all elements + """ + n_elements = data.numel() + output = torch.zeros(1, device=data.device, dtype=data.dtype) + + # Configure kernel + BLOCK_SIZE = 1024 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + + # Launch kernel + sum_kernel[grid]( + data, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output[0] \ No newline at end of file diff --git a/examples/vectorsum_py/task.py b/examples/vectorsum_py/task.py new file mode 100644 index 00000000..62e5dae0 --- /dev/null +++ b/examples/vectorsum_py/task.py @@ -0,0 +1,9 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=torch.Tensor) +output_t = TypeVar("output_t", bound=torch.Tensor) + +class TestSpec(TypedDict): + size: int + seed: int \ No newline at end of file diff --git a/examples/vectorsum_py/task.yml b/examples/vectorsum_py/task.yml new file mode 100644 index 00000000..1c8b6018 --- /dev/null +++ b/examples/vectorsum_py/task.yml @@ -0,0 +1,31 @@ +# name: vectorsum-cuda-inline + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "eval.py"} + +lang: "py" + +description: | + Implement a vector sum reduction kernel using CUDA inline function that matches the reference implementation. + The kernel should compute the sum of all elements in the input tensor. + +config: + main: "eval.py" + +tests: + - {"size": 1023, "seed": 4242} + - {"size": 1024, "seed": 5236} + - {"size": 1025, "seed": 1001} + - {"size": 2048, "seed": 5531} + - {"size": 4096, "seed": 9173} + +benchmarks: + - {"size": 8192, "seed": 54352} + - {"size": 16384, "seed": 93246} + - {"size": 32768, "seed": 6256} + - {"size": 65536, "seed": 8841} + - {"size": 131072, "seed": 6252} \ No newline at end of file diff --git a/examples/vectorsum_py/utils.py b/examples/vectorsum_py/utils.py new file mode 100644 index 00000000..cb7e26bb --- /dev/null +++ b/examples/vectorsum_py/utils.py @@ -0,0 +1,93 @@ +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + +def verbose_allclose( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + tensor1 (torch.Tensor): First tensor to compare. + tensor2 (torch.Tensor): Second tensor to compare. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + list[str]: List of error messages if tensors don't match, empty list otherwise. + """ + # Check if the shapes of the tensors match + if tensor1.shape != tensor2.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(tensor1 - tensor2) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(tensor2) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.sum().item() + + # Check if all elements are close + all_close = num_mismatched == 0 + + # Return detailed information if there are mismatches + if not all_close and num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}", + f"Mismatched elements: {mismatched_indices}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {tensor1[i]} {tensor2[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] \ No newline at end of file From 4ec2998a91977aa3cb1aca952c0fb7119d08527c Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 11 Feb 2025 17:33:47 -0800 Subject: [PATCH 3/9] conv2d --- examples/conv2d_py/eval.py | 1 + examples/conv2d_py/reference.py | 62 ++++++++++++++++ examples/conv2d_py/submission.py | 123 +++++++++++++++++++++++++++++++ examples/conv2d_py/task.py | 18 +++++ examples/conv2d_py/task.yml | 31 ++++++++ 5 files changed, 235 insertions(+) create mode 120000 examples/conv2d_py/eval.py create mode 100644 examples/conv2d_py/reference.py create mode 100644 examples/conv2d_py/submission.py create mode 100644 examples/conv2d_py/task.py create mode 100644 examples/conv2d_py/task.yml diff --git a/examples/conv2d_py/eval.py b/examples/conv2d_py/eval.py new file mode 120000 index 00000000..caf621bd --- /dev/null +++ b/examples/conv2d_py/eval.py @@ -0,0 +1 @@ +../eval.py \ No newline at end of file diff --git a/examples/conv2d_py/reference.py b/examples/conv2d_py/reference.py new file mode 100644 index 00000000..0eb48e6f --- /dev/null +++ b/examples/conv2d_py/reference.py @@ -0,0 +1,62 @@ +from utils import verbose_allclose +import torch +import torch.nn.functional as F +from task import input_t, output_t, KernelSpec + +def ref_kernel(data: input_t, spec: KernelSpec) -> output_t: + """ + Reference implementation of 2D convolution using PyTorch. + Args: + data: Tuple of (input tensor, kernel tensor) + spec: Convolution specifications (stride, padding) + Returns: + Output tensor after convolution + """ + input_tensor, kernel = data + return F.conv2d( + input_tensor, + kernel, + stride=spec.stride, + padding=spec.padding + ) + +def generate_input(size: int, kernel_size: int, channels: int, batch: int, seed: int) -> input_t: + """ + Generates random input and kernel tensors. + Returns: + Tuple of (input tensor, kernel tensor) + """ + gen = torch.Generator(device='cuda') + gen.manual_seed(seed) + + # Generate input tensor: [batch, in_channels, height, width] + input_tensor = torch.randn( + batch, channels, size, size, + device='cuda', + dtype=torch.float32, + generator=gen + ).contiguous() + + # Generate kernel tensor: [out_channels, in_channels, kernel_height, kernel_width] + # Here we use same number of output channels as input channels for simplicity + kernel = torch.randn( + channels, channels, kernel_size, kernel_size, + device='cuda', + dtype=torch.float32, + generator=gen + ).contiguous() + + return (input_tensor, kernel) + +def check_implementation( + data: input_t, + spec: KernelSpec, + output: output_t, +) -> str: + expected = ref_kernel(data, spec) + reasons = verbose_allclose(output, expected, rtol=1e-3, atol=1e-3) + + if len(reasons) > 0: + return "mismatch found! custom implementation doesn't match reference: " + reasons[0] + + return '' \ No newline at end of file diff --git a/examples/conv2d_py/submission.py b/examples/conv2d_py/submission.py new file mode 100644 index 00000000..1867acd3 --- /dev/null +++ b/examples/conv2d_py/submission.py @@ -0,0 +1,123 @@ +import torch +import triton +import triton.language as tl +from task import input_t, output_t, KernelSpec + +@triton.jit +def conv2d_kernel( + # Pointers to matrices + input_ptr, kernel_ptr, output_ptr, + # Matrix dimensions + batch, in_channels, out_channels, + in_height, in_width, + kernel_size, stride, padding, + out_height, out_width, + # Block sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + 2D Convolution kernel. + Each thread block handles computation for a BLOCK_SIZE_M x BLOCK_SIZE_N region of the output. + """ + # Program ID + pid = tl.program_id(0) + + # Calculate output position + n_blocks_m = triton.cdiv(out_height, BLOCK_SIZE_M) + batch_idx = pid // (n_blocks_m * out_channels) + tmp = pid % (n_blocks_m * out_channels) + out_ch = tmp // n_blocks_m + block_m = tmp % n_blocks_m + + # Calculate output row and column ranges for this block + out_m = block_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + out_n = tl.arange(0, BLOCK_SIZE_N) + + # Calculate input positions with padding offset + in_m = out_m * stride - padding + in_n = out_n * stride - padding + + # Initialize output accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Iterate over input channels and kernel positions + for in_ch in range(in_channels): + for kh in range(kernel_size): + for kw in range(kernel_size): + # Calculate input positions + h_pos = in_m + kh + w_pos = in_n + kw + + # Create masks for valid positions + m_mask = (h_pos >= 0) & (h_pos < in_height) + n_mask = (w_pos >= 0) & (w_pos < in_width) + mask = m_mask[:, None] & n_mask[None, :] + + # Load input values + x_pos = h_pos[:, None] * in_width + w_pos[None, :] + input_idx = ((batch_idx * in_channels + in_ch) * in_height * in_width + x_pos) + x = tl.load(input_ptr + input_idx, mask=mask, other=0.0) + + # Load kernel value + k_idx = ((out_ch * in_channels + in_ch) * kernel_size * kernel_size + + kh * kernel_size + kw) + k = tl.load(kernel_ptr + k_idx) + + # Accumulate + acc += k * x + + # Write output + out_pos = out_m[:, None] * out_width + out_n[None, :] + output_idx = ((batch_idx * out_channels + out_ch) * out_height * out_width + + out_pos) + + # Create output mask + m_mask = out_m < out_height + n_mask = out_n < out_width + mask = m_mask[:, None] & n_mask[None, :] + + # Store output + tl.store(output_ptr + output_idx, acc, mask=mask) + +def custom_kernel(data: input_t, spec: KernelSpec) -> output_t: + """ + Performs 2D convolution using Triton kernel. + Args: + data: Tuple of (input tensor, kernel tensor) + spec: Convolution specifications + Returns: + Output tensor after convolution + """ + input_tensor, kernel = data + batch, in_channels, in_height, in_width = input_tensor.shape + out_channels, _, kernel_size, _ = kernel.shape + + # Calculate output dimensions + out_height = ((in_height + 2 * spec.padding - kernel_size) // spec.stride) + 1 + out_width = ((in_width + 2 * spec.padding - kernel_size) // spec.stride) + 1 + + # Allocate output + output = torch.empty( + (batch, out_channels, out_height, out_width), + device=input_tensor.device, + dtype=input_tensor.dtype + ) + + # Configure kernel + BLOCK_SIZE_M = 8 + BLOCK_SIZE_N = 8 + grid = (batch * out_channels * triton.cdiv(out_height, BLOCK_SIZE_M),) + + # Launch kernel + conv2d_kernel[grid]( + input_tensor, kernel, output, + batch, in_channels, out_channels, + in_height, in_width, + kernel_size, spec.stride, spec.padding, + out_height, out_width, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + + return output \ No newline at end of file diff --git a/examples/conv2d_py/task.py b/examples/conv2d_py/task.py new file mode 100644 index 00000000..6cce0e6e --- /dev/null +++ b/examples/conv2d_py/task.py @@ -0,0 +1,18 @@ +from typing import TypedDict, TypeVar, Tuple +import torch +from dataclasses import dataclass + +input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) + +@dataclass +class KernelSpec: + stride: int + padding: int + +class TestSpec(TypedDict): + size: int + kernel_size: int + channels: int + batch: int + seed: int \ No newline at end of file diff --git a/examples/conv2d_py/task.yml b/examples/conv2d_py/task.yml new file mode 100644 index 00000000..ad073fcb --- /dev/null +++ b/examples/conv2d_py/task.yml @@ -0,0 +1,31 @@ +# name: conv2d-cuda-inline + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "eval.py"} + +lang: "py" + +description: | + Implement a 2D convolution kernel using CUDA inline function that matches the reference implementation. + The kernel should perform 2D convolution with the given specifications (stride and padding). + +config: + main: "eval.py" + +tests: + - {"size": 32, "kernel_size": 3, "channels": 16, "batch": 1, "seed": 4242} + - {"size": 32, "kernel_size": 5, "channels": 16, "batch": 2, "seed": 5236} + - {"size": 64, "kernel_size": 3, "channels": 32, "batch": 1, "seed": 1001} + - {"size": 64, "kernel_size": 5, "channels": 32, "batch": 2, "seed": 5531} + - {"size": 128, "kernel_size": 3, "channels": 64, "batch": 1, "seed": 9173} + +benchmarks: + - {"size": 128, "kernel_size": 3, "channels": 64, "batch": 4, "seed": 54352} + - {"size": 128, "kernel_size": 5, "channels": 64, "batch": 4, "seed": 93246} + - {"size": 256, "kernel_size": 3, "channels": 128, "batch": 2, "seed": 6256} + - {"size": 256, "kernel_size": 5, "channels": 128, "batch": 2, "seed": 8841} + - {"size": 512, "kernel_size": 3, "channels": 256, "batch": 1, "seed": 6252} \ No newline at end of file From 1f92c90ee21101178de1a80384f1ee29e9e5c75f Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 11 Feb 2025 17:36:51 -0800 Subject: [PATCH 4/9] prefixsum --- examples/prefixsum_py/reference.py | 35 +++++++++ examples/prefixsum_py/submission.py | 113 ++++++++++++++++++++++++++++ examples/prefixsum_py/task.py | 9 +++ examples/prefixsum_py/task.yml | 31 ++++++++ 4 files changed, 188 insertions(+) create mode 100644 examples/prefixsum_py/reference.py create mode 100644 examples/prefixsum_py/submission.py create mode 100644 examples/prefixsum_py/task.py create mode 100644 examples/prefixsum_py/task.yml diff --git a/examples/prefixsum_py/reference.py b/examples/prefixsum_py/reference.py new file mode 100644 index 00000000..bce90273 --- /dev/null +++ b/examples/prefixsum_py/reference.py @@ -0,0 +1,35 @@ +from utils import verbose_allclose +import torch +from task import input_t, output_t + +def ref_kernel(data: input_t) -> output_t: + """ + Reference implementation of inclusive prefix sum using PyTorch. + Args: + data: Input tensor to compute prefix sum on + Returns: + Tensor containing the inclusive prefix sum + """ + return torch.cumsum(data, dim=0) + +def generate_input(size: int, seed: int) -> input_t: + """ + Generates random input tensor. + Returns: + Tensor to compute prefix sum on + """ + gen = torch.Generator(device='cuda') + gen.manual_seed(seed) + return torch.randn(size, device='cuda', dtype=torch.float32, generator=gen).contiguous() + +def check_implementation( + data: input_t, + output: output_t, +) -> str: + expected = ref_kernel(data) + reasons = verbose_allclose(output, expected, rtol=1e-5, atol=1e-5) + + if len(reasons) > 0: + return "mismatch found! custom implementation doesn't match reference: " + reasons[0] + + return '' \ No newline at end of file diff --git a/examples/prefixsum_py/submission.py b/examples/prefixsum_py/submission.py new file mode 100644 index 00000000..1d06b8ab --- /dev/null +++ b/examples/prefixsum_py/submission.py @@ -0,0 +1,113 @@ +import torch +import triton +import triton.language as tl +from task import input_t, output_t + +@triton.jit +def scan_kernel( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """ + Single-block inclusive prefix sum kernel. + Uses a two-pass approach: up-sweep and down-sweep. + """ + # Get program ID and allocate shared memory + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load data into shared memory + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + + # Up-sweep: Build sum tree + offset = 1 + for d in range(triton.next_power_of_2(BLOCK_SIZE) // 2): + mask = tl.arange(0, BLOCK_SIZE) % (2 * offset) == (2 * offset - 1) + vals = tl.where(mask, x, 0.0) + vals = tl.sum(vals, axis=0) + x = x + tl.where(mask, -x + vals, 0.0) + offset *= 2 + + # Down-sweep: Distribute sums + for d in range(triton.next_power_of_2(BLOCK_SIZE) // 2 - 1, -1, -1): + offset = 1 << d + mask = tl.arange(0, BLOCK_SIZE) % (2 * offset) == (2 * offset - 1) + vals = tl.where(mask, x, 0.0) + x = x + tl.where(tl.arange(0, BLOCK_SIZE) % (2 * offset) >= offset, vals, 0.0) + + # Store results + output_mask = offsets < n_elements + tl.store(output_ptr + offsets, x, mask=output_mask) + +@triton.jit +def block_sum_kernel( + block_sums_ptr, + output_ptr, + block_size, + n_blocks, + BLOCK_SIZE: tl.constexpr, +): + """ + Adds block sums to subsequent blocks to get final prefix sum. + """ + pid = tl.program_id(0) + block_idx = pid + 1 # Skip first block + + if block_idx < n_blocks: + # Load block sum from previous block + prev_sum = tl.load(block_sums_ptr + block_idx - 1) + + # Add to all elements in current block + offsets = block_idx * block_size + tl.arange(0, BLOCK_SIZE) + mask = offsets < (block_idx + 1) * block_size + x = tl.load(output_ptr + offsets, mask=mask, other=0.0) + x = x + prev_sum + tl.store(output_ptr + offsets, x, mask=mask) + +def custom_kernel(data: input_t) -> output_t: + """ + Multi-block prefix sum implementation. + Args: + data: Input tensor + Returns: + Tensor containing inclusive prefix sum + """ + n_elements = data.numel() + output = torch.empty_like(data) + + # Configure kernel + BLOCK_SIZE = 1024 + n_blocks = triton.cdiv(n_elements, BLOCK_SIZE) + + # Phase 1: Compute prefix sum within each block + scan_kernel[(n_blocks,)]( + data, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + + if n_blocks > 1: + # Get block sums + block_sums = torch.empty(n_blocks, device=data.device, dtype=data.dtype) + block_sums[0] = output[BLOCK_SIZE-1] + for i in range(1, n_blocks-1): + block_sums[i] = output[(i+1)*BLOCK_SIZE-1] + + # Compute prefix sum of block sums + block_sums = torch.cumsum(block_sums, dim=0) + + # Phase 2: Add block sums to subsequent blocks + block_sum_kernel[(n_blocks-1,)]( + block_sums, + output, + BLOCK_SIZE, + n_blocks, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output \ No newline at end of file diff --git a/examples/prefixsum_py/task.py b/examples/prefixsum_py/task.py new file mode 100644 index 00000000..62e5dae0 --- /dev/null +++ b/examples/prefixsum_py/task.py @@ -0,0 +1,9 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=torch.Tensor) +output_t = TypeVar("output_t", bound=torch.Tensor) + +class TestSpec(TypedDict): + size: int + seed: int \ No newline at end of file diff --git a/examples/prefixsum_py/task.yml b/examples/prefixsum_py/task.yml new file mode 100644 index 00000000..36cfab42 --- /dev/null +++ b/examples/prefixsum_py/task.yml @@ -0,0 +1,31 @@ +# name: prefixsum-cuda-inline + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "eval.py"} + +lang: "py" + +description: | + Implement an inclusive prefix sum (scan) kernel using CUDA inline function that matches the reference implementation. + The kernel should compute the cumulative sum of all elements up to each position. + +config: + main: "eval.py" + +tests: + - {"size": 1023, "seed": 4242} + - {"size": 1024, "seed": 5236} + - {"size": 1025, "seed": 1001} + - {"size": 2048, "seed": 5531} + - {"size": 4096, "seed": 9173} + +benchmarks: + - {"size": 8192, "seed": 54352} + - {"size": 16384, "seed": 93246} + - {"size": 32768, "seed": 6256} + - {"size": 65536, "seed": 8841} + - {"size": 131072, "seed": 6252} \ No newline at end of file From a5297cbb3afb90289027cbbcf186ff30d9f81be6 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 11 Feb 2025 17:58:49 -0800 Subject: [PATCH 5/9] more --- examples/histogram_py/eval.py | 0 examples/histogram_py/reference.py | 47 +++++++++++++ examples/histogram_py/submission.py | 81 ++++++++++++++++++++++ examples/histogram_py/task.py | 16 +++++ examples/histogram_py/task.yml | 31 +++++++++ examples/mergesort_py/reference.py | 35 ++++++++++ examples/mergesort_py/submission.py | 102 ++++++++++++++++++++++++++++ examples/mergesort_py/task.py | 9 +++ examples/mergesort_py/task.yml | 31 +++++++++ examples/prefixsum_py/eval.py | 1 + 10 files changed, 353 insertions(+) create mode 100644 examples/histogram_py/eval.py create mode 100644 examples/histogram_py/reference.py create mode 100644 examples/histogram_py/submission.py create mode 100644 examples/histogram_py/task.py create mode 100644 examples/histogram_py/task.yml create mode 100644 examples/mergesort_py/reference.py create mode 100644 examples/mergesort_py/submission.py create mode 100644 examples/mergesort_py/task.py create mode 100644 examples/mergesort_py/task.yml create mode 120000 examples/prefixsum_py/eval.py diff --git a/examples/histogram_py/eval.py b/examples/histogram_py/eval.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/histogram_py/reference.py b/examples/histogram_py/reference.py new file mode 100644 index 00000000..8fb766b8 --- /dev/null +++ b/examples/histogram_py/reference.py @@ -0,0 +1,47 @@ +from utils import verbose_allclose +import torch +from task import input_t, output_t, HistogramSpec + +def ref_kernel(data: input_t, spec: HistogramSpec) -> output_t: + """ + Reference implementation of histogram using PyTorch. + Args: + data: Input tensor to compute histogram on + spec: Histogram specifications (num_bins, min_val, max_val) + Returns: + Tensor containing bin counts + """ + # Clip values to range + clipped = torch.clamp(data, spec.min_val, spec.max_val) + + # Scale to bin indices + bin_width = (spec.max_val - spec.min_val) / spec.num_bins + indices = ((clipped - spec.min_val) / bin_width).long() + indices = torch.clamp(indices, 0, spec.num_bins - 1) + + # Count values in each bin + return torch.bincount(indices, minlength=spec.num_bins).to(torch.float32) + +def generate_input(size: int, seed: int) -> input_t: + """ + Generates random input tensor with values roughly in [0, 1]. + Returns: + Tensor to compute histogram on + """ + gen = torch.Generator(device='cuda') + gen.manual_seed(seed) + # Generate values with normal distribution for interesting histograms + return torch.randn(size, device='cuda', dtype=torch.float32, generator=gen).contiguous() + +def check_implementation( + data: input_t, + spec: HistogramSpec, + output: output_t, +) -> str: + expected = ref_kernel(data, spec) + reasons = verbose_allclose(output, expected) + + if len(reasons) > 0: + return "mismatch found! custom implementation doesn't match reference: " + reasons[0] + + return '' \ No newline at end of file diff --git a/examples/histogram_py/submission.py b/examples/histogram_py/submission.py new file mode 100644 index 00000000..a0e28216 --- /dev/null +++ b/examples/histogram_py/submission.py @@ -0,0 +1,81 @@ +import torch +import triton +import triton.language as tl +from task import input_t, output_t, HistogramSpec + +@triton.jit +def histogram_kernel( + x_ptr, + output_ptr, + n_elements, + num_bins, + min_val, + max_val, + BLOCK_SIZE: tl.constexpr, +): + """ + Parallel histogram kernel. + Each thread block processes BLOCK_SIZE elements and maintains a local histogram, + then atomically adds to the global histogram. + """ + # Program ID + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load data + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + + # Clip values to range + x = tl.minimum(tl.maximum(x, min_val), max_val) + + # Convert to bin indices + bin_width = (max_val - min_val) / num_bins + indices = ((x - min_val) / bin_width).to(tl.int32) + indices = tl.minimum(tl.maximum(indices, 0), num_bins - 1) + + # Initialize local histogram in shared memory + local_hist = tl.zeros([num_bins], dtype=tl.float32) + + # Populate local histogram + for i in range(BLOCK_SIZE): + if offsets[i] < n_elements: + bin_idx = indices[i] + tl.atomic_add(local_hist + bin_idx, 1.0) + + # Add local histogram to global histogram + for bin_idx in range(num_bins): + if local_hist[bin_idx] > 0: + tl.atomic_add(output_ptr + bin_idx, local_hist[bin_idx]) + +def custom_kernel(data: input_t, spec: HistogramSpec) -> output_t: + """ + Computes histogram using parallel reduction. + Args: + data: Input tensor + spec: Histogram specifications + Returns: + Tensor containing bin counts + """ + n_elements = data.numel() + + # Initialize output histogram + output = torch.zeros(spec.num_bins, device=data.device, dtype=torch.float32) + + # Configure kernel + BLOCK_SIZE = 1024 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + + # Launch kernel + histogram_kernel[grid]( + data, + output, + n_elements, + spec.num_bins, + spec.min_val, + spec.max_val, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output \ No newline at end of file diff --git a/examples/histogram_py/task.py b/examples/histogram_py/task.py new file mode 100644 index 00000000..e9d7fadf --- /dev/null +++ b/examples/histogram_py/task.py @@ -0,0 +1,16 @@ +from typing import TypedDict, TypeVar +import torch +from dataclasses import dataclass + +input_t = TypeVar("input_t", bound=torch.Tensor) +output_t = TypeVar("output_t", bound=torch.Tensor) + +@dataclass +class HistogramSpec: + num_bins: int + min_val: float + max_val: float + +class TestSpec(TypedDict): + size: int + seed: int \ No newline at end of file diff --git a/examples/histogram_py/task.yml b/examples/histogram_py/task.yml new file mode 100644 index 00000000..a1bfeb31 --- /dev/null +++ b/examples/histogram_py/task.yml @@ -0,0 +1,31 @@ +# name: histogram-cuda-inline + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "eval.py"} + +lang: "py" + +description: | + Implement a histogram kernel using CUDA inline function that matches the reference implementation. + The kernel should count the number of elements falling into each bin across the specified range. + +config: + main: "eval.py" + +tests: + - {"size": 1023, "seed": 4242} + - {"size": 1024, "seed": 5236} + - {"size": 1025, "seed": 1001} + - {"size": 2048, "seed": 5531} + - {"size": 4096, "seed": 9173} + +benchmarks: + - {"size": 8192, "seed": 54352} + - {"size": 16384, "seed": 93246} + - {"size": 32768, "seed": 6256} + - {"size": 65536, "seed": 8841} + - {"size": 131072, "seed": 6252} \ No newline at end of file diff --git a/examples/mergesort_py/reference.py b/examples/mergesort_py/reference.py new file mode 100644 index 00000000..c8d05f77 --- /dev/null +++ b/examples/mergesort_py/reference.py @@ -0,0 +1,35 @@ +from utils import verbose_allclose +import torch +from task import input_t, output_t + +def ref_kernel(data: input_t) -> output_t: + """ + Reference implementation of sort using PyTorch. + Args: + data: Input tensor to be sorted + Returns: + Sorted tensor + """ + return torch.sort(data)[0] + +def generate_input(size: int, seed: int) -> input_t: + """ + Generates random input tensor. + Returns: + Tensor to be sorted + """ + gen = torch.Generator(device='cuda') + gen.manual_seed(seed) + return torch.randn(size, device='cuda', dtype=torch.float32, generator=gen).contiguous() + +def check_implementation( + data: input_t, + output: output_t, +) -> str: + expected = ref_kernel(data) + reasons = verbose_allclose(output, expected) + + if len(reasons) > 0: + return "mismatch found! custom implementation doesn't match reference: " + reasons[0] + + return '' \ No newline at end of file diff --git a/examples/mergesort_py/submission.py b/examples/mergesort_py/submission.py new file mode 100644 index 00000000..58da89a5 --- /dev/null +++ b/examples/mergesort_py/submission.py @@ -0,0 +1,102 @@ +import torch +import triton +import triton.language as tl +from task import input_t, output_t + +@triton.jit +def merge_kernel( + x_ptr, + temp_ptr, + n_elements, + chunk_size, + BLOCK_SIZE: tl.constexpr, +): + """ + Merges sorted chunks of size chunk_size into sorted chunks of size 2*chunk_size. + Each thread block handles merging two adjacent sorted chunks. + """ + # Program ID + pid = tl.program_id(0) + + # Calculate start of the two chunks to merge + chunk_pair_start = pid * (2 * chunk_size) + if chunk_pair_start >= n_elements: + return + + # Calculate sizes of chunks to merge (handle last chunk specially) + left_size = min(chunk_size, n_elements - chunk_pair_start) + right_start = chunk_pair_start + chunk_size + right_size = min(chunk_size, n_elements - right_start) if right_start < n_elements else 0 + + # Load left chunk + left_idx = chunk_pair_start + tl.arange(0, BLOCK_SIZE) + left_mask = left_idx < (chunk_pair_start + left_size) + left = tl.load(x_ptr + left_idx, mask=left_mask, other=float('inf')) + + # Load right chunk + right_idx = right_start + tl.arange(0, BLOCK_SIZE) + right_mask = right_idx < (right_start + right_size) + right = tl.load(x_ptr + right_idx, mask=right_mask, other=float('inf')) + + # Merge chunks using parallel merge path + output = tl.zeros([2 * BLOCK_SIZE], dtype=tl.float32) + float('inf') + left_ptr = 0 + right_ptr = 0 + out_ptr = 0 + + for i in range(left_size + right_size): + # Compare current elements + take_left = (left_ptr < left_size and + (right_ptr >= right_size or left[left_ptr] <= right[right_ptr])) + + # Store smaller element + output[out_ptr] = tl.where(take_left, left[left_ptr], right[right_ptr]) + + # Advance pointers + left_ptr = left_ptr + tl.where(take_left, 1, 0) + right_ptr = right_ptr + tl.where(take_left, 0, 1) + out_ptr = out_ptr + 1 + + # Store merged result + out_idx = chunk_pair_start + tl.arange(0, 2 * BLOCK_SIZE) + out_mask = out_idx < min(chunk_pair_start + left_size + right_size, n_elements) + tl.store(temp_ptr + out_idx, output, mask=out_mask) + +def custom_kernel(data: input_t) -> output_t: + """ + Implements parallel merge sort. + Args: + data: Input tensor to be sorted + Returns: + Sorted tensor + """ + n_elements = data.numel() + if n_elements <= 1: + return data.clone() + + # Allocate temporary buffer for merging + temp = torch.empty_like(data) + output = data.clone() + + # Configure kernel + BLOCK_SIZE = 512 # Should be power of 2 for simplicity + + # Bottom-up merge sort + chunk_size = BLOCK_SIZE + while chunk_size < n_elements: + n_chunk_pairs = triton.cdiv(n_elements, 2 * chunk_size) + + # Launch merge kernel + merge_kernel[(n_chunk_pairs,)]( + output, + temp, + n_elements, + chunk_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Swap buffers + output, temp = temp, output + chunk_size *= 2 + + return output \ No newline at end of file diff --git a/examples/mergesort_py/task.py b/examples/mergesort_py/task.py new file mode 100644 index 00000000..62e5dae0 --- /dev/null +++ b/examples/mergesort_py/task.py @@ -0,0 +1,9 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=torch.Tensor) +output_t = TypeVar("output_t", bound=torch.Tensor) + +class TestSpec(TypedDict): + size: int + seed: int \ No newline at end of file diff --git a/examples/mergesort_py/task.yml b/examples/mergesort_py/task.yml new file mode 100644 index 00000000..17c99107 --- /dev/null +++ b/examples/mergesort_py/task.yml @@ -0,0 +1,31 @@ +# name: mergesort-cuda-inline + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "eval.py"} + +lang: "py" + +description: | + Implement a sort kernel that matches the reference implementation. + The kernel should sort the input array in ascending order using the merge sort algorithm. + +config: + main: "eval.py" + +tests: + - {"size": 1023, "seed": 4242} + - {"size": 1024, "seed": 5236} + - {"size": 1025, "seed": 1001} + - {"size": 2048, "seed": 5531} + - {"size": 4096, "seed": 9173} + +benchmarks: + - {"size": 8192, "seed": 54352} + - {"size": 16384, "seed": 93246} + - {"size": 32768, "seed": 6256} + - {"size": 65536, "seed": 8841} + - {"size": 131072, "seed": 6252} \ No newline at end of file diff --git a/examples/prefixsum_py/eval.py b/examples/prefixsum_py/eval.py new file mode 120000 index 00000000..caf621bd --- /dev/null +++ b/examples/prefixsum_py/eval.py @@ -0,0 +1 @@ +../eval.py \ No newline at end of file From 6a9142644df2e44f5cec486475beadf31a939125 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 12 Feb 2025 20:43:18 -0800 Subject: [PATCH 6/9] fix grayscale folder --- examples/grayscale_py/submission_triton.py | 70 ---------------- examples/grayscale_py/utils.py | 95 +--------------------- 2 files changed, 1 insertion(+), 164 deletions(-) delete mode 100644 examples/grayscale_py/submission_triton.py mode change 100644 => 120000 examples/grayscale_py/utils.py diff --git a/examples/grayscale_py/submission_triton.py b/examples/grayscale_py/submission_triton.py deleted file mode 100644 index 27be62f9..00000000 --- a/examples/grayscale_py/submission_triton.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import triton -import triton.language as tl -from task import input_t, output_t - -@triton.jit -def grayscale_kernel( - input_ptr, output_ptr, - H, W, - stride_h, stride_w, stride_c, - BLOCK_SIZE: tl.constexpr, -): - # Program ID - pid = tl.program_id(0) - - # Calculate start indices - block_start_h = (pid // ((W + BLOCK_SIZE - 1) // BLOCK_SIZE)) * BLOCK_SIZE - block_start_w = (pid % ((W + BLOCK_SIZE - 1) // BLOCK_SIZE)) * BLOCK_SIZE - - # Offsets for this block - offs_h = block_start_h + tl.arange(0, BLOCK_SIZE) - offs_w = block_start_w + tl.arange(0, BLOCK_SIZE) - - # Create mask for valid pixels - mask = (offs_h[:, None] < H) & (offs_w[None, :] < W) - - # RGB to Grayscale coefficients - R_COEF = 0.2989 - G_COEF = 0.5870 - B_COEF = 0.1140 - - # Calculate base pointer for each pixel in the block - base_ptr = offs_h[:, None] * stride_h + offs_w[None, :] * stride_w - - # Load RGB channels - r = tl.load(input_ptr + base_ptr + 0 * stride_c, mask=mask, other=0.0) - g = tl.load(input_ptr + base_ptr + 1 * stride_c, mask=mask, other=0.0) - b = tl.load(input_ptr + base_ptr + 2 * stride_c, mask=mask, other=0.0) - - # Convert to grayscale - gray = R_COEF * r + G_COEF * g + B_COEF * b - - # Store result - out_ptr = offs_h[:, None] * W + offs_w[None, :] - tl.store(output_ptr + out_ptr, gray, mask=mask) - -def custom_kernel(data: input_t) -> output_t: - H, W, C = data.shape - assert C == 3, "Input must be an RGB image" - - # Create output tensor - output = torch.empty((H, W), device=data.device, dtype=data.dtype) - - # Calculate strides - stride_h = W * C - stride_w = C - stride_c = 1 - - # Launch kernel - BLOCK_SIZE = 32 - grid = ((H + BLOCK_SIZE - 1) // BLOCK_SIZE) * ((W + BLOCK_SIZE - 1) // BLOCK_SIZE) - - grayscale_kernel[grid]( - data, output, - H, W, - stride_h, stride_w, stride_c, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return output \ No newline at end of file diff --git a/examples/grayscale_py/utils.py b/examples/grayscale_py/utils.py deleted file mode 100644 index 5363abdd..00000000 --- a/examples/grayscale_py/utils.py +++ /dev/null @@ -1,94 +0,0 @@ -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -def verbose_allclose( - tensor1: torch.Tensor, - tensor2: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - - Parameters: - tensor1 (torch.Tensor): First tensor to compare. - tensor2 (torch.Tensor): Second tensor to compare. - rtol (float): Relative tolerance. - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - - Returns: - list[str]: List of error messages if tensors don't match, empty list otherwise. - """ - # Check if the shapes of the tensors match - if tensor1.shape != tensor2.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(tensor1 - tensor2) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(tensor2) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.sum().item() - - # Check if all elements are close - all_close = num_mismatched == 0 - - # Return error messages if there are mismatches - if not all_close and num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}", - f"Mismatched elements: {mismatched_indices}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {tensor1[i]} {tensor2[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] \ No newline at end of file diff --git a/examples/grayscale_py/utils.py b/examples/grayscale_py/utils.py new file mode 120000 index 00000000..50fbc6d8 --- /dev/null +++ b/examples/grayscale_py/utils.py @@ -0,0 +1 @@ +../utils.py \ No newline at end of file From c81d6a1d1f5c1bc3b63297a897c158a295569e8b Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 12 Feb 2025 20:57:33 -0800 Subject: [PATCH 7/9] update --- examples/conv2d_py/submission.py | 123 ------------------------------- examples/conv2d_py/utils.py | 1 + examples/prefixsum_py/utils.py | 1 + examples/vectorsum_py/utils.py | 94 +---------------------- 4 files changed, 3 insertions(+), 216 deletions(-) delete mode 100644 examples/conv2d_py/submission.py create mode 120000 examples/conv2d_py/utils.py create mode 120000 examples/prefixsum_py/utils.py mode change 100644 => 120000 examples/vectorsum_py/utils.py diff --git a/examples/conv2d_py/submission.py b/examples/conv2d_py/submission.py deleted file mode 100644 index 1867acd3..00000000 --- a/examples/conv2d_py/submission.py +++ /dev/null @@ -1,123 +0,0 @@ -import torch -import triton -import triton.language as tl -from task import input_t, output_t, KernelSpec - -@triton.jit -def conv2d_kernel( - # Pointers to matrices - input_ptr, kernel_ptr, output_ptr, - # Matrix dimensions - batch, in_channels, out_channels, - in_height, in_width, - kernel_size, stride, padding, - out_height, out_width, - # Block sizes - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - """ - 2D Convolution kernel. - Each thread block handles computation for a BLOCK_SIZE_M x BLOCK_SIZE_N region of the output. - """ - # Program ID - pid = tl.program_id(0) - - # Calculate output position - n_blocks_m = triton.cdiv(out_height, BLOCK_SIZE_M) - batch_idx = pid // (n_blocks_m * out_channels) - tmp = pid % (n_blocks_m * out_channels) - out_ch = tmp // n_blocks_m - block_m = tmp % n_blocks_m - - # Calculate output row and column ranges for this block - out_m = block_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - out_n = tl.arange(0, BLOCK_SIZE_N) - - # Calculate input positions with padding offset - in_m = out_m * stride - padding - in_n = out_n * stride - padding - - # Initialize output accumulator - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Iterate over input channels and kernel positions - for in_ch in range(in_channels): - for kh in range(kernel_size): - for kw in range(kernel_size): - # Calculate input positions - h_pos = in_m + kh - w_pos = in_n + kw - - # Create masks for valid positions - m_mask = (h_pos >= 0) & (h_pos < in_height) - n_mask = (w_pos >= 0) & (w_pos < in_width) - mask = m_mask[:, None] & n_mask[None, :] - - # Load input values - x_pos = h_pos[:, None] * in_width + w_pos[None, :] - input_idx = ((batch_idx * in_channels + in_ch) * in_height * in_width + x_pos) - x = tl.load(input_ptr + input_idx, mask=mask, other=0.0) - - # Load kernel value - k_idx = ((out_ch * in_channels + in_ch) * kernel_size * kernel_size + - kh * kernel_size + kw) - k = tl.load(kernel_ptr + k_idx) - - # Accumulate - acc += k * x - - # Write output - out_pos = out_m[:, None] * out_width + out_n[None, :] - output_idx = ((batch_idx * out_channels + out_ch) * out_height * out_width + - out_pos) - - # Create output mask - m_mask = out_m < out_height - n_mask = out_n < out_width - mask = m_mask[:, None] & n_mask[None, :] - - # Store output - tl.store(output_ptr + output_idx, acc, mask=mask) - -def custom_kernel(data: input_t, spec: KernelSpec) -> output_t: - """ - Performs 2D convolution using Triton kernel. - Args: - data: Tuple of (input tensor, kernel tensor) - spec: Convolution specifications - Returns: - Output tensor after convolution - """ - input_tensor, kernel = data - batch, in_channels, in_height, in_width = input_tensor.shape - out_channels, _, kernel_size, _ = kernel.shape - - # Calculate output dimensions - out_height = ((in_height + 2 * spec.padding - kernel_size) // spec.stride) + 1 - out_width = ((in_width + 2 * spec.padding - kernel_size) // spec.stride) + 1 - - # Allocate output - output = torch.empty( - (batch, out_channels, out_height, out_width), - device=input_tensor.device, - dtype=input_tensor.dtype - ) - - # Configure kernel - BLOCK_SIZE_M = 8 - BLOCK_SIZE_N = 8 - grid = (batch * out_channels * triton.cdiv(out_height, BLOCK_SIZE_M),) - - # Launch kernel - conv2d_kernel[grid]( - input_tensor, kernel, output, - batch, in_channels, out_channels, - in_height, in_width, - kernel_size, spec.stride, spec.padding, - out_height, out_width, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - ) - - return output \ No newline at end of file diff --git a/examples/conv2d_py/utils.py b/examples/conv2d_py/utils.py new file mode 120000 index 00000000..50fbc6d8 --- /dev/null +++ b/examples/conv2d_py/utils.py @@ -0,0 +1 @@ +../utils.py \ No newline at end of file diff --git a/examples/prefixsum_py/utils.py b/examples/prefixsum_py/utils.py new file mode 120000 index 00000000..50fbc6d8 --- /dev/null +++ b/examples/prefixsum_py/utils.py @@ -0,0 +1 @@ +../utils.py \ No newline at end of file diff --git a/examples/vectorsum_py/utils.py b/examples/vectorsum_py/utils.py deleted file mode 100644 index cb7e26bb..00000000 --- a/examples/vectorsum_py/utils.py +++ /dev/null @@ -1,93 +0,0 @@ -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - -def verbose_allclose( - tensor1: torch.Tensor, - tensor2: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - - Parameters: - tensor1 (torch.Tensor): First tensor to compare. - tensor2 (torch.Tensor): Second tensor to compare. - rtol (float): Relative tolerance. - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - - Returns: - list[str]: List of error messages if tensors don't match, empty list otherwise. - """ - # Check if the shapes of the tensors match - if tensor1.shape != tensor2.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(tensor1 - tensor2) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(tensor2) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.sum().item() - - # Check if all elements are close - all_close = num_mismatched == 0 - - # Return detailed information if there are mismatches - if not all_close and num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}", - f"Mismatched elements: {mismatched_indices}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {tensor1[i]} {tensor2[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] \ No newline at end of file diff --git a/examples/vectorsum_py/utils.py b/examples/vectorsum_py/utils.py new file mode 120000 index 00000000..50fbc6d8 --- /dev/null +++ b/examples/vectorsum_py/utils.py @@ -0,0 +1 @@ +../utils.py \ No newline at end of file From a08ab1f147111d480bcba86ca84616b62b212870 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 12 Feb 2025 21:03:45 -0800 Subject: [PATCH 8/9] fix sorting --- examples/mergesort_py/submission.py | 102 ------------------ examples/sort_py/eval.py | 1 + .../{mergesort_py => sort_py}/reference.py | 0 examples/sort_py/submission.py | 12 +++ examples/{mergesort_py => sort_py}/task.py | 0 examples/{mergesort_py => sort_py}/task.yml | 0 examples/sort_py/utils.py | 1 + 7 files changed, 14 insertions(+), 102 deletions(-) delete mode 100644 examples/mergesort_py/submission.py create mode 120000 examples/sort_py/eval.py rename examples/{mergesort_py => sort_py}/reference.py (100%) create mode 100644 examples/sort_py/submission.py rename examples/{mergesort_py => sort_py}/task.py (100%) rename examples/{mergesort_py => sort_py}/task.yml (100%) create mode 120000 examples/sort_py/utils.py diff --git a/examples/mergesort_py/submission.py b/examples/mergesort_py/submission.py deleted file mode 100644 index 58da89a5..00000000 --- a/examples/mergesort_py/submission.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import triton -import triton.language as tl -from task import input_t, output_t - -@triton.jit -def merge_kernel( - x_ptr, - temp_ptr, - n_elements, - chunk_size, - BLOCK_SIZE: tl.constexpr, -): - """ - Merges sorted chunks of size chunk_size into sorted chunks of size 2*chunk_size. - Each thread block handles merging two adjacent sorted chunks. - """ - # Program ID - pid = tl.program_id(0) - - # Calculate start of the two chunks to merge - chunk_pair_start = pid * (2 * chunk_size) - if chunk_pair_start >= n_elements: - return - - # Calculate sizes of chunks to merge (handle last chunk specially) - left_size = min(chunk_size, n_elements - chunk_pair_start) - right_start = chunk_pair_start + chunk_size - right_size = min(chunk_size, n_elements - right_start) if right_start < n_elements else 0 - - # Load left chunk - left_idx = chunk_pair_start + tl.arange(0, BLOCK_SIZE) - left_mask = left_idx < (chunk_pair_start + left_size) - left = tl.load(x_ptr + left_idx, mask=left_mask, other=float('inf')) - - # Load right chunk - right_idx = right_start + tl.arange(0, BLOCK_SIZE) - right_mask = right_idx < (right_start + right_size) - right = tl.load(x_ptr + right_idx, mask=right_mask, other=float('inf')) - - # Merge chunks using parallel merge path - output = tl.zeros([2 * BLOCK_SIZE], dtype=tl.float32) + float('inf') - left_ptr = 0 - right_ptr = 0 - out_ptr = 0 - - for i in range(left_size + right_size): - # Compare current elements - take_left = (left_ptr < left_size and - (right_ptr >= right_size or left[left_ptr] <= right[right_ptr])) - - # Store smaller element - output[out_ptr] = tl.where(take_left, left[left_ptr], right[right_ptr]) - - # Advance pointers - left_ptr = left_ptr + tl.where(take_left, 1, 0) - right_ptr = right_ptr + tl.where(take_left, 0, 1) - out_ptr = out_ptr + 1 - - # Store merged result - out_idx = chunk_pair_start + tl.arange(0, 2 * BLOCK_SIZE) - out_mask = out_idx < min(chunk_pair_start + left_size + right_size, n_elements) - tl.store(temp_ptr + out_idx, output, mask=out_mask) - -def custom_kernel(data: input_t) -> output_t: - """ - Implements parallel merge sort. - Args: - data: Input tensor to be sorted - Returns: - Sorted tensor - """ - n_elements = data.numel() - if n_elements <= 1: - return data.clone() - - # Allocate temporary buffer for merging - temp = torch.empty_like(data) - output = data.clone() - - # Configure kernel - BLOCK_SIZE = 512 # Should be power of 2 for simplicity - - # Bottom-up merge sort - chunk_size = BLOCK_SIZE - while chunk_size < n_elements: - n_chunk_pairs = triton.cdiv(n_elements, 2 * chunk_size) - - # Launch merge kernel - merge_kernel[(n_chunk_pairs,)]( - output, - temp, - n_elements, - chunk_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Swap buffers - output, temp = temp, output - chunk_size *= 2 - - return output \ No newline at end of file diff --git a/examples/sort_py/eval.py b/examples/sort_py/eval.py new file mode 120000 index 00000000..caf621bd --- /dev/null +++ b/examples/sort_py/eval.py @@ -0,0 +1 @@ +../eval.py \ No newline at end of file diff --git a/examples/mergesort_py/reference.py b/examples/sort_py/reference.py similarity index 100% rename from examples/mergesort_py/reference.py rename to examples/sort_py/reference.py diff --git a/examples/sort_py/submission.py b/examples/sort_py/submission.py new file mode 100644 index 00000000..d50f4eae --- /dev/null +++ b/examples/sort_py/submission.py @@ -0,0 +1,12 @@ +import torch +from task import input_t, output_t + +def custom_kernel(data: input_t) -> output_t: + """ + Implements sort using PyTorch. + Args: + data: Input tensor to be sorted + Returns: + Sorted tensor + """ + return torch.sort(data)[0] \ No newline at end of file diff --git a/examples/mergesort_py/task.py b/examples/sort_py/task.py similarity index 100% rename from examples/mergesort_py/task.py rename to examples/sort_py/task.py diff --git a/examples/mergesort_py/task.yml b/examples/sort_py/task.yml similarity index 100% rename from examples/mergesort_py/task.yml rename to examples/sort_py/task.yml diff --git a/examples/sort_py/utils.py b/examples/sort_py/utils.py new file mode 120000 index 00000000..50fbc6d8 --- /dev/null +++ b/examples/sort_py/utils.py @@ -0,0 +1 @@ +../utils.py \ No newline at end of file From 4a89a0c086ee8ea7892304edd4aa2a5859f8c5ef Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 12 Feb 2025 21:52:29 -0800 Subject: [PATCH 9/9] fix submissions --- examples/histogram_py/submission.py | 81 -------------------- examples/histogram_py/utils.py | 1 + examples/prefixsum_py/submission.py | 113 ---------------------------- examples/sort_py/submission.py | 6 +- examples/vectorsum_py/submission.py | 7 +- 5 files changed, 10 insertions(+), 198 deletions(-) delete mode 100644 examples/histogram_py/submission.py create mode 120000 examples/histogram_py/utils.py delete mode 100644 examples/prefixsum_py/submission.py diff --git a/examples/histogram_py/submission.py b/examples/histogram_py/submission.py deleted file mode 100644 index a0e28216..00000000 --- a/examples/histogram_py/submission.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import triton -import triton.language as tl -from task import input_t, output_t, HistogramSpec - -@triton.jit -def histogram_kernel( - x_ptr, - output_ptr, - n_elements, - num_bins, - min_val, - max_val, - BLOCK_SIZE: tl.constexpr, -): - """ - Parallel histogram kernel. - Each thread block processes BLOCK_SIZE elements and maintains a local histogram, - then atomically adds to the global histogram. - """ - # Program ID - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - # Load data - x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - - # Clip values to range - x = tl.minimum(tl.maximum(x, min_val), max_val) - - # Convert to bin indices - bin_width = (max_val - min_val) / num_bins - indices = ((x - min_val) / bin_width).to(tl.int32) - indices = tl.minimum(tl.maximum(indices, 0), num_bins - 1) - - # Initialize local histogram in shared memory - local_hist = tl.zeros([num_bins], dtype=tl.float32) - - # Populate local histogram - for i in range(BLOCK_SIZE): - if offsets[i] < n_elements: - bin_idx = indices[i] - tl.atomic_add(local_hist + bin_idx, 1.0) - - # Add local histogram to global histogram - for bin_idx in range(num_bins): - if local_hist[bin_idx] > 0: - tl.atomic_add(output_ptr + bin_idx, local_hist[bin_idx]) - -def custom_kernel(data: input_t, spec: HistogramSpec) -> output_t: - """ - Computes histogram using parallel reduction. - Args: - data: Input tensor - spec: Histogram specifications - Returns: - Tensor containing bin counts - """ - n_elements = data.numel() - - # Initialize output histogram - output = torch.zeros(spec.num_bins, device=data.device, dtype=torch.float32) - - # Configure kernel - BLOCK_SIZE = 1024 - grid = (triton.cdiv(n_elements, BLOCK_SIZE),) - - # Launch kernel - histogram_kernel[grid]( - data, - output, - n_elements, - spec.num_bins, - spec.min_val, - spec.max_val, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return output \ No newline at end of file diff --git a/examples/histogram_py/utils.py b/examples/histogram_py/utils.py new file mode 120000 index 00000000..50fbc6d8 --- /dev/null +++ b/examples/histogram_py/utils.py @@ -0,0 +1 @@ +../utils.py \ No newline at end of file diff --git a/examples/prefixsum_py/submission.py b/examples/prefixsum_py/submission.py deleted file mode 100644 index 1d06b8ab..00000000 --- a/examples/prefixsum_py/submission.py +++ /dev/null @@ -1,113 +0,0 @@ -import torch -import triton -import triton.language as tl -from task import input_t, output_t - -@triton.jit -def scan_kernel( - x_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - """ - Single-block inclusive prefix sum kernel. - Uses a two-pass approach: up-sweep and down-sweep. - """ - # Get program ID and allocate shared memory - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - # Load data into shared memory - x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - - # Up-sweep: Build sum tree - offset = 1 - for d in range(triton.next_power_of_2(BLOCK_SIZE) // 2): - mask = tl.arange(0, BLOCK_SIZE) % (2 * offset) == (2 * offset - 1) - vals = tl.where(mask, x, 0.0) - vals = tl.sum(vals, axis=0) - x = x + tl.where(mask, -x + vals, 0.0) - offset *= 2 - - # Down-sweep: Distribute sums - for d in range(triton.next_power_of_2(BLOCK_SIZE) // 2 - 1, -1, -1): - offset = 1 << d - mask = tl.arange(0, BLOCK_SIZE) % (2 * offset) == (2 * offset - 1) - vals = tl.where(mask, x, 0.0) - x = x + tl.where(tl.arange(0, BLOCK_SIZE) % (2 * offset) >= offset, vals, 0.0) - - # Store results - output_mask = offsets < n_elements - tl.store(output_ptr + offsets, x, mask=output_mask) - -@triton.jit -def block_sum_kernel( - block_sums_ptr, - output_ptr, - block_size, - n_blocks, - BLOCK_SIZE: tl.constexpr, -): - """ - Adds block sums to subsequent blocks to get final prefix sum. - """ - pid = tl.program_id(0) - block_idx = pid + 1 # Skip first block - - if block_idx < n_blocks: - # Load block sum from previous block - prev_sum = tl.load(block_sums_ptr + block_idx - 1) - - # Add to all elements in current block - offsets = block_idx * block_size + tl.arange(0, BLOCK_SIZE) - mask = offsets < (block_idx + 1) * block_size - x = tl.load(output_ptr + offsets, mask=mask, other=0.0) - x = x + prev_sum - tl.store(output_ptr + offsets, x, mask=mask) - -def custom_kernel(data: input_t) -> output_t: - """ - Multi-block prefix sum implementation. - Args: - data: Input tensor - Returns: - Tensor containing inclusive prefix sum - """ - n_elements = data.numel() - output = torch.empty_like(data) - - # Configure kernel - BLOCK_SIZE = 1024 - n_blocks = triton.cdiv(n_elements, BLOCK_SIZE) - - # Phase 1: Compute prefix sum within each block - scan_kernel[(n_blocks,)]( - data, - output, - n_elements, - BLOCK_SIZE=BLOCK_SIZE, - ) - - if n_blocks > 1: - # Get block sums - block_sums = torch.empty(n_blocks, device=data.device, dtype=data.dtype) - block_sums[0] = output[BLOCK_SIZE-1] - for i in range(1, n_blocks-1): - block_sums[i] = output[(i+1)*BLOCK_SIZE-1] - - # Compute prefix sum of block sums - block_sums = torch.cumsum(block_sums, dim=0) - - # Phase 2: Add block sums to subsequent blocks - block_sum_kernel[(n_blocks-1,)]( - block_sums, - output, - BLOCK_SIZE, - n_blocks, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return output \ No newline at end of file diff --git a/examples/sort_py/submission.py b/examples/sort_py/submission.py index d50f4eae..5a4915c9 100644 --- a/examples/sort_py/submission.py +++ b/examples/sort_py/submission.py @@ -1,7 +1,7 @@ import torch from task import input_t, output_t -def custom_kernel(data: input_t) -> output_t: +def _custom_kernel(data: input_t) -> output_t: """ Implements sort using PyTorch. Args: @@ -9,4 +9,6 @@ def custom_kernel(data: input_t) -> output_t: Returns: Sorted tensor """ - return torch.sort(data)[0] \ No newline at end of file + return torch.sort(data)[0] + +custom_kernel = torch.compile(_custom_kernel, mode="reduce-overhead") \ No newline at end of file diff --git a/examples/vectorsum_py/submission.py b/examples/vectorsum_py/submission.py index 4e1969e2..8ac3ac13 100644 --- a/examples/vectorsum_py/submission.py +++ b/examples/vectorsum_py/submission.py @@ -28,7 +28,7 @@ def sum_kernel( # Store the partial sum tl.atomic_add(output_ptr, block_sum) -def custom_kernel(data: input_t) -> output_t: +def _custom_kernel(data: input_t) -> output_t: """ Performs parallel reduction to compute sum of all elements. Args: @@ -51,4 +51,7 @@ def custom_kernel(data: input_t) -> output_t: BLOCK_SIZE=BLOCK_SIZE, ) - return output[0] \ No newline at end of file + return output[0] + +# Compile the kernel for better performance +custom_kernel = torch.compile(_custom_kernel, mode="reduce-overhead") \ No newline at end of file