Skip to content

Commit aaf9303

Browse files
authored
Upgrade to libtorch v1.6.0 (#60)
1 parent 2e658f0 commit aaf9303

File tree

4 files changed

+18
-13
lines changed

4 files changed

+18
-13
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake")
66

77
option(DOWNLOAD_DATASETS "Download datasets used in the tutorials." ON)
88

9-
set(PYTORCH_VERSION "1.5.1")
9+
set(PYTORCH_VERSION "1.6.0")
1010

1111
find_package(Torch ${PYTORCH_VERSION} EXACT QUIET PATHS "${CMAKE_SOURCE_DIR}/libtorch")
1212

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
<br />
77
<img src="https://img.shields.io/travis/prabhuomkar/pytorch-cpp">
88
<img src="https://img.shields.io/github/license/prabhuomkar/pytorch-cpp">
9-
<img src="https://img.shields.io/badge/libtorch-1.5.1-ee4c2c">
9+
<img src="https://img.shields.io/badge/libtorch-1.6.0-ee4c2c">
1010
<img src="https://img.shields.io/badge/cmake-3.14-064f8d">
1111
</p>
1212

1313

14-
| OS (Compiler)\\libtorch | 1.5.1 | nightly |
14+
| OS (Compiler)\\libtorch | 1.6.0 | nightly |
1515
| :---------------------: | :---------------------------------------------------------------------------------------------------: | :-----: |
1616
| macOS (clang 9.1) | ![Status](https://travis-matrix-badges.herokuapp.com/repos/prabhuomkar/pytorch-cpp/branches/master/1) | |
1717
| macOS (clang 10.0) | ![Status](https://travis-matrix-badges.herokuapp.com/repos/prabhuomkar/pytorch-cpp/branches/master/2) | |
@@ -31,7 +31,7 @@ This repository provides tutorial code in C++ for deep learning researchers to l
3131

3232
1. [C++](http://www.cplusplus.com/doc/tutorial/introduction/)
3333
2. [CMake](https://cmake.org/download/)
34-
3. [LibTorch v1.5.1](https://pytorch.org/cppdocs/installing.html)
34+
3. [LibTorch v1.6.0](https://pytorch.org/cppdocs/installing.html)
3535
4. [Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/download.html)
3636

3737

@@ -68,7 +68,7 @@ Some useful options:
6868

6969
| Option | Default | Description |
7070
| :------------- |:------------|-----:|
71-
| `-D CUDA_V=(9.2\|10.1\|10.2\|none)` | `none` | Download libtorch for a CUDA version (`none` = download CPU version). |
71+
| `-D CUDA_V=(9.2 [Linux only]\|10.1\|10.2\|none)` | `none` | Download libtorch for a CUDA version (`none` = download CPU version). |
7272
| `-D DOWNLOAD_DATASETS=(OFF\|ON)` | `ON` | Download all datasets used in the tutorials. |
7373
| `-D CMAKE_PREFIX_PATH=path/to/libtorch/share/cmake/Torch` | `<empty>` | Skip the downloading of libtorch and use your own local version (see Requirements) instead. |
7474
| `-D CMAKE_BUILD_TYPE=(Release\|Debug)` | `<empty>` (`Release` when downloading libtorch on Windows) | Set the build type (`Release` = compile with optimizations)|

cmake/download_libtorch.cmake

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
22

33
include(FetchContent)
44

5-
set(CUDA_V "none" CACHE STRING "Determines libtorch CUDA version to download (9.2, 10.1 or 10.2).")
5+
set(CUDA_V "none" CACHE STRING "Determines libtorch CUDA version to download (9.2 [Linux only], 10.1 or 10.2).")
66

77
if(${CUDA_V} STREQUAL "none")
88
set(LIBTORCH_DEVICE "cpu")
@@ -13,14 +13,18 @@ elseif(${CUDA_V} STREQUAL "10.1")
1313
elseif(${CUDA_V} STREQUAL "10.2")
1414
set(LIBTORCH_DEVICE "cu102")
1515
else()
16-
message(FATAL_ERROR "Invalid CUDA version specified, must be 9.2, 10.1, 10.2 or none!")
16+
message(FATAL_ERROR "Invalid CUDA version specified, must be 9.2 [Linux only], 10.1, 10.2 or none!")
1717
endif()
1818

1919
if(NOT ${LIBTORCH_DEVICE} STREQUAL "cu102")
2020
set(LIBTORCH_DEVICE_TAG "%2B${LIBTORCH_DEVICE}")
2121
endif()
2222

2323
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
24+
if(${LIBTORCH_DEVICE} STREQUAL "cu92")
25+
message(FATAL_ERROR "PyTorch ${PYTORCH_VERSION} does not support CUDA 9.2 on Windows. Please use CPU or upgrade to CUDA versions 10.1 or 10.2.")
26+
endif()
27+
2428
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/${LIBTORCH_DEVICE}/libtorch-win-shared-with-deps-${PYTORCH_VERSION}${LIBTORCH_DEVICE_TAG}.zip")
2529
set(CMAKE_BUILD_TYPE "Release")
2630
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")

tutorials/advanced/image_captioning/src/score.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright 2020-present pytorch-cpp Authors
22
#include "score.h"
33
#include <vector>
4+
#include <cmath>
45

56
namespace score {
67
namespace {
@@ -9,20 +10,20 @@ torch::Tensor n_grams(const torch::Tensor &sequence, size_t n) {
910
torch::empty({0}, sequence.options()) : sequence.unfold(-1, n, 1);
1011
}
1112

12-
size_t closest_reference_size(size_t hypothesis_size, const std::vector<torch::Tensor> &references) {
13-
std::vector<size_t> reference_sizes(references.size());
13+
int64_t closest_reference_size(int64_t hypothesis_size, const std::vector<torch::Tensor> &references) {
14+
std::vector<int64_t> reference_sizes(references.size());
1415

1516
std::transform(references.cbegin(), references.cend(), reference_sizes.begin(),
1617
[](const auto &reference) { return reference.size(0); });
1718

1819
return *std::min_element(reference_sizes.cbegin(), reference_sizes.cend(),
19-
[hypothesis_size](auto l, auto r) {
20-
return std::abs<int64_t>(l - hypothesis_size) <=
21-
std::abs<int64_t>(r - hypothesis_size);
20+
[hypothesis_size](int64_t l, int64_t r) {
21+
return std::abs(l - hypothesis_size) <=
22+
std::abs(r - hypothesis_size);
2223
});
2324
}
2425

25-
double brevity_penalty(size_t hypothesis_size, size_t closest_ref_size) {
26+
double brevity_penalty(int64_t hypothesis_size, int64_t closest_ref_size) {
2627
return (hypothesis_size > closest_ref_size) ? 1.0 :
2728
std::exp(1.0 - static_cast<double>(closest_ref_size) / hypothesis_size);
2829
}

0 commit comments

Comments
 (0)