8000 [BUG] Calling a function from a kernel doesn't seem to work correctly · Issue #4713 · modular/modular · GitHub
[go: up one dir, main page]

Skip to content
[BUG] Calling a function from a kernel doesn't seem to work correctly #4713
@prabhuramachandran

Description

@prabhuramachandran

Bug description

Actual behavior

I have a very simple GPU kernel which when run should produce non-zero output but produces all zeros. The code below is fairly self explanatory.

Expected behavior

I am merely squaring an input LayoutTensor and storing it into another LayoutTensor, when I do the squaring inside the kernel function, all is well but if I just call a function which squares its argument I get zeros. There is no compilation error and nothing in the documentation that says this can't be done so this is rather surprising.

Steps to reproduce

Here is a complete example:

from gpu.host import DeviceContext, DeviceBuffer, HostBuffer
from gpu.id import block_dim, block_idx, thread_idx
from layout import Layout, LayoutTensor
from sys import has_accelerator


alias float_type = DType.float32
alias size = 16
alias layout = Layout.row_major(size)
alias Tensor = LayoutTensor[float_type, layout, MutableAnyOrigin]


fn comp[eltype:DType](x: SIMD[eltype, 1]) -> __type_of(x):
    return x**2


fn sq_kernel[eltype:DType](input:LayoutTensor[eltype, layout, MutableAnyOrigin],
                           output:LayoutTensor[eltype, layout, MutableAnyOrigin]):
    var tid = block_idx.x * block_dim.x + thread_idx.x
    if tid < size:
        # Uncommenting this and commenting the next line will produce the correct output.
        # output[tid] = input[tid]**2
        output[tid] = comp[eltype](input.load[1](tid, 1))
        # Also, it is painful to use input[tid] as it produces hard to debug and understand compiler error messages.


fn bug1(ctx:DeviceContext) raises:
    host_buf = ctx.enqueue_create_host_buffer[float_type](size)
    dev_buf = ctx.enqueue_create_buffer[float_type](size)
    dev = LayoutTensor[float_type, layout, MutableAnyOrigin](dev_buf)
    host = LayoutTensor[float_type, layout, MutableAnyOrigin](host_buf)
    host_buf1 = ctx.enqueue_create_host_buffer[float_type](size)
    dev_buf1 = ctx.enqueue_create_buffer[float_type](size)
    output_dev = LayoutTensor[float_type, layout, MutableAnyOrigin](dev_buf1)
    output_host = LayoutTensor[float_type, layout, MutableAnyOrigin](host_buf1)

    for i in range(size):
        host[i] = i
    print(host[8])
    ctx.enqueue_copy(dst_buf=dev_buf, src_buf=host_buf)
    ctx.synchronize()
    ctx.enqueue_function[sq_kernel[float_type]](dev, output_dev, grid_dim=4, block_dim=4)
    ctx.enqueue_copy(dst_buf=host_buf1, src_buf=dev_buf1)
    ctx.synchronize()
    # Prints 0.0, when it should have printed the square of the numbers.
    for i in range(size):
        print(output_host[i], end=' ')
    print()


def main():
    @parameter
    if not has_accelerator():
        print("No compatible GPU found")
    else:
        ctx = DeviceContext()
        bug1(ctx)

System information

  • magic info
$ magic info
     Magic version: 0.7.2
System
------------
       Pixi version: 0.41.4
           Platform: linux-64
   Virtual packages: __unix=0=0
                   : __linux=6.8.0=0
                   : __glibc=2.39=0
                   : __cuda=12.4=0
                   : __archspec=1=zen2
          Cache dir: /home/user/.cache/rattler/cache
       Auth storage: /home/user/.rattler/credentials.json
   Config locations: No config files found

Global
------------
            Bin dir: /home/user/.modular/bin
    Environment dir: /home/user/.modular/envs
       Manifest dir: /home/user/.modular/manifests/pixi-global.toml

Project
------------
               Name: bench
            Version: 0.1.0
      Manifest file: /home/user/mojoproject.toml
       Last updated: 27-05-2025 12:12:20

Environments
------------
        Environment: default
           Features: default
           Channels: https://conda.modular.com/max-nightly, https://conda.modular.com/max, https://repo.prefix.dev/modular-community, conda-forge
   Dependency count: 1
       Dependencies: max
   Target platforms: linux-64
  • magic list max
$ magic list max
Package     Version               Build    Size       Kind   Source
max         25.4.0.dev2025052605  release  9.2 KiB    conda  max
max-core    25.4.0.dev2025052605  release  211.7 MiB  conda  max-core
max-python  25.4.0.dev2025052605  release  13.9 MiB   conda  max-python

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingmodular-repomojoIssues that are related to mojo

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0