This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use mnist::*; | |
| fn main()-> Result<(), Box<dyn, Error>>{ | |
| // Deconstruct the returned Mnist struct. | |
| let Mnist { | |
| trn_img, | |
| trn_lbl, | |
| val_img, | |
| val_lbl, | |
| tst_img, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| [package] | |
| name = "simple_neural_networks" | |
| version = "0.1.0" | |
| edition = "2021" | |
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |
| [dependencies] | |
| tch = "0.8.0" | |
| mnist = {version = "0.5.0", features = ["download"]} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| pub fn linear<'a, T: Borrow<super::Path<'a>>>( | |
| vs: T, | |
| in_dim: i64, | |
| out_dim: i64, | |
| c: LinearConfig, | |
| ) -> Linear |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| pub struct LinearConfig { | |
| pub ws_init: super::Init, | |
| pub bs_init: Option<super::Init>, | |
| pub bias: bool, | |
| } | |
| impl Default for LinearConfig { | |
| fn default() -> Self { | |
| LinearConfig { ws_init: super::Init::KaimingUniform, bs_init: None, bias: true } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| impl Init { | |
| /// Re-initializes an existing tensor with the specified initialization | |
| pub fn set(self, tensor: &mut Tensor) { | |
| match self { | |
| Init::Const(cst) => { | |
| let _ = tensor.fill_(cst); | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| match i { | |
| Init::Const(cst) => { | |
| // Optimize the case for which a single C++ code can be done. | |
| if cst == 0. { | |
| Tensor::f_zeros(dims, (Kind::Float, device)) | |
| } else if (cst - 1.).abs() <= std::f64::EPSILON { | |
| Tensor::f_ones(dims, (Kind::Float, device)) | |
| } else { | |
| Tensor::f_ones(dims, (Kind::Float, device)).map(|t| t * cst) | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Init::Const(cst) => { | |
| // Optimize the case for which a single C++ code can be done. | |
| if cst == 0. { | |
| Tensor::f_zeros(dims, (Kind::Float, device)) | |
| } else if (cst - 1.).abs() <= std::f64::EPSILON { | |
| Tensor::f_ones(dims, (Kind::Float, device)) | |
| } else { | |
| Tensor::f_ones(dims, (Kind::Float, device)).map(|t| t * cst) | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| !pip install sentence_transformers | |
| from sentence_transformers import SentenceTransformer, util | |
| # use roberta | |
| model = SentenceTransformer('stsb-roberta-large') | |
| def create_heatmap(similarity, cmap = "YlGnBu"): | |
| df = pd.DataFrame(similarity) | |
| df.columns = ['john', 'luke','mark', 'matt'] #ohn 0 mark 2 matt 3 luke 1 | |
| df.index = ['john', 'luke','mark', 'matt'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def create_heatmap(similarity, cmap = "YlGnBu"): | |
| df = pd.DataFrame(similarity) | |
| df.columns = ['john', 'luke','mark', 'matt'] #ohn 0 mark 2 matt 3 luke 1 | |
| df.index = ['john', 'luke','mark', 'matt'] | |
| fig, ax = plt.subplots(figsize=(5,5)) | |
| sns.heatmap(df, cmap=cmap) | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| import seaborn as sns |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def create_heatmap(similarity, cmap = "YlGnBu"): | |
| df = pd.DataFrame(similarity) | |
| df.columns = ['john', 'luke','mark', 'matt'] #ohn 0 mark 2 matt 3 luke 1 | |
| df.index = ['john', 'luke','mark', 'matt'] | |
| fig, ax = plt.subplots(figsize=(5,5)) | |
| sns.heatmap(df, cmap=cmap) | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| import seaborn as sns |