Sunteți pe pagina 1din 3

import jcuda.driver.JCudaDriver.

_
import java.io._
import jcuda._
import jcuda.driver._
//remove if not needed
import scala.collection.JavaConversions._
object JCudaVectorAdd {
def main(args: Array[String]) {
JCudaDriver.setExceptionsEnabled(true)
val ptxFileName = preparePtxFile("JCudaVectorAddKernel.cu")
cuInit(0)
val device = new CUdevice()
cuDeviceGet(device, 0)
val context = new CUcontext()
cuCtxCreate(context, 0, device)
val module = new CUmodule()
cuModuleLoad(module, ptxFileName)
val function = new CUfunction()
cuModuleGetFunction(function, module, "add")
val numElements = 512
val hostInputA = Array.ofDim[Float](numElements)
val hostInputB = Array.ofDim[Float](numElements)
for (i <- 0 until numElements) {
hostInputA(i) = i.toFloat
hostInputB(i) = i.toFloat
}
val deviceInputA = new CUdeviceptr()
cuMemAlloc(deviceInputA, numElements * Sizeof.FLOAT)
cuMemcpyHtoD(deviceInputA, Pointer.to(hostInputA), numElements * Sizeof.FLOA
T)
val deviceInputB = new CUdeviceptr()
cuMemAlloc(deviceInputB, numElements * Sizeof.FLOAT)
cuMemcpyHtoD(deviceInputB, Pointer.to(hostInputB), numElements * Sizeof.FLOA
T)
val deviceOutput = new CUdeviceptr()
cuMemAlloc(deviceOutput, numElements * Sizeof.FLOAT)
val kernelParameters = Pointer.to(Pointer.to(Array(numElements)), Pointer.to
(deviceInputA), Pointer.to(deviceInputB),
Pointer.to(deviceOutput))
val blockSizeX = 256
val gridSizeX = Math.ceil(numElements.toDouble / blockSizeX).toInt
cuLaunchKernel(function, gridSizeX, 1, 1, blockSizeX, 1, 1, 0, null, kernelP
arameters, null)
cuCtxSynchronize()
val hostOutput = Array.ofDim[Float](numElements)

cuMemcpyDtoH(Pointer.to(hostOutput), deviceOutput, numElements * Sizeof.FLOA


T)
var passed = true
for (i <- 0 until numElements) {
val expected = i + i
if (Math.abs(hostOutput(i) - expected) > 1e-5) {
println("At index " + i + " found " + hostOutput(i) + " but expected " +
expected)
passed = false
//break
}
}
println("Test " + (if (passed) "PASSED" else "FAILED"))
cuMemFree(deviceInputA)
cuMemFree(deviceInputB)
cuMemFree(deviceOutput)
}
private def preparePtxFile(cuFileName: String): String = {
var endIndex = cuFileName.lastIndexOf('.')
if (endIndex == -1) {
endIndex = cuFileName.length - 1
}
val ptxFileName = cuFileName.substring(0, endIndex + 1) + "ptx"
val ptxFile = new File(ptxFileName)
if (ptxFile.exists()) {
return ptxFileName
}
val cuFile = new File(cuFileName)
if (!cuFile.exists()) {
throw new IOException("Input file not found: " + cuFileName)
}
val modelString = "-m" + System.getProperty("sun.arch.data.model")
val command = "nvcc " + modelString + " -ptx " + cuFile.getPath + " -o " +
ptxFileName
println("Executing\n" + command)
val process = Runtime.getRuntime.exec(command)
val errorMessage = new String(toByteArray(process.getErrorStream))
val outputMessage = new String(toByteArray(process.getInputStream))
var exitValue = 0
try {
exitValue = process.waitFor()
} catch {
case e: InterruptedException => {
Thread.currentThread().interrupt()
throw new IOException("Interrupted while waiting for nvcc output", e)
}
}
if (exitValue != 0) {
println("nvcc process exitValue " + exitValue)
println("errorMessage:\n" + errorMessage)
println("outputMessage:\n" + outputMessage)
throw new IOException("Could not create .ptx file: " + errorMessage)
}
println("Finished creating PTX file")
ptxFileName
}
private def toByteArray(inputStream: InputStream): Array[Byte] = {

val baos = new ByteArrayOutputStream()


val buffer = Array.ofDim[Byte](8192)
while (true) {
val read = inputStream.read(buffer)
if (read == -1) {
//break
}
baos.write(buffer, 0, read)
}
baos.toByteArray()
}
}

S-ar putea să vă placă și