0%

GPU HashTable as a custom op for TensorFlow

Tensorflow Op Register

  • Firstly, Mutable Hashtable should be registered as an TensorFlow Op.

  • This Op will output a table_handle

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    REGISTER_OP("CUDAMutableHashTableV2")
    .Output("table_handle: resource")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .Attr("use_node_name_sharing: bool = false")
    .Attr("total_buckets: int = 131072") // 2^17
    .Attr("key_dtype: type")
    .Attr("value_dtype: type")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
    return MutableHashTableShape(c, /*key=*/c->Scalar(),
    /*value=*/c->Scalar());
    });
  • According to TensorFlow CPU HashTable’s style, the following HashTable operations could be registered:

    • LookupTableFind
    • LookupTableInsert
    • LookupTableSize
    • LookupTableExport
  • Those operations should take table_handle as the first input.

GPU Context

  • Based on Dependency Injection design pattern, GPU memory related operation could be encapsulated in a class and as a member of HashTable class.
  • For example, it could have a cudaStream_t member, and Alloc(), Free(), MemcpyD2H() and MemcpyH2D() methods.

    GPU Context in Tensorflow

  • Tensorflow has it’s own module to handle GPU related resource, which is StreamExecutor.
  • You could get se::Stream object by using (OpKernelContext*) ctx -> op_device_context()->stream()
  • Alloc should use ctx->allocate_persistent() function. In this case, those memory will be managed by OpKernelContext, thus there is no need to free.
  • For memory copy, you could use se::DeviceMemorBase() to get a gpu pointer. Then use (se::Stream*) stream_->ThenMemcpy to transfer data between CPU memory and GPU memory.
  • About getting cudaStream_t:
    1
    2
    3
    4
    se::Stream* stream_ = ctx -> op_device_context() -> stream();
    const cudaStream_t* cu_stream_ptr =
    CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(
    stream_->implementation()->GpuStreamMemberHack()));